package dr.evomodel.coalescent.basta;

import dr.evolution.alignment.PatternList;
import dr.evolution.coalescent.IntervalType;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeUtils;
import dr.evolution.util.TaxonList;
import dr.evolution.util.Units;
import dr.evomodel.branchratemodel.BranchRateModel;
import dr.evomodel.branchratemodel.DefaultBranchRateModel;
import dr.evomodel.coalescent.AbstractCoalescentLikelihood;
import dr.evomodel.substmodel.GeneralSubstitutionModel;
import dr.evomodel.tree.TreeModel;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.util.Author;
import dr.util.Citable;
import dr.util.Citation;
import dr.util.ComparableDouble;
import dr.util.HeapSort;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;

/* loaded from: input_file:dr/evomodel/coalescent/basta/OldStructuredCoalescentLikelihood.class */
public class OldStructuredCoalescentLikelihood extends AbstractCoalescentLikelihood implements Citable {
    private static final boolean DEBUG = false;
    private static final boolean MATRIX_DEBUG = false;
    private static final boolean UPDATE_DEBUG = false;
    private static final boolean ASSOC_MULTIPLICATION = true;
    public static Citation CITATION = new Citation(new Author[]{new Author("Nicola", "De Maio"), new Author("Chieh-Hsi", "Wu"), new Author("Kathleen", "O'Reilly"), new Author("Daniel", "Wilson")}, "New routes to phylogeography: a Bayesian structured coalescent approximation", 2015, "PLOS Genetics", 11, "e1005421", "10.1371/journal.pgen.1005421");
    private TreeModel treeModel;
    private BranchRateModel branchRateModel;
    private Parameter popSizes;
    private PatternList patternList;
    private double[] startExpected;
    private double[] endExpected;
    private ProbDist[] nodeProbDist;
    private ArrayList<ComparableDouble> times;
    private ArrayList<Integer> children;
    private ArrayList<NodeRef> nodes;
    private int[] indices;
    private ArrayList<ComparableDouble> storedTimes;
    private ArrayList<Integer> storedChildren;
    private ArrayList<NodeRef> storedNodes;
    private int[] storedIndices;
    private ArrayList<ProbDist> activeLineageList;
    private ArrayList<ProbDist> tempLineageList;
    private GeneralSubstitutionModel generalSubstitutionModel;
    private int demes;
    private int subIntervals;
    private int maxCoalescentIntervals;
    private int currentCoalescentInterval;
    private double[][] migrationMatrices;
    private int finalCoalescentInterval;
    private double[][] storedMigrationMatrices;
    private boolean matricesKnown;
    private boolean rateChanged;
    private boolean treeModelUpdateFired;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:dr/evomodel/coalescent/basta/OldStructuredCoalescentLikelihood$ProbDist.class */
    public class ProbDist {
        private double[] startLineageProbs;
        private double[] endLineageProbs;
        private double intervalLength;
        private NodeRef node;
        private boolean incremented;
        private IntervalType intervalType;
        private NodeRef leftChild;
        private NodeRef rightChild;

        public ProbDist(int i, double d, NodeRef nodeRef) {
            this.incremented = false;
            this.leftChild = null;
            this.rightChild = null;
            this.startLineageProbs = new double[i];
            this.endLineageProbs = new double[i];
            this.intervalLength = d;
            this.node = nodeRef;
        }

        public ProbDist(int i, double d, NodeRef nodeRef, NodeRef nodeRef2, NodeRef nodeRef3) {
            this.incremented = false;
            this.leftChild = null;
            this.rightChild = null;
            this.startLineageProbs = new double[i];
            this.endLineageProbs = new double[i];
            this.intervalLength = d;
            this.node = nodeRef;
            this.leftChild = nodeRef2;
            this.rightChild = nodeRef3;
        }

