package dr.evomodel.speciation;

import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeUtils;
import dr.evolution.util.Taxa;
import dr.evolution.util.TaxonList;
import dr.inference.model.Statistic;
import dr.math.distributions.Distribution;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

/* loaded from: input_file:dr/evomodel/speciation/CalibrationPoints.class */
public class CalibrationPoints {
    private final CorrectionType correctionType;
    private final int[][] clades;
    private final Distribution[] densities;
    private final boolean[] forParent;
    private final int[][] taxaPartialOrder;
    private final int[] freeHeights;
    private final boolean rootCorrection;
    private final Statistic calibrationLogPDF;
    private double[] lc2;
    private double[] lNR;
    private double[] lfactorials;
    private CalibrationLineagesIterator linsIter;
    double[] lastHeights;
    static final /* synthetic */ boolean $assertionsDisabled;
    private final double lg2 = Math.log(2.0d);
    double lastLam = Double.NEGATIVE_INFINITY;
    double lastValue = Double.NEGATIVE_INFINITY;

    /* loaded from: input_file:dr/evomodel/speciation/CalibrationPoints$CorrectionType.class */
    public enum CorrectionType {
        EXACT("exact"),
        APPROXIMATED("approximated"),
        PEXACT("pexact"),
        NONE("none");

        private final String name;

        CorrectionType(String str) {
            this.name = str;
        }

        @Override // java.lang.Enum
        public String toString() {
            return this.name;
        }
    }

    /* JADX WARN: Type inference failed for: r1v10, types: [int[], int[][]] */
    /* JADX WARN: Type inference failed for: r1v23, types: [int[], int[][]] */
    public CalibrationPoints(Tree tree, boolean z, List<Distribution> list, List<Taxa> list2, List<Boolean> list3, Statistic statistic, CorrectionType correctionType) {
        this.linsIter = null;
        this.densities = new Distribution[list.size()];
        this.clades = new int[list2.size()];
        this.forParent = new boolean[list2.size()];
        for (int i = 0; i < list2.size(); i++) {
            Taxa taxa = list2.get(i);
            for (int i2 = i + 1; i2 < list2.size(); i2++) {
                Taxa taxa2 = list2.get(i2);
                if (taxa2.containsAny(taxa) && !taxa2.containsAll(taxa) && !taxa.containsAll(taxa2)) {
                    throw new IllegalArgumentException("Overlapping clades??");
                }
            }
        }
        TaxonList[] taxonListArr = new Taxa[list2.size()];
        for (int size = list2.size() - 1; size >= 0; size--) {
            int i3 = 0;
            while (i3 < list2.size() && !isMaximal(list2, i3)) {
                i3++;
            }
            this.densities[size] = list.remove(i3);
            this.forParent[size] = list3.remove(i3).booleanValue();
            Taxa taxa3 = list2.get(i3);
            int taxonCount = taxa3.getTaxonCount();
            this.clades[size] = new int[taxonCount];
            for (int i4 = 0; i4 < taxonCount; i4++) {
                int taxonIndex = tree.getTaxonIndex(taxa3.getTaxon(i4));
                this.clades[size][i4] = taxonIndex;
                if (taxonIndex < 0) {
                    throw new IllegalArgumentException("Taxon not found in tree: " + taxa3.getTaxon(i4));
                }
            }
            taxonListArr[size] = taxa3;
            list2.remove(i3);
        }
        List[] listArr = new List[taxonListArr.length];
        for (int i5 = 0; i5 < taxonListArr.length; i5++) {
            listArr[i5] = new ArrayList();
        }
        for (int i6 = 0; i6 < taxonListArr.length; i6++) {
            int i7 = i6 + 1;
            while (true) {
                if (i7 >= taxonListArr.length) {
                    break;
                }
                if (taxonListArr[i7].containsAll(taxonListArr[i6])) {
                    listArr[i7].add(Integer.valueOf(i6));
                    break;
                }
                i7++;
            }
        }
        this.taxaPartialOrder = new int[taxonListArr.length];
        for (int i8 = 0; i8 < taxonListArr.length; i8++) {
            List list4 = listArr[i8];
            this.taxaPartialOrder[i8] = new int[list4.size()];
            for (int i9 = 0; i9 < list4.size(); i9++) {
                this.taxaPartialOrder[i8][i9] = ((Integer) list4.get(i9)).intValue();
            }
        }
        this.freeHeights = new int[this.clades.length];
        for (int i10 = 0; i10 < this.clades.length; i10++) {
            int i11 = 0;
            for (int i12 : this.taxaPartialOrder[i10]) {
                i11 += this.clades[i12].length - (this.forParent[i12] ? 0 : 1);
            }
            this.freeHeights[i10] = (this.clades[i10].length - (this.forParent[i10] ? 1 : 2)) - i11;
            if (!$assertionsDisabled && this.freeHeights[i10] < 0) {
                throw new AssertionError();
            }
        }
        boolean[] zArr = new boolean[this.clades.length];
        for (int i13 = 0; i13 < this.clades.length; i13++) {
            zArr[i13] = true;
        }
        for (int i14 = 0; i14 < this.clades.length; i14++) {
            for (int i15 : this.taxaPartialOrder[i14]) {
                zArr[i15] = false;
            }
        }
        this.rootCorrection = this.clades[this.clades.length - 1].length < tree.getExternalNodeCount();
        this.calibrationLogPDF = statistic;
        this.correctionType = correctionType;
        if (statistic == null) {
            if (!z) {
                throw new IllegalArgumentException("Sorry, not implemented: conditional calibration prior for this non Yule models.");
            }
            if (correctionType != CorrectionType.EXACT) {
                if (correctionType == CorrectionType.PEXACT) {
                    setUpTables(tree);
                    return;
                }
                return;
            }
            if (this.densities.length == 1) {
                return;
            }
            boolean z2 = false;
            for (boolean z3 : this.forParent) {
                if (z3) {
                    z2 = true;
                }
            }
            if (z2) {
                throw new IllegalArgumentException("Sorry, not implemented: calibration on parent for more than one clade.");
            }
            if (this.densities.length == 2 && taxonListArr[1].containsAll(taxonListArr[0])) {
                return;
            }
            setUpTables(tree);
            this.linsIter = new CalibrationLineagesIterator(this.clades, this.taxaPartialOrder, zArr, tree.getExternalNodeCount());
            this.lastHeights = new double[this.clades.length];
        }
    }

