package dr.evolution.colouring;

import dr.evolution.alignment.Alignment;
import dr.evolution.coalescent.structure.MetaPopulation;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.util.TaxonList;
import dr.evoxml.util.GraphMLUtils;
import dr.math.MathUtils;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.List;
import java.util.TreeMap;

/* loaded from: input_file:dr/evolution/colouring/StructuredColourSampler.class */
public class StructuredColourSampler implements ColourSampler {
    static final int maxIterations = 1000;
    static final double tinyTime = 1.0E-6d;
    static final boolean debugMessages = false;
    static final boolean debugMeanColours = false;
    static final boolean debugNodePartials = false;
    static final boolean debugSampleLikelihoods = false;
    static final boolean debugRejectionSampler = false;
    static final boolean debugProposalProbabilityCalculator = false;
    double _totalIntegratedRate;
    static final DecimalFormat df = new DecimalFormat("###.####");
    static final double propAffected = 0.0d;
    private boolean useNodeBias;
    private boolean useBranchBias;
    private boolean useSecondColourIteration;
    private final int colourCount;
    private final int[] nodeColours;
    private final int[] leafColourCounts;
    private double[][] meanColourCounts;
    private int[][] nodeColoursEM;
    private double[][] nodePartials;
    private double[][][] nodePartialsEM;
    private double logNodePartialsRescaling;
    private double[] equilibriumColours;
    private int[] node2Interval;
    private double[] interval2Height;
    private double[] avgN0;
    private double[] avgN1;
    private int numIntervals;

    public StructuredColourSampler(Alignment alignment, Tree tree, boolean z, boolean z2, boolean z3) {
        this.useNodeBias = false;
        this.useBranchBias = false;
        this.useSecondColourIteration = true;
        if (alignment.getSiteCount() != 1) {
            throw new IllegalArgumentException("Tip colour alignment must consist of a single column!");
        }
        this.nodeColours = new int[tree.getNodeCount()];
        this.colourCount = alignment.getDataType().getStateCount();
        this.leafColourCounts = new int[this.colourCount];
        for (int i = 0; i < tree.getExternalNodeCount(); i++) {
            NodeRef externalNode = tree.getExternalNode(i);
            int state = alignment.getState(alignment.getTaxonIndex(tree.getTaxonId(i)), 0);
            this.nodeColours[externalNode.getNumber()] = state;
            int[] iArr = this.leafColourCounts;
            iArr[state] = iArr[state] + 1;
        }
        this.useNodeBias = z;
        this.useBranchBias = z2;
        this.useSecondColourIteration = z3;
        initialize(tree);
    }

    public StructuredColourSampler(TaxonList[] taxonListArr, Tree tree, boolean z, boolean z2, boolean z3) {
        this.useNodeBias = false;
        this.useBranchBias = false;
        this.useSecondColourIteration = true;
        this.nodeColours = new int[tree.getNodeCount()];
        this.colourCount = taxonListArr.length + 1;
        this.leafColourCounts = new int[this.colourCount];
        for (int i = 0; i < tree.getExternalNodeCount(); i++) {
            NodeRef externalNode = tree.getExternalNode(i);
            int i2 = 0;
            for (int i3 = 0; i3 < taxonListArr.length; i3++) {
                if (taxonListArr[i3].getTaxonIndex(tree.getTaxonId(i)) != -1) {
                    i2 = i3 + 1;
                }
            }
            this.nodeColours[externalNode.getNumber()] = i2;
            int[] iArr = this.leafColourCounts;
            int i4 = i2;
            iArr[i4] = iArr[i4] + 1;
        }
        this.useNodeBias = z;
        this.useBranchBias = z2;
        this.useSecondColourIteration = z3;
        initialize(tree);
    }

    @Override // dr.evolution.colouring.ColourSampler
    public int[] getLeafColourCounts() {
        return this.leafColourCounts;
    }

    /* JADX WARN: Type inference failed for: r1v11, types: [double[][], double[][][]] */
    /* JADX WARN: Type inference failed for: r1v8, types: [int[], int[][]] */
    private void initialize(Tree tree) {
        this.nodePartials = new double[tree.getNodeCount()][this.colourCount];
        this.meanColourCounts = new double[tree.getNodeCount()][this.colourCount];
        this.nodeColoursEM = new int[tree.getNodeCount()];
        this.nodePartialsEM = new double[tree.getNodeCount()];
        this.equilibriumColours = new double[this.colourCount];
    }