        public double computeCoalescedLineage(ProbDist probDist, ProbDist probDist2) {
            double d = 0.0d;
            double[] dArr = new double[OldStructuredCoalescentLikelihood.this.demes];
            for (int i = 0; i < OldStructuredCoalescentLikelihood.this.demes; i++) {
                dArr[i] = (probDist.getEndLineageProb(i) * probDist2.getEndLineageProb(i)) / OldStructuredCoalescentLikelihood.this.popSizes.getParameterValue(i);
                d += dArr[i];
            }
            for (int i2 = 0; i2 < OldStructuredCoalescentLikelihood.this.demes; i2++) {
                setStartLineageProb(i2, dArr[i2] / d);
            }
            this.intervalLength = 0.0d;
            return Math.log(d);
        }

        public void computeEndLineageDensities(double d, double[] dArr) {
            if (d == 0.0d) {
                for (int i = 0; i < OldStructuredCoalescentLikelihood.this.demes; i++) {
                    setEndLineageProb(i, getStartLineageProb(i));
                }
                return;
            }
            for (int i2 = 0; i2 < OldStructuredCoalescentLikelihood.this.demes; i2++) {
                double d2 = 0.0d;
                for (int i3 = 0; i3 < OldStructuredCoalescentLikelihood.this.demes; i3++) {
                    d2 += this.startLineageProbs[i3] * dArr[(i3 * OldStructuredCoalescentLikelihood.this.demes) + i2];
                }
                setEndLineageProb(i2, d2);
            }
        }

        public void incrementIntervalLength(double d, double[] dArr) {
            this.intervalLength += d;
            if (this.incremented) {
                for (int i = 0; i < OldStructuredCoalescentLikelihood.this.demes; i++) {
                    this.startLineageProbs[i] = this.endLineageProbs[i];
                }
                computeEndLineageDensities(d, dArr);
            } else {
                computeEndLineageDensities(this.intervalLength, dArr);
            }
            this.incremented = true;
        }

        public IntervalType getIntervalType() {
            return this.intervalType;
        }

        public void setIntervalType(IntervalType intervalType) {
            this.intervalType = intervalType;
        }

        public NodeRef getLeftChild() {
            return this.leftChild;
        }

        public NodeRef getRightChild() {
            return this.rightChild;
        }

        public double getStartLineageProb(int i) {
            return this.startLineageProbs[i];
        }

        public double getEndLineageProb(int i) {
            return this.endLineageProbs[i];
        }

        public void setStartLineageProb(int i, double d) {
            this.startLineageProbs[i] = d;
        }

        private void setEndLineageProb(int i, double d) {
            this.endLineageProbs[i] = d;
        }

        public String toString() {
            String str = "Node " + this.node + " ; length = " + this.intervalLength + " S(";
            for (int i = 0; i < this.startLineageProbs.length; i++) {
                str = str + this.startLineageProbs[i] + " ";
            }
            String str2 = str + ") E(";
            for (int i2 = 0; i2 < this.endLineageProbs.length; i2++) {
                str2 = str2 + this.endLineageProbs[i2] + " ";
            }
            return str2 + ")";
        }
    }

    public OldStructuredCoalescentLikelihood(Tree tree, BranchRateModel branchRateModel, Parameter parameter, PatternList patternList, GeneralSubstitutionModel generalSubstitutionModel, int i, TaxonList taxonList, List<TaxonList> list) throws TreeUtils.MissingTaxonException {
        super(StructuredCoalescentLikelihoodParser.STRUCTURED_COALESCENT, tree, taxonList, list);
        this.treeModel = (TreeModel) tree;
        this.patternList = patternList;
        this.popSizes = parameter;
        addVariable(this.popSizes);
        if (branchRateModel != null) {
            this.branchRateModel = branchRateModel;
        } else {
            this.branchRateModel = new DefaultBranchRateModel();
        }
        addModel(this.branchRateModel);
        this.generalSubstitutionModel = generalSubstitutionModel;
        addModel(this.generalSubstitutionModel);
        this.demes = generalSubstitutionModel.getDataType().getStateCount();
        this.startExpected = new double[this.demes];
        this.endExpected = new double[this.demes];
        this.subIntervals = i;
        this.activeLineageList = new ArrayList<>();
        this.tempLineageList = new ArrayList<>();
        this.nodeProbDist = new ProbDist[this.treeModel.getNodeCount()];
        this.maxCoalescentIntervals = (this.treeModel.getTaxonCount() * 2) - 2;
        this.currentCoalescentInterval = 0;
        this.migrationMatrices = new double[this.maxCoalescentIntervals][this.demes * this.demes];
        this.storedMigrationMatrices = new double[this.maxCoalescentIntervals][this.demes * this.demes];
        this.times = new ArrayList<>();
        this.children = new ArrayList<>();
        this.nodes = new ArrayList<>();
        this.storedTimes = new ArrayList<>();
        this.storedChildren = new ArrayList<>();
        this.storedNodes = new ArrayList<>();
        this.treeModelUpdateFired = false;
        this.rateChanged = false;
    }

