package dr.evolution.colouring;

import dr.evolution.alignment.Alignment;
import dr.evolution.coalescent.structure.MetaPopulation;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.util.TaxonList;
import dr.evoxml.util.GraphMLUtils;
import dr.math.MathUtils;

/* loaded from: input_file:dr/evolution/colouring/BasicColourSampler.class */
public class BasicColourSampler implements ColourSampler {
    static final int maxIterations = 1000;
    private final int colourCount;
    private int[] nodeColours;
    private double[][] nodePartials;
    private final int[] leafColourCounts;

    public BasicColourSampler(Alignment alignment, Tree tree) {
        if (alignment.getSiteCount() != 1) {
            throw new IllegalArgumentException("Tip colour alignment must consist of a single column!");
        }
        this.nodeColours = new int[tree.getNodeCount()];
        this.colourCount = alignment.getDataType().getStateCount();
        this.leafColourCounts = new int[this.colourCount];
        for (int i = 0; i < tree.getExternalNodeCount(); i++) {
            NodeRef externalNode = tree.getExternalNode(i);
            int state = alignment.getState(alignment.getTaxonIndex(tree.getTaxonId(i)), 0);
            this.nodeColours[externalNode.getNumber()] = state;
            int[] iArr = this.leafColourCounts;
            iArr[state] = iArr[state] + 1;
        }
        this.nodePartials = new double[tree.getNodeCount()][this.colourCount];
    }

    public BasicColourSampler(TaxonList[] taxonListArr, Tree tree) {
        this.nodeColours = new int[tree.getNodeCount()];
        this.colourCount = taxonListArr.length + 1;
        this.leafColourCounts = new int[this.colourCount];
        for (int i = 0; i < tree.getExternalNodeCount(); i++) {
            NodeRef externalNode = tree.getExternalNode(i);
            int i2 = 0;
            for (int i3 = 0; i3 < taxonListArr.length; i3++) {
                if (taxonListArr[i3].getTaxonIndex(tree.getTaxonId(i)) != -1) {
                    i2 = i3 + 1;
                }
            }
            this.nodeColours[externalNode.getNumber()] = i2;
            int[] iArr = this.leafColourCounts;
            int i4 = i2;
            iArr[i4] = iArr[i4] + 1;
        }
        this.nodePartials = new double[tree.getNodeCount()][this.colourCount];
    }

    @Override // dr.evolution.colouring.ColourSampler
    public int[] getLeafColourCounts() {
        return this.leafColourCounts;
    }

    @Override // dr.evolution.colouring.ColourSampler
    public DefaultTreeColouring sampleTreeColouring(Tree tree, ColourChangeMatrix colourChangeMatrix, MetaPopulation metaPopulation) {
        DefaultTreeColouring defaultTreeColouring = new DefaultTreeColouring(2, tree);
        double[] populationSizes = metaPopulation.getPopulationSizes(0.0d);
        double[] prune = prune(tree, tree.getRoot(), colourChangeMatrix, populationSizes);
        double d = 0.0d;
        for (int i = 0; i < prune.length; i++) {
            d += colourChangeMatrix.getEquilibrium(i) * prune[i];
        }
        sampleInternalNodes(tree, tree.getRoot(), colourChangeMatrix);
        sampleBranchColourings(defaultTreeColouring, tree, tree.getRoot(), colourChangeMatrix);
        defaultTreeColouring.setLogProbabilityDensity(calculateLogProbabilityDensity(defaultTreeColouring, tree, tree.getRoot(), colourChangeMatrix, populationSizes) - Math.log(d));
        return defaultTreeColouring;
    }

    private final int getColour(NodeRef nodeRef) {
        return this.nodeColours[nodeRef.getNumber()];
    }

    private final void setColour(NodeRef nodeRef, int i) {
        if (i < 0 || i >= this.colourCount) {
            throw new IllegalArgumentException("colour value " + i + " + is outside of range of colours, [0, " + Integer.toString(this.colourCount - 1) + GraphMLUtils.END_ATTRIBUTE);
        }
        this.nodeColours[nodeRef.getNumber()] = i;
    }