    private void computeIntervals(Tree tree, MetaPopulation metaPopulation) {
        TreeMap treeMap = new TreeMap();
        int nodeCount = tree.getNodeCount();
        for (int i = 0; i < nodeCount; i++) {
            NodeRef node = tree.getNode(i);
            Double d = new Double(tree.getNodeHeight(node));
            if (treeMap.containsKey(d)) {
                ((ArrayList) treeMap.get(d)).add(node);
            } else {
                ArrayList arrayList = new ArrayList(1);
                arrayList.add(node);
                treeMap.put(d, arrayList);
            }
        }
        this.node2Interval = new int[nodeCount];
        this.interval2Height = new double[treeMap.size()];
        this.avgN0 = new double[treeMap.size()];
        this.avgN1 = new double[treeMap.size()];
        int i2 = 0;
        for (Double d2 : treeMap.keySet()) {
            this.interval2Height[i2] = d2.doubleValue();
            List list = (List) treeMap.get(d2);
            for (int i3 = 0; i3 < list.size(); i3++) {
                this.node2Interval[((NodeRef) list.get(i3)).getNumber()] = i2;
            }
            if (i2 > 0) {
                double d3 = this.interval2Height[i2 - 1];
                double doubleValue = d2.doubleValue();
                this.avgN0[i2 - 1] = (doubleValue - d3) / metaPopulation.getIntegral(d3, doubleValue, 0);
                this.avgN1[i2 - 1] = (doubleValue - d3) / metaPopulation.getIntegral(d3, doubleValue, 1);
            }
            i2++;
        }
        this.numIntervals = i2;
    }

    @Override // dr.evolution.colouring.ColourSampler
    public DefaultTreeColouring sampleTreeColouring(Tree tree, ColourChangeMatrix colourChangeMatrix, MetaPopulation metaPopulation) {
        populateEquilibriumColourArray(colourChangeMatrix);
        computeIntervals(tree, metaPopulation);
        DefaultTreeColouring defaultTreeColouring = new DefaultTreeColouring(2, tree);
        this.logNodePartialsRescaling = 0.0d;
        prune(tree, tree.getRoot(), colourChangeMatrix);
        calculateMeanColourCounts(tree, colourChangeMatrix);
        this.logNodePartialsRescaling = 0.0d;
        double[] pruneEM = pruneEM(tree, tree.getRoot(), colourChangeMatrix, metaPopulation);
        if (this.useSecondColourIteration) {
            calculateMeanColourCountsEM(tree, tree.getRoot(), colourChangeMatrix);
            this.logNodePartialsRescaling = 0.0d;
            pruneEM = pruneEM(tree, tree.getRoot(), colourChangeMatrix, metaPopulation);
        }
        double d = 0.0d;
        for (int i = 0; i < this.colourCount; i++) {
            d += this.equilibriumColours[i] * pruneEM[i];
        }
        defaultTreeColouring.setLogProbabilityDensity(sampleEM(tree, tree.getRoot(), colourChangeMatrix, metaPopulation, defaultTreeColouring) - (Math.log(d) + this.logNodePartialsRescaling));
        return defaultTreeColouring;
    }

    @Override // dr.evolution.colouring.ColourSampler
    public double getProposalProbability(TreeColouring treeColouring, Tree tree, ColourChangeMatrix colourChangeMatrix, MetaPopulation metaPopulation) {
        populateEquilibriumColourArray(colourChangeMatrix);
        computeIntervals(tree, metaPopulation);
        this.logNodePartialsRescaling = 0.0d;
        prune(tree, tree.getRoot(), colourChangeMatrix);
        calculateMeanColourCounts(tree, colourChangeMatrix);
        this.logNodePartialsRescaling = 0.0d;
        double[] pruneEM = pruneEM(tree, tree.getRoot(), colourChangeMatrix, metaPopulation);
        if (this.useSecondColourIteration) {
            calculateMeanColourCountsEM(tree, tree.getRoot(), colourChangeMatrix);
            this.logNodePartialsRescaling = 0.0d;
            pruneEM = pruneEM(tree, tree.getRoot(), colourChangeMatrix, metaPopulation);
        }
        double d = 0.0d;
        for (int i = 0; i < this.colourCount; i++) {
            d += this.equilibriumColours[i] * pruneEM[i];
        }
        return (calculateEMProposal(tree, tree.getRoot(), colourChangeMatrix, metaPopulation, treeColouring) - Math.log(d)) - this.logNodePartialsRescaling;
    }