    @Override // dr.evomodel.coalescent.AbstractCoalescentLikelihood
    public double calculateLogLikelihood() {
        this.logLikelihood = traverseTree(this.treeModel, this.treeModel.getRoot(), this.patternList);
        return this.logLikelihood;
    }

    @Override // dr.evomodel.coalescent.AbstractCoalescentLikelihood, dr.inference.model.Likelihood
    public double getLogLikelihood() {
        if (!this.likelihoodKnown) {
            this.logLikelihood = calculateLogLikelihood();
            this.likelihoodKnown = true;
        }
        return this.logLikelihood;
    }

    private double traverseTree(Tree tree, NodeRef nodeRef, PatternList patternList) {
        if (!this.treeModelUpdateFired || this.rateChanged) {
            this.times.clear();
            this.children.clear();
            this.nodes.clear();
            collectAllTimes(tree, nodeRef, this.nodes, this.times, this.children);
            this.indices = new int[this.times.size()];
            HeapSort.sort(this.times, this.indices);
        } else {
            updateTransitionProbabilities();
            this.treeModelUpdateFired = false;
        }
        double d = 0.0d;
        double doubleValue = this.times.get(this.indices[0]).doubleValue();
        int i = 0;
        int i2 = 0;
        while (i < this.times.size()) {
            int i3 = 0;
            int i4 = 0;
            double doubleValue2 = this.times.get(this.indices[i]).doubleValue();
            double d2 = doubleValue2;
            double d3 = doubleValue2 - doubleValue;
            while (Math.abs(d2 - doubleValue2) < 1.0E-9d) {
                int intValue = this.children.get(this.indices[i]).intValue();
                if (intValue == 0) {
                    i4++;
                } else {
                    i3 += intValue - 1;
                }
                i++;
                if (i == this.times.size()) {
                    break;
                }
                d2 = this.times.get(this.indices[i]).doubleValue();
            }
            if (i4 > 0) {
                if (d3 > 1.0E-9d) {
                    incrementActiveLineages(doubleValue2 - doubleValue);
                    while (true) {
                        if (Math.abs(this.treeModel.getNodeHeight(this.nodes.get(this.indices[i2])) - doubleValue2) >= 1.0E-9d) {
                            break;
                        }
                        NodeRef nodeRef2 = this.nodes.get(this.indices[i2]);
                        if (this.treeModel.isExternal(nodeRef2)) {
                            ProbDist probDist = new ProbDist(this.demes, 0.0d, nodeRef2);
                            probDist.setIntervalType(IntervalType.SAMPLE);
                            probDist.setStartLineageProb(patternList.getPattern(0)[patternList.getTaxonIndex(this.treeModel.getNodeTaxon(probDist.node).getId())], 1.0d);
                            probDist.computeEndLineageDensities(0.0d, null);
                            this.tempLineageList.add(probDist);
                            this.nodeProbDist[nodeRef2.getNumber()] = probDist;
                        } else {
                            if (this.treeModel.getChildCount(nodeRef2) > 2) {
                                throw new RuntimeException("Structured coalescent currently only allows strictly bifurcating trees.");
                            }
                            ProbDist probDist2 = new ProbDist(this.demes, d3, nodeRef2, this.treeModel.getChild(nodeRef2, 0), this.treeModel.getChild(nodeRef2, 1));
                            probDist2.setIntervalType(IntervalType.COALESCENT);
                            this.tempLineageList.add(probDist2);
                        }
                        i2++;
                        if (i2 >= this.indices.length) {
                            i2 = 0;
                            break;
                        }
                    }
                    doubleValue = doubleValue2;
                }
                while (true) {
                    if (Math.abs(this.treeModel.getNodeHeight(this.nodes.get(this.indices[i2])) - doubleValue) >= 1.0E-9d) {
                        break;
                    }
                    NodeRef nodeRef3 = this.nodes.get(this.indices[i2]);
                    if (!this.treeModel.isExternal(nodeRef3)) {
                        throw new RuntimeException("First interval cannot be a coalescent event.");
                    }
                    ProbDist probDist3 = new ProbDist(this.demes, 0.0d, nodeRef3);
                    probDist3.setIntervalType(IntervalType.SAMPLE);
                    probDist3.setStartLineageProb(patternList.getPattern(0)[patternList.getTaxonIndex(this.treeModel.getNodeTaxon(probDist3.node).getId())], 1.0d);
                    probDist3.computeEndLineageDensities(0.0d, null);
                    this.tempLineageList.add(probDist3);
                    this.nodeProbDist[nodeRef3.getNumber()] = probDist3;
                    i2++;
                    if (i2 >= this.indices.length) {
                        i2 = 0;
                        break;
                    }
                }
                doubleValue = doubleValue2;
            }
            if (i3 > 0) {
                incrementActiveLineages(doubleValue2 - doubleValue);
                while (Math.abs(this.treeModel.getNodeHeight(this.nodes.get(this.indices[i2])) - doubleValue2) < 1.0E-9d) {
                    NodeRef nodeRef4 = this.nodes.get(this.indices[i2]);
                    if (this.treeModel.isExternal(nodeRef4)) {
                        ProbDist probDist4 = new ProbDist(this.demes, d3, nodeRef4);
                        probDist4.setIntervalType(IntervalType.SAMPLE);
                        probDist4.setStartLineageProb(patternList.getPattern(0)[patternList.getTaxonIndex(this.treeModel.getNodeTaxon(probDist4.node).getId())], 1.0d);
                        probDist4.computeEndLineageDensities(0.0d, null);
                        this.tempLineageList.add(probDist4);
                        this.nodeProbDist[nodeRef4.getNumber()] = probDist4;
                    } else {
                        if (this.treeModel.getChildCount(nodeRef4) > 2) {
                            throw new RuntimeException("Structured coalescent currently only allows strictly bifurcating trees.");
                        }
                        NodeRef child = this.treeModel.getChild(nodeRef4, 0);
                        NodeRef child2 = this.treeModel.getChild(nodeRef4, 1);
                        ProbDist probDist5 = this.nodeProbDist[child.getNumber()];
                        ProbDist probDist6 = this.nodeProbDist[child2.getNumber()];
                        this.tempLineageList.remove(probDist5);
                        this.tempLineageList.remove(probDist6);
                        ProbDist probDist7 = new ProbDist(this.demes, d3, nodeRef4, child, child2);
                        probDist7.setIntervalType(IntervalType.COALESCENT);
                        d += probDist7.computeCoalescedLineage(probDist5, probDist6);
                        if (!this.treeModel.isRoot(nodeRef4)) {
                            this.tempLineageList.add(probDist7);
                            this.nodeProbDist[nodeRef4.getNumber()] = probDist7;
                        }
                    }
                    i2++;
                    if (i2 >= this.indices.length) {
                        break;
                    }
                }
                doubleValue = doubleValue2;
            }
            if (doubleValue2 != 0.0d) {
                computeExpectedLineageCounts();
            }
            if (doubleValue2 != 0.0d) {
                d += computeLogLikelihood(d3);
            }
            this.activeLineageList.clear();
            Iterator<ProbDist> it = this.tempLineageList.iterator();
            while (it.hasNext()) {
                this.activeLineageList.add(it.next());
            }
        }
        this.finalCoalescentInterval = this.currentCoalescentInterval;
        this.currentCoalescentInterval = 0;
        this.matricesKnown = true;
        this.rateChanged = false;
        return d;
    }