    private void setUpTables(Tree tree) {
        int externalNodeCount = tree.getExternalNodeCount() + 1;
        double[] dArr = new double[externalNodeCount];
        this.lc2 = new double[externalNodeCount];
        this.lfactorials = new double[externalNodeCount];
        this.lNR = new double[externalNodeCount];
        dArr[0] = Double.NEGATIVE_INFINITY;
        dArr[1] = 0.0d;
        for (int i = 2; i < externalNodeCount; i++) {
            dArr[i] = Math.log(i);
        }
        double[] dArr2 = this.lc2;
        this.lc2[1] = Double.NEGATIVE_INFINITY;
        dArr2[0] = Double.NEGATIVE_INFINITY;
        for (int i2 = 2; i2 < externalNodeCount; i2++) {
            this.lc2[i2] = (dArr[i2] + dArr[i2 - 1]) - this.lg2;
        }
        this.lfactorials[0] = 0.0d;
        for (int i3 = 1; i3 < externalNodeCount; i3++) {
            this.lfactorials[i3] = this.lfactorials[i3 - 1] + dArr[i3];
        }
        this.lNR[0] = Double.NEGATIVE_INFINITY;
        this.lNR[1] = 0.0d;
        for (int i4 = 2; i4 < externalNodeCount; i4++) {
            this.lNR[i4] = this.lNR[i4 - 1] + this.lc2[i4];
        }
    }

    private boolean isMaximal(List<Taxa> list, int i) {
        Taxa taxa = list.get(i);
        for (int i2 = 0; i2 < list.size(); i2++) {
            if (i2 != i && list.get(i2).containsAll(taxa)) {
                return false;
            }
        }
        return true;
    }