    private int getColour(NodeRef nodeRef) {
        return this.nodeColours[nodeRef.getNumber()];
    }

    private void setColour(NodeRef nodeRef, int i) {
        if (i < 0 || i >= this.colourCount) {
            throw new IllegalArgumentException("colour value " + i + " + is outside of range of colours, [0, " + Integer.toString(this.colourCount - 1) + GraphMLUtils.END_ATTRIBUTE);
        }
        this.nodeColours[nodeRef.getNumber()] = i;
    }

    void populateEquilibriumColourArray(ColourChangeMatrix colourChangeMatrix) {
        for (int i = 0; i < this.colourCount; i++) {
            this.equilibriumColours[i] = colourChangeMatrix.getEquilibrium(i);
        }
    }

    double[] getMeanColours(int i, ColourChangeMatrix colourChangeMatrix) {
        double[] dArr = new double[this.colourCount];
        double d = 0.0d;
        for (int i2 = 0; i2 < this.colourCount; i2++) {
            dArr[i2] = this.nodePartials[i][i2] * this.equilibriumColours[i2];
            d += dArr[i2];
        }
        for (int i3 = 0; i3 < this.colourCount; i3++) {
            int i4 = i3;
            dArr[i4] = dArr[i4] / d;
        }
        return dArr;
    }

    double[] getMeanColoursEM(int i, int i2, ColourChangeMatrix colourChangeMatrix) {
        double[] dArr = new double[this.colourCount];
        double d = 0.0d;
        for (int i3 = 0; i3 < this.colourCount; i3++) {
            dArr[i3] = this.nodePartialsEM[i][i2][i3] * this.equilibriumColours[i3];
            d += dArr[i3];
        }
        for (int i4 = 0; i4 < this.colourCount; i4++) {
            int i5 = i4;
            dArr[i5] = dArr[i5] / d;
        }
        return dArr;
    }

    void fillMeanColourCounts(Tree tree, NodeRef nodeRef, ColourChangeMatrix colourChangeMatrix) {
        if (!tree.isRoot(nodeRef)) {
            int number = tree.getParent(nodeRef).getNumber();
            int number2 = nodeRef.getNumber();
            double[] meanColours = getMeanColours(number, colourChangeMatrix);
            double[] meanColours2 = getMeanColours(number2, colourChangeMatrix);
            for (int i = 0; i < this.colourCount; i++) {
                double d = (meanColours[i] + meanColours2[i]) / 2.0d;
                double[] dArr = this.meanColourCounts[this.node2Interval[number2]];
                int i2 = i;
                dArr[i2] = dArr[i2] + d;
                double[] dArr2 = this.meanColourCounts[this.node2Interval[number]];
                int i3 = i;
                dArr2[i3] = dArr2[i3] - d;
            }
        }
        if (tree.isExternal(nodeRef)) {
            return;
        }
        NodeRef child = tree.getChild(nodeRef, 0);
        NodeRef child2 = tree.getChild(nodeRef, 1);
        fillMeanColourCounts(tree, child, colourChangeMatrix);
        fillMeanColourCounts(tree, child2, colourChangeMatrix);
    }