    private double computeLogLikelihood(double d) {
        double d2 = 0.0d;
        double d3 = 0.0d;
        for (int i = 0; i < this.demes; i++) {
            double d4 = 0.0d;
            double d5 = 0.0d;
            Iterator<ProbDist> it = this.activeLineageList.iterator();
            while (it.hasNext()) {
                ProbDist next = it.next();
                d4 += next.getStartLineageProb(i) * next.getStartLineageProb(i);
                d5 += next.getEndLineageProb(i) * next.getEndLineageProb(i);
            }
            d2 += (1.0d / this.popSizes.getParameterValue(i)) * ((this.startExpected[i] * this.startExpected[i]) - d4);
            d3 += (1.0d / this.popSizes.getParameterValue(i)) * ((this.endExpected[i] * this.endExpected[i]) - d5);
        }
        return (d2 * ((-d) / 4.0d)) + (d3 * ((-d) / 4.0d));
    }

    private void computeExpectedLineageCounts() {
        for (int i = 0; i < this.demes; i++) {
            this.startExpected[i] = 0.0d;
            this.endExpected[i] = 0.0d;
            Iterator<ProbDist> it = this.activeLineageList.iterator();
            while (it.hasNext()) {
                ProbDist next = it.next();
                double[] dArr = this.startExpected;
                int i2 = i;
                dArr[i2] = dArr[i2] + next.getStartLineageProb(i);
                double[] dArr2 = this.endExpected;
                int i3 = i;
                dArr2[i3] = dArr2[i3] + next.getEndLineageProb(i);
            }
        }
    }

