package dr.evomodel.arg.coalescent;

import dr.evolution.coalescent.DemographicFunction;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.util.Units;
import dr.evomodel.coalescent.DemographicModel;
import dr.evomodel.tree.TreeModel;
import dr.inference.model.AbstractModelLikelihood;
import dr.inference.model.Likelihood;
import dr.inference.model.Model;
import dr.inference.model.Statistic;
import dr.inference.model.Variable;
import dr.math.Binomial;
import dr.util.ComparableDouble;
import dr.util.HeapSort;
import dr.xml.AbstractXMLObjectParser;
import dr.xml.ElementRule;
import dr.xml.XMLObject;
import dr.xml.XMLObjectParser;
import dr.xml.XMLSyntaxRule;
import java.util.ArrayList;

/* loaded from: input_file:dr/evomodel/arg/coalescent/VeryOldCoalescentLikelihood.class */
public class VeryOldCoalescentLikelihood extends AbstractModelLikelihood implements Units {
    public static final String COALESCENT_LIKELIHOOD = "veryOldCoalescentLikelihood";
    public static final String ANALYTICAL = "analytical";
    public static final String MODEL = "model";
    public static final String POPULATION_TREE = "populationTree";
    public static final int COALESCENT = 0;
    public static final int NEW_SAMPLE = 1;
    public static final int NOTHING = 2;
    public static XMLObjectParser PARSER = new AbstractXMLObjectParser() { // from class: dr.evomodel.arg.coalescent.VeryOldCoalescentLikelihood.1
        private final XMLSyntaxRule[] rules = {new ElementRule("model", new XMLSyntaxRule[]{new ElementRule(DemographicModel.class)}), new ElementRule("populationTree", new XMLSyntaxRule[]{new ElementRule(TreeModel.class)})};

        @Override // dr.xml.XMLObjectParser
        public String getParserName() {
            return VeryOldCoalescentLikelihood.COALESCENT_LIKELIHOOD;
        }

        @Override // dr.xml.AbstractXMLObjectParser
        public Object parseXMLObject(XMLObject xMLObject) {
            return new VeryOldCoalescentLikelihood((TreeModel) xMLObject.getChild("populationTree").getChild(TreeModel.class), (DemographicModel) xMLObject.getChild("model").getChild(DemographicModel.class));
        }

        @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
        public String getParserDescription() {
            return "This element represents the likelihood of the tree given the demographic function.";
        }

        @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
        public Class getReturnType() {
            return Likelihood.class;
        }

        @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
        public XMLSyntaxRule[] getSyntaxRules() {
            return this.rules;
        }
    };
    DemographicModel demoModel;
    Tree tree;
    double[] intervals;
    private double[] storedIntervals;
    int[] lineageCounts;
    private int[] storedLineageCounts;
    boolean intervalsKnown;
    protected boolean storedIntervalsKnown;
    double logLikelihood;
    protected double storedLogLikelihood;
    boolean likelihoodKnown;
    protected boolean storedLikelihoodKnown;
    int intervalCount;
    private int storedIntervalCount;

    /* loaded from: input_file:dr/evomodel/arg/coalescent/VeryOldCoalescentLikelihood$DeltaStatistic.class */
    public class DeltaStatistic extends Statistic.Abstract {
        public DeltaStatistic() {
            super("delta");
        }

        @Override // dr.inference.model.Statistic
        public int getDimension() {
            return 1;
        }

        @Override // dr.inference.model.Statistic
        public double getStatisticValue(int i) {
            throw new RuntimeException("Not implemented");
        }
    }

    public VeryOldCoalescentLikelihood(Tree tree, DemographicModel demographicModel) {
        this(COALESCENT_LIKELIHOOD, tree, demographicModel, true);
    }