    void calculateMeanColourCountsEM(Tree tree, NodeRef nodeRef, ColourChangeMatrix colourChangeMatrix) {
        if (tree.isRoot(nodeRef)) {
            for (int i = 0; i < this.colourCount; i++) {
                for (int i2 = 0; i2 < this.meanColourCounts.length; i2++) {
                    this.meanColourCounts[i2][i] = 0.0d;
                }
            }
        } else {
            NodeRef parent = tree.getParent(nodeRef);
            int number = parent.getNumber();
            int number2 = nodeRef.getNumber();
            int i3 = this.node2Interval[number2];
            int i4 = this.node2Interval[number];
            double[] meanColoursEM = getMeanColoursEM(number2, 0, colourChangeMatrix);
            for (int i5 = 0; i5 < i4 - i3; i5++) {
                int i6 = i5 + 1;
                double[] meanColoursEM2 = i6 + i3 < i4 ? getMeanColoursEM(number2, i6, colourChangeMatrix) : !tree.isRoot(parent) ? getMeanColoursEM(number, 0, colourChangeMatrix) : meanColoursEM;
                for (int i7 = 0; i7 < this.colourCount; i7++) {
                    double d = (meanColoursEM2[i7] + meanColoursEM[i7]) / 2.0d;
                    double[] dArr = this.meanColourCounts[i5 + i3];
                    int i8 = i7;
                    dArr[i8] = dArr[i8] + d;
                }
                meanColoursEM = meanColoursEM2;
            }
        }
        if (!tree.isExternal(nodeRef)) {
            NodeRef child = tree.getChild(nodeRef, 0);
            NodeRef child2 = tree.getChild(nodeRef, 1);
            calculateMeanColourCountsEM(tree, child, colourChangeMatrix);
            calculateMeanColourCountsEM(tree, child2, colourChangeMatrix);
        }
        if (tree.isRoot(nodeRef)) {
        }
    }

    void calculateMeanColourCounts(Tree tree, ColourChangeMatrix colourChangeMatrix) {
        for (int i = 0; i < this.colourCount; i++) {
            for (int i2 = 0; i2 < this.meanColourCounts.length; i2++) {
                this.meanColourCounts[i2][i] = 0.0d;
            }
        }
        fillMeanColourCounts(tree, tree.getRoot(), colourChangeMatrix);
        for (int i3 = 0; i3 < this.colourCount; i3++) {
            double d = 0.0d;
            for (int i4 = 0; i4 < this.meanColourCounts.length; i4++) {
                d += this.meanColourCounts[i4][i3];
                this.meanColourCounts[i4][i3] = d;
            }
        }
    }

    private double[] prune(Tree tree, NodeRef nodeRef, ColourChangeMatrix colourChangeMatrix) {
        double[] dArr = new double[this.colourCount];
        if (tree.isExternal(nodeRef)) {
            dArr[getColour(nodeRef)] = 1.0d;
        } else {
            NodeRef child = tree.getChild(nodeRef, 0);
            NodeRef child2 = tree.getChild(nodeRef, 1);
            double[] prune = prune(tree, child, colourChangeMatrix);
            double[] prune2 = prune(tree, child2, colourChangeMatrix);
            double nodeHeight = tree.getNodeHeight(nodeRef);
            double nodeHeight2 = nodeHeight - tree.getNodeHeight(tree.getChild(nodeRef, 0));
            double nodeHeight3 = nodeHeight - tree.getNodeHeight(tree.getChild(nodeRef, 1));
            double d = 0.0d;
            for (int i = 0; i < this.colourCount; i++) {
                double d2 = 0.0d;
                double d3 = 0.0d;
                for (int i2 = 0; i2 < this.colourCount; i2++) {
                    d2 += colourChangeMatrix.forwardTimeEvolution(i, i2, nodeHeight2) * prune[i2];
                    d3 += colourChangeMatrix.forwardTimeEvolution(i, i2, nodeHeight3) * prune2[i2];
                }
                dArr[i] = d2 * d3;
                if (dArr[i] > d) {
                    d = dArr[i];
                }
            }
            if (d < 1.0E-100d) {
                for (int i3 = 0; i3 < this.colourCount; i3++) {
                    int i4 = i3;
                    dArr[i4] = dArr[i4] * 1.0E100d;
                }
                this.logNodePartialsRescaling -= Math.log(1.0E100d);
            }
        }
        this.nodePartials[nodeRef.getNumber()] = dArr;
        return dArr;
    }