    private void incrementActiveLineages(double d) {
        double branchRate;
        synchronized (this.branchRateModel) {
            branchRate = this.branchRateModel.getBranchRate(this.treeModel, this.treeModel.getRoot());
        }
        if (!this.matricesKnown) {
            this.generalSubstitutionModel.getTransitionProbabilities(branchRate * d, this.migrationMatrices[this.currentCoalescentInterval]);
        }
        Iterator<ProbDist> it = this.activeLineageList.iterator();
        while (it.hasNext()) {
            it.next().incrementIntervalLength(d, this.migrationMatrices[this.currentCoalescentInterval]);
        }
        this.currentCoalescentInterval++;
    }

    private void collectAllTimes(Tree tree, NodeRef nodeRef, ArrayList<NodeRef> arrayList, ArrayList<ComparableDouble> arrayList2, ArrayList<Integer> arrayList3) {
        arrayList2.add(new ComparableDouble(tree.getNodeHeight(nodeRef)));
        arrayList.add(nodeRef);
        arrayList3.add(Integer.valueOf(tree.getChildCount(nodeRef)));
        for (int i = 0; i < tree.getChildCount(nodeRef); i++) {
            collectAllTimes(tree, tree.getChild(nodeRef, i), arrayList, arrayList2, arrayList3);
        }
    }