    public double getCorrection(Tree tree, double d) {
        NodeRef node;
        double d2 = 0.0d;
        int length = this.densities.length;
        double[] dArr = new double[length];
        for (int i = 0; i < length; i++) {
            int[] iArr = this.clades[i];
            if (iArr.length > 1) {
                node = TreeUtils.getCommonAncestor(tree, iArr);
                if (TreeUtils.getLeafCount(tree, node) != iArr.length) {
                    return Double.NEGATIVE_INFINITY;
                }
            } else {
                node = tree.getNode(iArr[0]);
                if (!$assertionsDisabled && !this.forParent[i]) {
                    throw new AssertionError();
                }
            }
            if (this.forParent[i]) {
                node = tree.getParent(node);
            }
            double nodeHeight = tree.getNodeHeight(node);
            d2 += this.densities[i].logPdf(nodeHeight);
            dArr[i] = nodeHeight;
        }
        if (!Double.isInfinite(d2) && this.correctionType != CorrectionType.NONE) {
            if (this.calibrationLogPDF == null) {
                switch (this.correctionType) {
                    case EXACT:
                        if (length != 1) {
                            if (length != 2 || this.taxaPartialOrder[1].length != 1) {
                                if (this.lastLam == d) {
                                    int i2 = 0;
                                    while (i2 < dArr.length && dArr[i2] == this.lastHeights[i2]) {
                                        i2++;
                                    }
                                    if (i2 == dArr.length) {
                                        return this.lastValue;
                                    }
                                }
                                double[] dArr2 = new double[dArr.length];
                                int[] iArr2 = new int[dArr.length];
                                for (int i3 = 0; i3 < dArr.length; i3++) {
                                    int i4 = 0;
                                    for (double d3 : dArr) {
                                        i4 += d3 < dArr[i3] ? 1 : 0;
                                    }
                                    iArr2[i3] = i4 + 1;
                                    dArr2[i4] = dArr[i3];
                                }
                                d2 -= logMarginalDensity(d, dArr2, iArr2, this.linsIter);
                                this.lastLam = d;
                                System.arraycopy(dArr, 0, this.lastHeights, 0, this.lastHeights.length);
                                this.lastValue = d2;
                                break;
                            } else {
                                if (!$assertionsDisabled && (this.forParent[0] || this.forParent[1])) {
                                    throw new AssertionError();
                                }
                                d2 -= logMarginalDensity(d, tree.getExternalNodeCount(), dArr[0], this.clades[0].length, dArr[1], this.clades[1].length);
                                break;
                            }
                        } else {
                            d2 -= logMarginalDensity(d, tree.getExternalNodeCount(), dArr[0], this.clades[0].length, this.forParent[0]);
                            break;
                        }
                        break;
                    case APPROXIMATED:
                        double log = Math.log(d);
                        int i5 = 0;
                        for (int i6 = 0; i6 < length; i6++) {
                            double d4 = (-d) * dArr[i6];
                            if (this.freeHeights[i6] > 0) {
                                d2 -= Math.log1p(-Math.exp(d4)) * this.freeHeights[i6];
                            }
                            d2 -= d4 + log;
                            if (dArr[i6] > dArr[i5]) {
                                i5 = i6;
                            }
                        }
                        if (!this.rootCorrection) {
                        }
                        d2 -= ((-(this.forParent[i5] ? 0 : 1)) * d) * dArr[i5];
                        if (Double.isNaN(d2)) {
                            d2 = Double.NEGATIVE_INFINITY;
                            break;
                        }
                        break;
                    case PEXACT:
                        Arrays.sort(dArr);
                        int[] iArr3 = new int[length + 1];
                        int internalNodeCount = tree.getInternalNodeCount();
                        for (int i7 = 0; i7 < internalNodeCount; i7++) {
                            double nodeHeight2 = tree.getNodeHeight(tree.getInternalNode(i7));
                            int i8 = 0;
                            while (i8 < dArr.length && dArr[i8] < nodeHeight2) {
                                i8++;
                            }
                            if (i8 == dArr.length) {
                                int i9 = i8;
                                iArr3[i9] = iArr3[i9] + 1;
                            } else if (nodeHeight2 < dArr[i8]) {
                                int i10 = i8;
                                iArr3[i10] = iArr3[i10] + 1;
                            }
                        }
                        double log1p = 0.0d + (((iArr3[0] * Math.log1p(-Math.exp((-d) * dArr[0]))) - (d * dArr[0])) - this.lfactorials[iArr3[0]]);
                        for (int i11 = 1; i11 < iArr3.length - 1; i11++) {
                            int i12 = iArr3[i11];
                            log1p = log1p + (i12 * (Math.log1p(-Math.exp((-d) * (dArr[i11] - dArr[i11 - 1]))) - (d * dArr[i11 - 1]))) + (((-d) * dArr[i11]) - this.lfactorials[i12]);
                        }
                        d2 -= (log1p + ((((-d) * (iArr3[length] + 1)) * dArr[length - 1]) - this.lfactorials[iArr3[length] + 1])) + (Math.log(d) * length);
                        break;
                }
            } else {
                double statisticValue = this.calibrationLogPDF.getStatisticValue(0);
                d2 = (Double.isNaN(statisticValue) || Double.isInfinite(statisticValue)) ? Double.NEGATIVE_INFINITY : d2 - statisticValue;
            }
            return d2;
        }
        return d2;
    }

    private double logMarginalDensity(double d, int i, double d2, int i2, boolean z) {
        double log;
        double d3 = d * d2;
        if (z) {
            log = ((-2.0d) * d3) + Math.log(d);
            if (i2 > 1) {
                log += (i2 - 1) * Math.log(1.0d - Math.exp(-d3));
            }
        } else {
            if (!$assertionsDisabled && i2 <= 1) {
                throw new AssertionError();
            }
            log = ((-3.0d) * d3) + ((i2 - 2) * Math.log(1.0d - Math.exp(-d3))) + Math.log(d);
            if (i == i2) {
                log += d3;
            }
        }
        return log;
    }