    static double[] matrixEvolve(double[] dArr, int i) {
        double d = dArr[0];
        double d2 = dArr[1];
        double d3 = dArr[2];
        double d4 = dArr[3];
        double sqrt = Math.sqrt(((d - d4) * (d - d4)) + (4.0d * d2 * d3));
        if (sqrt >= 1.0E-5d) {
            double exp = Math.exp((-((d + d4) + sqrt)) / 2.0d);
            double exp2 = Math.exp((-((d + d4) - sqrt)) / 2.0d);
            return i == 0 ? new double[]{((((d4 - d) + sqrt) * exp2) - (((d4 - d) - sqrt) * exp)) / (2.0d * sqrt), (d2 * (exp2 - exp)) / sqrt} : new double[]{(d3 * (exp2 - exp)) / sqrt, ((((d - d4) + sqrt) * exp2) - (((d - d4) - sqrt) * exp)) / (2.0d * sqrt)};
        }
        if (i == 0) {
            double exp3 = Math.exp(-d);
            return new double[]{exp3, d2 * exp3};
        }
        double exp4 = Math.exp(-d4);
        return new double[]{d3 * exp4, exp4};
    }

    static void matrixPullBack(double[] dArr, double[] dArr2) {
        double d;
        double d2;
        double d3;
        double d4;
        double d5 = dArr[0];
        double d6 = dArr[1];
        double d7 = dArr[2];
        double d8 = dArr[3];
        double sqrt = Math.sqrt(((d5 - d8) * (d5 - d8)) + (4.0d * d6 * d7));
        if (sqrt < 1.0E-5d) {
            d = Math.exp(-d5);
            d4 = Math.exp(-d8);
            d2 = d6 * d;
            d3 = d7 * d4;
        } else {
            double exp = Math.exp((-((d5 + d8) + sqrt)) / 2.0d);
            double exp2 = Math.exp((-((d5 + d8) - sqrt)) / 2.0d);
            d = ((((d8 - d5) + sqrt) * exp2) - (((d8 - d5) - sqrt) * exp)) / (2.0d * sqrt);
            d2 = (d6 * (exp2 - exp)) / sqrt;
            d3 = (d7 * (exp2 - exp)) / sqrt;
            d4 = ((((d5 - d8) + sqrt) * exp2) - (((d5 - d8) - sqrt) * exp)) / (2.0d * sqrt);
        }
        double d9 = (dArr2[0] * d) + (dArr2[1] * d2);
        dArr2[1] = (dArr2[0] * d3) + (dArr2[1] * d4);
        dArr2[0] = d9;
    }

    double[] calculateMatrixElts(int i, NodeRef nodeRef, Tree tree, double d, double d2, double d3, ColourChangeMatrix colourChangeMatrix) {
        double d4 = this.meanColourCounts[i][0];
        double d5 = this.meanColourCounts[i][1];
        double d6 = d4 + d5;
        double d7 = (((0.0d * (d6 - 1.0d)) + (1.0d * (d4 - 1.0d))) / (2.0d * d2)) * d;
        if (d7 < 0.0d) {
            d7 = 0.0d;
        }
        double d8 = (((0.0d * (d6 - 1.0d)) + (1.0d * (d5 - 1.0d))) / (2.0d * d3)) * d;
        if (d8 < 0.0d) {
            d8 = 0.0d;
        }
        if (!this.useBranchBias) {
            d7 = 0.0d;
            d8 = 0.0d;
        }
        double min = Math.min(d7, d8);
        double d9 = d7 - min;
        double d10 = d8 - min;
        double forwardRate = colourChangeMatrix.getForwardRate(0, 1) * d;
        double forwardRate2 = colourChangeMatrix.getForwardRate(1, 0) * d;
        return new double[]{forwardRate + d9, forwardRate, forwardRate2, forwardRate2 + d10};
    }

    double[] pruneBranchEM(ColourChangeMatrix colourChangeMatrix, double[] dArr, NodeRef nodeRef, NodeRef nodeRef2, Tree tree, MetaPopulation metaPopulation) {
        int i = this.node2Interval[nodeRef.getNumber()];
        int i2 = this.node2Interval[nodeRef2.getNumber()];
        double[][] dArr2 = new double[i - i2][2];
        double[] dArr3 = (double[]) dArr.clone();
        for (int i3 = i2; i3 != i; i3++) {
            dArr2[i3 - i2][0] = dArr3[0];
            dArr2[i3 - i2][1] = dArr3[1];
            matrixPullBack(calculateMatrixElts(i3, nodeRef2, tree, this.interval2Height[i3 + 1] - this.interval2Height[i3], this.avgN0[i3], this.avgN1[i3], colourChangeMatrix), dArr3);
        }
        this.nodePartialsEM[nodeRef2.getNumber()] = dArr2;
        return dArr3;
    }