    public VeryOldCoalescentLikelihood(String str, Tree tree, DemographicModel demographicModel, boolean z) {
        super(str);
        this.demoModel = null;
        this.tree = null;
        this.intervalsKnown = false;
        this.storedIntervalsKnown = false;
        this.likelihoodKnown = false;
        this.storedLikelihoodKnown = false;
        this.intervalCount = 0;
        this.storedIntervalCount = 0;
        this.tree = tree;
        this.demoModel = demographicModel;
        if (tree instanceof TreeModel) {
            addModel((TreeModel) tree);
        }
        if (demographicModel != null) {
            addModel(demographicModel);
        }
        if (z) {
            setupIntervals();
        }
        addStatistic(new DeltaStatistic());
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public VeryOldCoalescentLikelihood(String str) {
        super(str);
        this.demoModel = null;
        this.tree = null;
        this.intervalsKnown = false;
        this.storedIntervalsKnown = false;
        this.likelihoodKnown = false;
        this.storedLikelihoodKnown = false;
        this.intervalCount = 0;
        this.storedIntervalCount = 0;
    }

    public NodeRef getMRCAOfCoalescent(Tree tree) {
        return tree.getRoot();
    }

    public NodeRef[] getExcludedMRCAs(Tree tree) {
        return null;
    }

    @Override // dr.inference.model.AbstractModel
    protected void handleModelChangedEvent(Model model, Object obj, int i) {
        if (model == this.tree) {
            this.intervalsKnown = false;
        }
        this.likelihoodKnown = false;
    }

    @Override // dr.inference.model.AbstractModel
    protected final void handleVariableChangedEvent(Variable variable, int i, Variable.ChangeType changeType) {
    }

    @Override // dr.inference.model.AbstractModel
    protected void storeState() {
        System.arraycopy(this.intervals, 0, this.storedIntervals, 0, this.intervals.length);
        System.arraycopy(this.lineageCounts, 0, this.storedLineageCounts, 0, this.lineageCounts.length);
        this.storedIntervalsKnown = this.intervalsKnown;
        this.storedIntervalCount = this.intervalCount;
        this.storedLikelihoodKnown = this.likelihoodKnown;
        this.storedLogLikelihood = this.logLikelihood;
    }

    @Override // dr.inference.model.AbstractModel
    protected void restoreState() {
        System.arraycopy(this.storedIntervals, 0, this.intervals, 0, this.storedIntervals.length);
        System.arraycopy(this.storedLineageCounts, 0, this.lineageCounts, 0, this.storedLineageCounts.length);
        this.intervalsKnown = this.storedIntervalsKnown;
        this.intervalCount = this.storedIntervalCount;
        this.likelihoodKnown = this.storedLikelihoodKnown;
        this.logLikelihood = this.storedLogLikelihood;
        if (this.intervalsKnown) {
            return;
        }
        this.likelihoodKnown = false;
    }

    @Override // dr.inference.model.AbstractModel
    protected final void acceptState() {
    }

    protected final void adoptState(Model model) {
        makeDirty();
    }

    @Override // dr.inference.model.Likelihood
    public final Model getModel() {
        return this;
    }

    public double getLogLikelihood() {
        if (!this.likelihoodKnown) {
            this.logLikelihood = calculateLogLikelihood();
            this.likelihoodKnown = true;
        }
        return this.logLikelihood;
    }

    @Override // dr.inference.model.Likelihood
    public final void makeDirty() {
        this.likelihoodKnown = false;
        this.intervalsKnown = false;
    }

    public double calculateLogLikelihood() {
        if (!this.intervalsKnown) {
            setupIntervals();
        }
        if (this.demoModel == null) {
            return calculateAnalyticalLogLikelihood();
        }
        double d = 0.0d;
        double d2 = 0.0d;
        DemographicFunction demographicFunction = this.demoModel.getDemographicFunction();
        for (int i = 0; i < this.intervalCount; i++) {
            d += calculateIntervalLikelihood(demographicFunction, this.intervals[i], d2, this.lineageCounts[i], getIntervalType(i));
            int coalescentEvents = getCoalescentEvents(i) - 1;
            for (int i2 = 0; i2 < coalescentEvents; i2++) {
                d += calculateIntervalLikelihood(demographicFunction, 0.0d, d2, (this.lineageCounts[i] - i2) - 1, 0);
            }
            d2 += this.intervals[i];
        }
        return d;
    }

    private double calculateAnalyticalLogLikelihood() {
        return Math.log(1.0d / Math.pow(getLambda(), this.tree.getExternalNodeCount() - 1));
    }

    public final double calculateIntervalLikelihood(DemographicFunction demographicFunction, double d, double d2, int i) {
        return calculateIntervalLikelihood(demographicFunction, d, d2, i, 0);
    }

    public final double calculateIntervalLikelihood(DemographicFunction demographicFunction, double d, double d2, int i, int i2) {
        double d3 = d + d2;
        double integral = demographicFunction.getIntegral(d2, d3);
        double d4 = 0.0d;
        switch (i2) {
            case 0:
                d4 = (-Math.log(demographicFunction.getDemographic(d3))) - (Binomial.choose2(i) * integral);
                break;
            case 1:
                d4 = -(Binomial.choose2(i) * integral);
                break;
        }
        return d4;
    }

    private double getLambda() {
        double d = 0.0d;
        for (int i = 0; i < getIntervalCount(); i++) {
            d += this.intervals[i] * this.lineageCounts[i];
        }
        return d / 2.0d;
    }

    protected final void setupIntervals() {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        collectAllTimes(this.tree, getMRCAOfCoalescent(this.tree), getExcludedMRCAs(this.tree), arrayList, arrayList2);
        int[] iArr = new int[arrayList.size()];
        HeapSort.sort(arrayList, iArr);
        int nodeCount = this.tree.getNodeCount();
        if (this.intervals == null) {
            this.intervals = new double[nodeCount];
            this.lineageCounts = new int[nodeCount];
            this.storedIntervals = new double[nodeCount];
            this.storedLineageCounts = new int[nodeCount];
        }
        double doubleValue = ((ComparableDouble) arrayList.get(iArr[0])).doubleValue();
        int i = 0;
        int i2 = 0;
        this.intervalCount = 0;
        while (i2 < arrayList.size()) {
            int i3 = 0;
            int i4 = 0;
            double doubleValue2 = ((ComparableDouble) arrayList.get(iArr[i2])).doubleValue();
            double d = doubleValue2;
            while (Math.abs(d - doubleValue2) < 1.0E-9d) {
                int intValue = ((Integer) arrayList2.get(iArr[i2])).intValue();
                if (intValue == 0) {
                    i4++;
                } else {
                    i3 += intValue - 1;
                }
                i2++;
                if (i2 >= arrayList.size()) {
                    break;
                } else {
                    d = ((ComparableDouble) arrayList.get(iArr[i2])).doubleValue();
                }
            }
            if (i4 > 0) {
                if (this.intervalCount > 0 || doubleValue2 - doubleValue > 1.0E-9d) {
                    this.intervals[this.intervalCount] = doubleValue2 - doubleValue;
                    this.lineageCounts[this.intervalCount] = i;
                    this.intervalCount++;
                }
                doubleValue = doubleValue2;
            }
            int i5 = i + i4;
            if (i3 > 0) {
                this.intervals[this.intervalCount] = doubleValue2 - doubleValue;
                this.lineageCounts[this.intervalCount] = i5;
                this.intervalCount++;
                doubleValue = doubleValue2;
            }
            i = i5 - i3;
        }
        this.intervalsKnown = true;
    }

    private static void collectAllTimes(Tree tree, NodeRef nodeRef, NodeRef[] nodeRefArr, ArrayList arrayList, ArrayList<Integer> arrayList2) {
        arrayList.add(new ComparableDouble(tree.getNodeHeight(nodeRef)));
        arrayList2.add(Integer.valueOf(tree.getChildCount(nodeRef)));
        for (int i = 0; i < tree.getChildCount(nodeRef); i++) {
            NodeRef child = tree.getChild(nodeRef, i);
            if (nodeRefArr == null) {
                collectAllTimes(tree, child, nodeRefArr, arrayList, arrayList2);
            } else {
                boolean z = true;
                int length = nodeRefArr.length;
                int i2 = 0;
                while (true) {
                    if (i2 >= length) {
                        break;
                    }
                    if (nodeRefArr[i2].getNumber() == child.getNumber()) {
                        z = false;
                        break;
                    }
                    i2++;
                }
                if (z) {
                    collectAllTimes(tree, child, nodeRefArr, arrayList, arrayList2);
                }
            }
        }
    }

    public final int getIntervalCount() {
        return this.intervalCount;
    }

    public final double getInterval(int i) {
        if (i >= this.intervalCount) {
            throw new IllegalArgumentException();
        }
        return this.intervals[i];
    }

    public final int getLineageCount(int i) {
        if (i >= this.intervalCount) {
            throw new IllegalArgumentException();
        }
        return this.lineageCounts[i];
    }

    public final int getCoalescentEvents(int i) {
        if (i >= this.intervalCount) {
            throw new IllegalArgumentException();
        }
        return i < this.intervalCount - 1 ? this.lineageCounts[i] - this.lineageCounts[i + 1] : this.lineageCounts[i] - 1;
    }

    public final int getIntervalType(int i) {
        if (i >= this.intervalCount) {
            throw new IllegalArgumentException();
        }
        int coalescentEvents = getCoalescentEvents(i);
        if (coalescentEvents > 0) {
            return 0;
        }
        return coalescentEvents < 0 ? 1 : 2;
    }

    public final double getTotalHeight() {
        double d = 0.0d;
        for (int i = 0; i < this.intervalCount; i++) {
            d += this.intervals[i];
        }
        return d;
    }

    public final boolean isBinaryCoalescent() {
        for (int i = 0; i < this.intervalCount; i++) {
            if (getCoalescentEvents(i) != 1) {
                return false;
            }
        }
        return true;
    }

    public final boolean isCoalescentOnly() {
        for (int i = 0; i < this.intervalCount; i++) {
            if (getCoalescentEvents(i) < 1) {
                return false;
            }
        }
        return true;
    }

    @Override // dr.inference.model.AbstractModel
    public String toString() {
        return Double.toString(getLogLikelihood());
    }

    @Override // dr.evolution.util.Units
    public final Units.Type getUnits() {
        return this.demoModel.getUnits();
    }

    @Override // dr.evolution.util.Units
    public void setUnits(Units.Type type) {
        this.demoModel.setUnits(type);
    }
}