    private void updateTransitionProbabilities() {
        double branchRate;
        double[] dArr = new double[this.times.size() - 1];
        for (int i = 1; i < this.times.size(); i++) {
            dArr[i - 1] = this.times.get(this.indices[i]).doubleValue() - this.times.get(this.indices[i - 1]).doubleValue();
        }
        this.storedTimes.clear();
        this.storedNodes.clear();
        this.storedChildren.clear();
        collectAllTimes(this.treeModel, this.treeModel.getRoot(), this.storedNodes, this.storedTimes, this.storedChildren);
        this.storedIndices = new int[this.storedTimes.size()];
        HeapSort.sort(this.storedTimes, this.storedIndices);
        double[] dArr2 = new double[this.storedTimes.size() - 1];
        for (int i2 = 1; i2 < this.storedTimes.size(); i2++) {
            dArr2[i2 - 1] = this.storedTimes.get(this.storedIndices[i2]).doubleValue() - this.storedTimes.get(this.storedIndices[i2 - 1]).doubleValue();
        }
        synchronized (this.branchRateModel) {
            branchRate = this.branchRateModel.getBranchRate(this.treeModel, this.treeModel.getRoot());
        }
        int i3 = 0;
        if (dArr.length != dArr2.length) {
            throw new RuntimeException("Number of coalescent intervals has increased?");
        }
        for (int i4 = 0; i4 < dArr.length; i4++) {
            if (dArr[i4] != dArr2[i4] && dArr2[i4] != 0.0d) {
                this.generalSubstitutionModel.getTransitionProbabilities(branchRate * dArr2[i4], this.migrationMatrices[i3]);
                i3++;
            } else if (dArr2[i4] != 0.0d) {
                i3++;
            }
        }
        this.times = this.storedTimes;
        this.nodes = this.storedNodes;
        this.children = this.storedChildren;
        this.indices = this.storedIndices;
        this.matricesKnown = true;
    }

    @Override // dr.evomodel.coalescent.AbstractCoalescentLikelihood, dr.inference.model.AbstractModel
    protected void handleModelChangedEvent(Model model, Object obj, int i) {
        if (model == this.treeModel) {
            this.likelihoodKnown = false;
            this.matricesKnown = false;
            this.treeModelUpdateFired = true;
        } else if (model == this.branchRateModel) {
            this.likelihoodKnown = false;
            this.matricesKnown = false;
            this.rateChanged = true;
        } else {
            if (model != this.generalSubstitutionModel) {
                throw new RuntimeException("Unknown handleModelChangedEvent source, exiting.");
            }
            this.likelihoodKnown = false;
            this.matricesKnown = false;
            this.rateChanged = true;
        }
    }

    @Override // dr.evomodel.coalescent.AbstractCoalescentLikelihood, dr.inference.model.AbstractModel
    protected void handleVariableChangedEvent(Variable variable, int i, Variable.ChangeType changeType) {
        this.likelihoodKnown = false;
        this.matricesKnown = true;
    }

    @Override // dr.evomodel.coalescent.AbstractCoalescentLikelihood, dr.inference.model.AbstractModel
    protected void storeState() {
        for (int i = 0; i < this.finalCoalescentInterval; i++) {
            System.arraycopy(this.migrationMatrices[i], 0, this.storedMigrationMatrices[i], 0, this.demes * this.demes);
        }
        this.storedLikelihoodKnown = this.likelihoodKnown;
        this.storedLogLikelihood = this.logLikelihood;
    }

    @Override // dr.evomodel.coalescent.AbstractCoalescentLikelihood, dr.inference.model.AbstractModel
    protected void restoreState() {
        for (int i = 0; i < this.finalCoalescentInterval; i++) {
            double[] dArr = this.migrationMatrices[i];
            this.migrationMatrices[i] = this.storedMigrationMatrices[i];
            this.storedMigrationMatrices[i] = dArr;
        }
        this.likelihoodKnown = this.storedLikelihoodKnown;
        this.logLikelihood = this.storedLogLikelihood;
    }

    @Override // dr.evomodel.coalescent.AbstractCoalescentLikelihood, dr.inference.model.Likelihood
    public void makeDirty() {
        this.likelihoodKnown = false;
        this.matricesKnown = false;
    }

    @Override // dr.evolution.util.Units
    public final void setUnits(Units.Type type) {
        this.treeModel.setUnits(type);
    }

    @Override // dr.evolution.util.Units
    public final Units.Type getUnits() {
        return this.treeModel.getUnits();
    }

    @Override // dr.util.Citable
    public Citation.Category getCategory() {
        return Citation.Category.TREE_PRIORS;
    }

    @Override // dr.util.Citable
    public String getDescription() {
        return "Bayesian structured coalescent approximation";
    }

    @Override // dr.util.Citable
    public List<Citation> getCitations() {
        return Collections.singletonList(CITATION);
    }
}