    private double[] pruneEM(Tree tree, NodeRef nodeRef, ColourChangeMatrix colourChangeMatrix, MetaPopulation metaPopulation) {
        double[] dArr = new double[this.colourCount];
        if (tree.isExternal(nodeRef)) {
            dArr[getColour(nodeRef)] = 1.0d;
        } else {
            NodeRef child = tree.getChild(nodeRef, 0);
            NodeRef child2 = tree.getChild(nodeRef, 1);
            double[] pruneEM = pruneEM(tree, child, colourChangeMatrix, metaPopulation);
            double[] pruneEM2 = pruneEM(tree, child2, colourChangeMatrix, metaPopulation);
            double[] pruneBranchEM = pruneBranchEM(colourChangeMatrix, pruneEM, nodeRef, child, tree, metaPopulation);
            double[] pruneBranchEM2 = pruneBranchEM(colourChangeMatrix, pruneEM2, nodeRef, child2, tree, metaPopulation);
            double d = 0.0d;
            for (int i = 0; i < this.colourCount; i++) {
                dArr[i] = pruneBranchEM[i] * pruneBranchEM2[i];
                if (this.useNodeBias) {
                    int i2 = i;
                    dArr[i2] = dArr[i2] * (colourChangeMatrix.getEquilibrium(i) / metaPopulation.getDemographic(tree.getNodeHeight(nodeRef) - 1.0E-6d, i));
                }
                if (dArr[i] > d) {
                    d = dArr[i];
                }
            }
            if (d < 1.0E-100d) {
                for (int i3 = 0; i3 < this.colourCount; i3++) {
                    int i4 = i3;
                    dArr[i4] = dArr[i4] * 1.0E100d;
                }
                this.logNodePartialsRescaling -= Math.log(1.0E100d);
            }
        }
        this.nodePartials[nodeRef.getNumber()] = dArr;
        return dArr;
    }

    private double sampleEM(Tree tree, NodeRef nodeRef, ColourChangeMatrix colourChangeMatrix, MetaPopulation metaPopulation, DefaultTreeColouring defaultTreeColouring) {
        int colour;
        double d = 0.0d;
        if (tree.isRoot(nodeRef)) {
            this._totalIntegratedRate = 0.0d;
            double[] equilibrium = colourChangeMatrix.getEquilibrium();
            double[] dArr = this.nodePartials[nodeRef.getNumber()];
            double[] dArr2 = new double[this.colourCount];
            double d2 = -1.0d;
            double d3 = 1.0d;
            for (int i = 0; i < equilibrium.length; i++) {
                dArr2[i] = equilibrium[i] * dArr[i];
                d2 = Math.max(d2, dArr2[i]);
                d3 = Math.min(d3, dArr2[i]);
            }
            colour = MathUtils.randomChoicePDF(dArr2);
            d = 0.0d + Math.log(equilibrium[colour]);
        } else {
            int number = nodeRef.getNumber();
            double[][] dArr3 = this.nodePartialsEM[number];
            int i2 = this.node2Interval[number];
            this.nodeColoursEM[number] = new int[dArr3.length];
            colour = getColour(tree.getParent(nodeRef));
            DefaultBranchColouring defaultBranchColouring = new DefaultBranchColouring(colour, colour);
            double[] dArr4 = new double[this.colourCount];
            for (int length = dArr3.length - 1; length >= 0; length--) {
                int i3 = length + i2;
                double d4 = this.interval2Height[i3];
                double d5 = this.interval2Height[i3 + 1] - d4;
                double[] calculateMatrixElts = calculateMatrixElts(i3, nodeRef, tree, d5, this.avgN0[i3], this.avgN1[i3], colourChangeMatrix);
                double[] matrixEvolve = matrixEvolve(calculateMatrixElts, colour);
                for (int i4 = 0; i4 < this.colourCount; i4++) {
                    dArr4[i4] = matrixEvolve[i4] * dArr3[length][i4];
                }
                int randomChoicePDF = MathUtils.randomChoicePDF(dArr4);
                this.nodeColoursEM[number][length] = randomChoicePDF;
                d += sampleConditionalBranchColouringEM(nodeRef, colour, randomChoicePDF, d5, d4, calculateMatrixElts, defaultBranchColouring);
                colour = randomChoicePDF;
            }
            defaultTreeColouring.setBranchColouring(nodeRef, defaultBranchColouring);
        }
        setColour(nodeRef, colour);
        if (!tree.isExternal(nodeRef) && this.useNodeBias) {
            d += Math.log(colourChangeMatrix.getEquilibrium(colour) / metaPopulation.getDemographic(tree.getNodeHeight(nodeRef) - 1.0E-6d, colour));
        }
        for (int i5 = 0; i5 < tree.getChildCount(nodeRef); i5++) {
            d += sampleEM(tree, tree.getChild(nodeRef, i5), colourChangeMatrix, metaPopulation, defaultTreeColouring);
        }
        return d;
    }