    private final double[] prune(Tree tree, NodeRef nodeRef, ColourChangeMatrix colourChangeMatrix, double[] dArr) {
        double[] dArr2 = new double[this.colourCount];
        if (tree.isExternal(nodeRef)) {
            dArr2[getColour(nodeRef)] = 1.0d;
            return dArr2;
        }
        NodeRef child = tree.getChild(nodeRef, 0);
        NodeRef child2 = tree.getChild(nodeRef, 1);
        double[] prune = prune(tree, child, colourChangeMatrix, dArr);
        double[] prune2 = prune(tree, child2, colourChangeMatrix, dArr);
        double nodeHeight = tree.getNodeHeight(nodeRef);
        double nodeHeight2 = nodeHeight - tree.getNodeHeight(tree.getChild(nodeRef, 0));
        double nodeHeight3 = nodeHeight - tree.getNodeHeight(tree.getChild(nodeRef, 1));
        for (int i = 0; i < dArr2.length; i++) {
            double d = 0.0d;
            double d2 = 0.0d;
            for (int i2 = 0; i2 < prune.length; i2++) {
                d += prune[i2] * colourChangeMatrix.forwardTimeEvolution(i, i2, nodeHeight2);
                d2 += prune2[i2] * colourChangeMatrix.forwardTimeEvolution(i, i2, nodeHeight3);
            }
            dArr2[i] = d * d2;
        }
        this.nodePartials[nodeRef.getNumber()] = dArr2;
        return dArr2;
    }

    private final void sampleInternalNodes(Tree tree, NodeRef nodeRef, ColourChangeMatrix colourChangeMatrix) {
        double[] dArr;
        double[] dArr2 = this.nodePartials[nodeRef.getNumber()];
        if (tree.isRoot(nodeRef)) {
            dArr = colourChangeMatrix.getEquilibrium();
        } else {
            NodeRef parent = tree.getParent(nodeRef);
            int colour = getColour(parent);
            double nodeHeight = tree.getNodeHeight(parent) - tree.getNodeHeight(nodeRef);
            dArr = new double[dArr2.length];
            for (int i = 0; i < dArr2.length; i++) {
                dArr[i] = colourChangeMatrix.forwardTimeEvolution(colour, i, nodeHeight);
            }
        }
        for (int i2 = 0; i2 < dArr2.length; i2++) {
            double[] dArr3 = dArr;
            int i3 = i2;
            dArr3[i3] = dArr3[i3] * dArr2[i2];
        }
        setColour(nodeRef, MathUtils.randomChoicePDF(dArr));
        for (int i4 = 0; i4 < tree.getChildCount(nodeRef); i4++) {
            NodeRef child = tree.getChild(nodeRef, i4);
            if (!tree.isExternal(child)) {
                sampleInternalNodes(tree, child, colourChangeMatrix);
            }
        }
    }

    private void sampleBranchColourings(DefaultTreeColouring defaultTreeColouring, Tree tree, NodeRef nodeRef, ColourChangeMatrix colourChangeMatrix) {
        if (!tree.isRoot(nodeRef)) {
            NodeRef parent = tree.getParent(nodeRef);
            defaultTreeColouring.setBranchColouring(nodeRef, sampleConditionalBranchColouring(getColour(parent), tree.getNodeHeight(parent), getColour(nodeRef), tree.getNodeHeight(nodeRef), colourChangeMatrix));
        }
        for (int i = 0; i < tree.getChildCount(nodeRef); i++) {
            sampleBranchColourings(defaultTreeColouring, tree, tree.getChild(nodeRef, i), colourChangeMatrix);
        }
    }