    private double logMarginalDensity(double d, int i, double d2, int i2, double d3, int i3) {
        if (!$assertionsDisabled && (d2 > d3 || i2 >= i3)) {
            throw new AssertionError();
        }
        int i4 = i3 - i2;
        double exp = Math.exp((-d) * d2);
        double exp2 = Math.exp((-d) * d3);
        double log = (2.0d * Math.log(d)) + ((i2 - 2) * Math.log(1.0d - exp)) + ((i4 - 3) * Math.log(1.0d - exp2)) + Math.log((((1.0d - ((2 * i4) * exp2)) + ((2 * (i4 - 1)) * exp)) - (((i4 * (i4 - 1)) * exp2) * exp)) + (((i4 * (i4 + 1)) / 2.0d) * exp2 * exp2) + ((((i4 - 1) * (i4 - 2)) / 2.0d) * exp * exp));
        return i3 < i ? log - (d * (d2 + (3.0d * d3))) : log - (d * (d2 + (2.0d * d3)));
    }

    private double logMarginalDensity(double d, double[] dArr, int[] iArr, CalibrationLineagesIterator calibrationLineagesIterator) {
        int upVar = calibrationLineagesIterator.setup(iArr);
        int length = dArr.length;
        double[] dArr2 = new double[length + 1];
        dArr2[0] = 0.0d;
        for (int i = 1; i < dArr2.length; i++) {
            dArr2[i] = (-d) * dArr[i - 1];
        }
        boolean z = upVar == dArr2.length;
        int i2 = length + (z ? 1 : 0);
        double[] dArr3 = new double[i2];
        for (int i3 = 0; i3 < length; i3++) {
            dArr3[i3] = dArr2[i3] + Math.log1p(-Math.exp(dArr2[i3 + 1] - dArr2[i3]));
        }
        if (z) {
            dArr3[length] = dArr2[length];
        }
        int[] iArr2 = new int[i2];
        int[][] allJoiners = calibrationLineagesIterator.allJoiners();
        double d2 = 0.0d;
        boolean z2 = true;
        int i4 = 0;
        while (true) {
            int[][] next = calibrationLineagesIterator.next();
            if (next == null) {
                break;
            }
            i4++;
            double countRankedTrees = countRankedTrees(i2, next, allJoiners, iArr2);
            if (z) {
                int i5 = iArr2[i2 - 1] + 2;
                iArr2[i2 - 1] = i5;
                countRankedTrees -= this.lc2[i5] + this.lg2;
            }
            for (int i6 = 0; i6 < i2; i6++) {
                countRankedTrees += iArr2[i6] * dArr3[i6];
            }
            if (z2) {
                d2 = countRankedTrees;
                z2 = false;
            } else {
                d2 = d2 > countRankedTrees ? d2 + Math.log1p(Math.exp(countRankedTrees - d2)) : countRankedTrees + Math.log1p(Math.exp(d2 - countRankedTrees));
            }
        }
        double d3 = 0.0d;
        int i7 = 0;
        for (int i8 = 0; i8 < upVar; i8++) {
            int nStart = calibrationLineagesIterator.nStart(i8);
            if (nStart > 0) {
                d3 += this.lNR[nStart];
                i7 += nStart;
            }
        }
        double d4 = this.lfactorials[i7];
        double log = length * Math.log(d);
        for (int i9 = 1; i9 < length + 1; i9++) {
            log += dArr2[i9];
        }
        if (!z) {
            log += 1.0d * dArr2[length];
        }
        return d2 + d3 + d4 + log;
    }

    private double countRankedTrees(int i, int[][] iArr, int[][] iArr2, int[] iArr3) {
        double d = 0.0d;
        for (int i2 = 0; i2 < i; i2++) {
            int i3 = 0;
            for (int i4 = i2; i4 < i; i4++) {
                int[] iArr4 = iArr[i4];
                int i5 = iArr4[i2];
                if (iArr2[i4][i2] > 0) {
                    i5++;
                    if (i5 > 1) {
                        d += this.lc2[i5];
                    }
                }
                int i6 = i5 - iArr4[i2 + 1];
                d -= this.lfactorials[i6];
                i3 += i6;
            }
            iArr3[i2] = i3;
        }
        return d;
    }

    static {
        $assertionsDisabled = !CalibrationPoints.class.desiredAssertionStatus();
    }
}