    private double sampleConditionalBranchColouringEM(NodeRef nodeRef, int i, int i2, double d, double d2, double[] dArr, DefaultBranchColouring defaultBranchColouring) {
        int i3;
        double d3;
        double d4;
        boolean z;
        double d5;
        double d6;
        double nextDouble;
        double d7;
        DefaultBranchColouring defaultBranchColouring2 = new DefaultBranchColouring(i, i2);
        int i4 = 0;
        do {
            defaultBranchColouring2.clear();
            i3 = i;
            d3 = d;
            d4 = 0.0d;
            z = false;
            boolean z2 = true;
            do {
                if (i3 == 0) {
                    d5 = dArr[0] / d;
                    d6 = dArr[1] / d;
                } else {
                    d5 = dArr[3] / d;
                    d6 = dArr[2] / d;
                }
                do {
                    nextDouble = MathUtils.nextDouble();
                } while (nextDouble == 0.0d);
                if (z2 && i != i2) {
                    double exp = Math.exp((-d5) * d);
                    nextDouble = exp + (nextDouble * (1.0d - exp));
                }
                d7 = (-Math.log(nextDouble)) / d5;
                d3 -= d7;
                if (d3 <= 0.0d) {
                    d4 += (-d5) * (d3 + d7);
                } else if (z2 || d6 == d5 || MathUtils.nextDouble() < d6 / d5) {
                    i3 = 1 - i3;
                    defaultBranchColouring2.addEvent(i3, d3 + d2);
                    d4 += ((-d5) * d7) + Math.log(d6);
                } else {
                    z = true;
                }
                z2 = false;
                if (z) {
                    break;
                }
            } while (d3 > 0.0d);
            i4++;
            if (i3 != i2) {
                z = true;
            }
            if (!z) {
                break;
            }
        } while (i4 < 1000);
        if (z && i3 != i2) {
            defaultBranchColouring2.addEvent(i2, (0.01d * (d3 + d7)) + d2);
        }
        defaultBranchColouring.addHistory(defaultBranchColouring2);
        return d4;
    }

    private double calculateEMProposal(Tree tree, NodeRef nodeRef, ColourChangeMatrix colourChangeMatrix, MetaPopulation metaPopulation, TreeColouring treeColouring) {
        int nodeColour;
        double d = 0.0d;
        if (tree.isRoot(nodeRef)) {
            double[] equilibrium = colourChangeMatrix.getEquilibrium();
            nodeColour = treeColouring.getNodeColour(nodeRef);
            d = 0.0d + Math.log(equilibrium[nodeColour]);
        } else {
            int number = nodeRef.getNumber();
            double[][] dArr = this.nodePartialsEM[number];
            int i = this.node2Interval[number];
            BranchColouring branchColouring = treeColouring.getBranchColouring(nodeRef);
            for (int length = dArr.length - 1; length >= 0; length--) {
                int i2 = length + i;
                double d2 = this.interval2Height[i2];
                double d3 = this.interval2Height[i2 + 1] - d2;
                d += calculateConditionalBranchColouringEM(nodeRef, d3, d2, calculateMatrixElts(i2, nodeRef, tree, d3, this.avgN0[i2], this.avgN1[i2], colourChangeMatrix), branchColouring);
            }
            nodeColour = treeColouring.getNodeColour(nodeRef);
        }
        if (!tree.isExternal(nodeRef) && this.useNodeBias) {
            d += Math.log(colourChangeMatrix.getEquilibrium(nodeColour) / metaPopulation.getDemographic(tree.getNodeHeight(nodeRef) - 1.0E-6d, nodeColour));
        }
        for (int i3 = 0; i3 < tree.getChildCount(nodeRef); i3++) {
            d += calculateEMProposal(tree, tree.getChild(nodeRef, i3), colourChangeMatrix, metaPopulation, treeColouring);
        }
        return d;
    }