    private DefaultBranchColouring sampleConditionalBranchColouring(int i, double d, int i2, double d2, ColourChangeMatrix colourChangeMatrix) {
        int i3;
        double d3;
        double nextDouble;
        double d4;
        DefaultBranchColouring defaultBranchColouring = new DefaultBranchColouring(i, i2);
        int i4 = 1000;
        do {
            defaultBranchColouring.clear();
            i3 = i;
            d3 = d;
            do {
                double d5 = -colourChangeMatrix.getForwardRate(i3, i3);
                do {
                    nextDouble = MathUtils.nextDouble();
                } while (nextDouble == 0.0d);
                if (i != i2 && defaultBranchColouring.getNumEvents() == 0) {
                    double exp = Math.exp((-d5) * (d - d2));
                    nextDouble = exp + (nextDouble * (1.0d - exp));
                }
                d4 = (-Math.log(nextDouble)) / d5;
                d3 -= d4;
                if (d3 > d2) {
                    i3 = 1 - i3;
                    defaultBranchColouring.addEvent(i3, d3);
                }
            } while (d3 > d2);
            i4--;
            if (i3 == i2) {
                break;
            }
        } while (i4 > 0);
        if (i3 != i2) {
            defaultBranchColouring.addEvent(i2, d2 + (0.01d * ((d3 + d4) - d2)));
            System.out.println("dr.evolution.colouring.BranchColourSampler: failed to generate sample after 1000 trials.");
            System.out.println(": parentColour=" + i);
            System.out.println(": parentHeight=" + d);
            System.out.println(": childColour=" + i2);
            System.out.println(": childHeight=" + d2);
            System.out.println(": migration rate 0->1 = " + colourChangeMatrix.getForwardRate(0, 1));
            System.out.println(": migration rate 1->0 = " + colourChangeMatrix.getForwardRate(1, 0));
        }
        return defaultBranchColouring;
    }

    private final double calculateLogProbabilityDensity(TreeColouring treeColouring, Tree tree, NodeRef nodeRef, ColourChangeMatrix colourChangeMatrix, double[] dArr) {
        double exp;
        double d = 1.0d;
        if (tree.isRoot(nodeRef)) {
            exp = colourChangeMatrix.getEquilibrium(treeColouring.getNodeColour(nodeRef));
        } else {
            NodeRef parent = tree.getParent(nodeRef);
            BranchColouring branchColouring = treeColouring.getBranchColouring(nodeRef);
            int nodeColour = treeColouring.getNodeColour(parent);
            double nodeHeight = tree.getNodeHeight(parent);
            for (int i = 1; i <= branchColouring.getNumEvents(); i++) {
                int forwardColourBelow = branchColouring.getForwardColourBelow(i);
                double forwardTime = branchColouring.getForwardTime(i);
                d = d * Math.exp((-(nodeHeight - forwardTime)) * (-colourChangeMatrix.getForwardRate(nodeColour, nodeColour))) * colourChangeMatrix.getForwardRate(nodeColour, forwardColourBelow);
                nodeHeight = forwardTime;
                nodeColour = forwardColourBelow;
            }
            exp = d * Math.exp((-(nodeHeight - tree.getNodeHeight(nodeRef))) * (-colourChangeMatrix.getForwardRate(nodeColour, nodeColour)));
        }
        double log = Math.log(exp);
        for (int i2 = 0; i2 < tree.getChildCount(nodeRef); i2++) {
            log += calculateLogProbabilityDensity(treeColouring, tree, tree.getChild(nodeRef, i2), colourChangeMatrix, dArr);
        }
        return log;
    }

    public static final double calculateLogNormalization(TreeColouring treeColouring, Tree tree, NodeRef nodeRef) {
        double d = 0.0d;
        if (!tree.isRoot(nodeRef)) {
            double d2 = 1.0d;
            double nodeHeight = tree.getNodeHeight(tree.getParent(nodeRef)) - tree.getNodeHeight(nodeRef);
            for (int i = 1; i <= treeColouring.getBranchColouring(nodeRef).getNumEvents(); i++) {
                d2 *= nodeHeight / i;
            }
            d = 1.0d * Math.log(d2);
        }
        for (int i2 = 0; i2 < tree.getChildCount(nodeRef); i2++) {
            d += calculateLogNormalization(treeColouring, tree, tree.getChild(nodeRef, i2));
        }
        return d;
    }

    @Override // dr.evolution.colouring.ColourSampler
    public double getProposalProbability(TreeColouring treeColouring, Tree tree, ColourChangeMatrix colourChangeMatrix, MetaPopulation metaPopulation) {
        throw new IllegalArgumentException("Not implemented for BasicColourSampler; you can only use <ColouredOperator>s");
    }
}