    private double calculateConditionalBranchColouringEM(NodeRef nodeRef, double d, double d2, double[] dArr, BranchColouring branchColouring) {
        double d3;
        double d4;
        double d5 = d + d2;
        int nextForwardEvent = branchColouring.getNextForwardEvent(d5);
        int forwardColourBelow = branchColouring.getForwardColourBelow(nextForwardEvent - 1);
        double d6 = 0.0d;
        int i = 0;
        while (d5 > d2) {
            double forwardTime = nextForwardEvent == branchColouring.getNumEvents() + 1 ? d2 - 1.0d : branchColouring.getForwardTime(nextForwardEvent);
            double d7 = d5 - forwardTime;
            if (forwardColourBelow == 0) {
                d3 = dArr[0] / d;
                d4 = dArr[1] / d;
            } else {
                d3 = dArr[3] / d;
                d4 = dArr[2] / d;
            }
            if (forwardTime < d2) {
                d6 += (-d3) * (d5 - d2);
            } else {
                d6 += ((-d3) * d7) + Math.log(d4);
                forwardColourBelow = branchColouring.getForwardColourBelow(nextForwardEvent);
                i++;
            }
            d5 = forwardTime;
            nextForwardEvent++;
        }
        return d6;
    }

    private void prettyPrint(String str, double[] dArr) {
        System.out.print(str + "= (");
        for (double d : dArr) {
            System.out.print(d + ", ");
        }
        System.out.println(")");
    }

    static void testMatrix(double[] dArr, double[] dArr2) {
        if (Math.abs(matrixEvolve(dArr, 0)[0] - dArr2[0]) > 1.0E-6d) {
            throw new Error("1");
        }
        if (Math.abs(matrixEvolve(dArr, 0)[1] - dArr2[1]) > 1.0E-6d) {
            throw new Error("2");
        }
        if (Math.abs(matrixEvolve(dArr, 1)[0] - dArr2[2]) > 1.0E-6d) {
            throw new Error("3");
        }
        if (Math.abs(matrixEvolve(dArr, 1)[1] - dArr2[3]) > 1.0E-6d) {
            throw new Error("4");
        }
        double[] dArr3 = {1.0d, 0.0d};
        double[] dArr4 = {0.0d, 1.0d};
        matrixPullBack(dArr, dArr3);
        matrixPullBack(dArr, dArr4);
        if (Math.abs(dArr3[0] - dArr2[0]) > 1.0E-6d) {
            throw new Error("5");
        }
        if (Math.abs(dArr3[1] - dArr2[2]) > 1.0E-6d) {
            throw new Error("7");
        }
        if (Math.abs(dArr4[0] - dArr2[1]) > 1.0E-6d) {
            throw new Error("6");
        }
        if (Math.abs(dArr4[1] - dArr2[3]) > 1.0E-6d) {
            throw new Error("8");
        }
    }

    public static void main(String[] strArr) {
        testMatrix(new double[]{5.0d, 3.0d, 2.0d, 3.0d}, new double[]{0.0811818d, 0.145616d, 0.097077d, 0.178259d});
        System.out.println("First matrix OK");
        testMatrix(new double[]{1.0d, 1.0d, 0.0d, 1.0d}, new double[]{0.367879d, 0.367879d, 0.0d, 0.367879d});
        System.out.println("Second matrix OK");
        testMatrix(new double[]{1.0d, 0.0d, 1.0d, 1.0d}, new double[]{0.367879d, 0.0d, 0.367879d, 0.367879d});
        System.out.println("Third matrix OK");
    }
}
