package dr.evomodel.tree;

import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evomodelxml.tree.UniformNodeHeightPriorParser;
import dr.evoxml.util.GraphMLUtils;
import dr.inference.model.AbstractModelLikelihood;
import dr.inference.model.Model;
import dr.inference.model.Variable;
import dr.math.LogTricks;
import dr.math.MathUtils;
import dr.math.Polynomial;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import java.util.TreeSet;
import java.util.logging.Logger;
import org.w3c.dom.Document;
import org.w3c.dom.Element;

/* loaded from: input_file:dr/evomodel/tree/UniformNodeHeightPrior.class */
public class UniformNodeHeightPrior extends AbstractModelLikelihood {
    public static final int MAX_ANALYTIC_TIPS = 60;
    public static final int DEFAULT_MC_SAMPLE = 100000;
    private static final double tolerance = 1.0E-6d;
    private int k;
    private double logFactorialK;
    private double maxRootHeight;
    private boolean isNicholls;
    private boolean useAnalytic;
    private boolean useMarginal;
    private boolean leadingTerm;
    private int mcSampleSize;
    Set<Double> tipDates;
    List<Double> reversedTipDateList;
    Map<Double, Integer> intervals;
    private static final double INV_PRECISION = 10.0d;
    Tree tree;
    double logLikelihood;
    private double storedLogLikelihood;
    boolean likelihoodKnown;
    private boolean storedLikelihoodKnown;
    private boolean treePolynomialKnown;
    private boolean storedTreePolynomialKnown;
    private Polynomial treePolynomial;
    private Polynomial[] treePolynomials;
    private Polynomial storedTreePolynomial;
    private double tmpLogLikelihood;
    private Polynomial.Type polynomialType;
    private double[] logLikelihoods;
    private double[][] drawNodeHeights;
    private double[] minNodeHeights;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:dr/evomodel/tree/UniformNodeHeightPrior$TipLabeledPolynomial.class */
    public class TipLabeledPolynomial extends Polynomial.Abstract {
        private double label;
        private Polynomial polynomial;
        private boolean isTip;

        TipLabeledPolynomial(double[] dArr, double d, Polynomial.Type type, boolean z) {
            switch (type) {
                case DOUBLE:
                    this.polynomial = new Polynomial.Double(dArr);
                    break;
                case LOG_DOUBLE:
                    this.polynomial = new Polynomial.LogDouble(dArr);
                    break;
                case BIG_DOUBLE:
                    this.polynomial = new Polynomial.BigDouble(dArr);
                    break;
                default:
                    throw new RuntimeException("Unknown polynomial type");
            }
            this.label = d;
            this.isTip = z;
        }

        TipLabeledPolynomial(Polynomial polynomial, double d, boolean z) {
            this.polynomial = polynomial;
            this.label = d;
            this.isTip = z;
        }

        @Override // dr.math.Polynomial
        public TipLabeledPolynomial copy() {
            return new TipLabeledPolynomial(this.polynomial.copy(), this.label, this.isTip);
        }

        @Override // dr.math.Polynomial.Abstract, dr.math.Polynomial
        public Polynomial getPolynomial() {
            return this.polynomial;
        }

        public TipLabeledPolynomial multiply(TipLabeledPolynomial tipLabeledPolynomial) {
            return new TipLabeledPolynomial(this.polynomial.multiply(tipLabeledPolynomial), Math.max(this.label, tipLabeledPolynomial.label), false);
        }

        @Override // dr.math.Polynomial.Abstract, dr.math.Polynomial
        public int getDegree() {
            return this.polynomial.getDegree();
        }

        @Override // dr.math.Polynomial.Abstract, dr.math.Polynomial
        public Polynomial multiply(Polynomial polynomial) {
            return this.polynomial.multiply(polynomial);
        }

        @Override // dr.math.Polynomial.Abstract, dr.math.Polynomial
        public Polynomial integrate() {
            return this.polynomial.integrate();
        }

        @Override // dr.math.Polynomial.Abstract, dr.math.Polynomial
        public void expand(double d) {
            this.polynomial.expand(d);
        }

        @Override // dr.math.Polynomial.Abstract, dr.math.Polynomial
        public double evaluate(double d) {
            return this.polynomial.evaluate(d);
        }

        @Override // dr.math.Polynomial.Abstract, dr.math.Polynomial
        public double logEvaluate(double d) {
            return this.polynomial.logEvaluate(d);
        }

        @Override // dr.math.Polynomial.Abstract, dr.math.Polynomial
        public double logEvaluateHorner(double d) {
            return this.polynomial.logEvaluateHorner(d);
        }

        @Override // dr.math.Polynomial.Abstract, dr.math.Polynomial
        public void setCoefficient(int i, double d) {
            this.polynomial.setCoefficient(i, d);
        }

        @Override // dr.math.Polynomial.Abstract, dr.math.Polynomial
        public TipLabeledPolynomial integrateWithLowerBound(double d) {
            return new TipLabeledPolynomial(this.polynomial.integrateWithLowerBound(d), this.label, this.isTip);
        }

        @Override // dr.math.Polynomial.Abstract, dr.math.Polynomial
        public double getCoefficient(int i) {
            return this.polynomial.getCoefficient(i);
        }

        @Override // dr.math.Polynomial.Abstract
        public String toString() {
            return this.polynomial.toString() + " {" + this.label + GraphMLUtils.END_SECTION;
        }

        @Override // dr.math.Polynomial.Abstract, dr.math.Polynomial
        public String getCoefficientString(int i) {
            return this.polynomial.getCoefficientString(i);
        }
    }

    public UniformNodeHeightPrior(Tree tree, boolean z, boolean z2, boolean z3) {
        this(UniformNodeHeightPriorParser.UNIFORM_NODE_HEIGHT_PRIOR, tree, z, DEFAULT_MC_SAMPLE, z2, z3);
    }

    public UniformNodeHeightPrior(Tree tree, boolean z, int i) {
        this(UniformNodeHeightPriorParser.UNIFORM_NODE_HEIGHT_PRIOR, tree, z, i, false, false);
    }

    private UniformNodeHeightPrior(String str, Tree tree, boolean z, int i, boolean z2, boolean z3) {
        super(str);
        this.k = 0;
        this.tipDates = new TreeSet();
        this.reversedTipDateList = new ArrayList();
        this.intervals = new TreeMap();
        this.tree = null;
        this.likelihoodKnown = false;
        this.storedLikelihoodKnown = false;
        this.treePolynomialKnown = false;
        this.storedTreePolynomialKnown = false;
        this.tree = tree;
        this.isNicholls = false;
        this.useAnalytic = z;
        this.useMarginal = z2;
        this.mcSampleSize = i;
        this.leadingTerm = z3;
        if (tree instanceof TreeModel) {
            addModel((TreeModel) tree);
        }
        for (int i2 = 0; i2 < tree.getExternalNodeCount(); i2++) {
            this.tipDates.add(Double.valueOf(tree.getNodeHeight(tree.getExternalNode(i2))));
        }
        if (this.tipDates.size() == 1 || z3) {
            this.k = tree.getInternalNodeCount() - 1;
            Logger.getLogger("dr.evomodel").info("Uniform Node Height Prior, Intervals = " + (this.k + 1));
            this.logFactorialK = logFactorial(this.k);
        } else {
            this.reversedTipDateList.addAll(this.tipDates);
            Collections.reverse(this.reversedTipDateList);
            double nodeHeight = tree.getNodeHeight(tree.getRoot());
            ArrayList arrayList = new ArrayList();
            for (Double d : this.reversedTipDateList) {
                if (nodeHeight - d.doubleValue() < 1.0E-6d) {
                    arrayList.add(Double.valueOf(nodeHeight));
                }
                nodeHeight = d.doubleValue();
            }
            Iterator it = arrayList.iterator();
            while (it.hasNext()) {
                this.reversedTipDateList.remove((Double) it.next());
            }
            if (!z) {
                this.logLikelihoods = new double[i];
                this.drawNodeHeights = new double[tree.getNodeCount()][i];
                this.minNodeHeights = new double[tree.getNodeCount()];
            }
        }
        if (tree.getExternalNodeCount() < 30) {
            this.polynomialType = Polynomial.Type.DOUBLE;
        } else if (tree.getExternalNodeCount() < 45) {
            this.polynomialType = Polynomial.Type.LOG_DOUBLE;
        } else {
            this.polynomialType = Polynomial.Type.LOG_DOUBLE;
        }
        Logger.getLogger("dr.evomodel").info("Using " + this.polynomialType + " polynomials!");
    }

    public UniformNodeHeightPrior(Tree tree, double d) {
        this(UniformNodeHeightPriorParser.UNIFORM_NODE_HEIGHT_PRIOR, tree, d);
    }

    private UniformNodeHeightPrior(String str, Tree tree, double d) {
        super(str);
        this.k = 0;
        this.tipDates = new TreeSet();
        this.reversedTipDateList = new ArrayList();
        this.intervals = new TreeMap();
        this.tree = null;
        this.likelihoodKnown = false;
        this.storedLikelihoodKnown = false;
        this.treePolynomialKnown = false;
        this.storedTreePolynomialKnown = false;
        this.tree = tree;
        this.maxRootHeight = d;
        this.isNicholls = true;
        if (tree instanceof TreeModel) {
            addModel((TreeModel) tree);
        }
    }

    UniformNodeHeightPrior(String str) {
        super(str);
        this.k = 0;
        this.tipDates = new TreeSet();
        this.reversedTipDateList = new ArrayList();
        this.intervals = new TreeMap();
        this.tree = null;
        this.likelihoodKnown = false;
        this.storedLikelihoodKnown = false;
        this.treePolynomialKnown = false;
        this.storedTreePolynomialKnown = false;
    }

    @Override // dr.inference.model.AbstractModel
    protected final void handleModelChangedEvent(Model model, Object obj, int i) {
        this.likelihoodKnown = false;
        this.treePolynomialKnown = false;
    }

    @Override // dr.inference.model.AbstractModel
    protected final void handleVariableChangedEvent(Variable variable, int i, Variable.ChangeType changeType) {
    }

    @Override // dr.inference.model.AbstractModel
    protected final void storeState() {
        this.storedLikelihoodKnown = this.likelihoodKnown;
        this.storedLogLikelihood = this.logLikelihood;
    }

    @Override // dr.inference.model.AbstractModel
    protected final void restoreState() {
        this.likelihoodKnown = this.storedLikelihoodKnown;
        this.logLikelihood = this.storedLogLikelihood;
    }

    @Override // dr.inference.model.AbstractModel
    protected final void acceptState() {
    }

    @Override // dr.inference.model.Likelihood
    public final Model getModel() {
        return this;
    }

    @Override // dr.inference.model.Likelihood
    public double getLogLikelihood() {
        if (!this.likelihoodKnown) {
            this.logLikelihood = calculateLogLikelihood();
            this.likelihoodKnown = true;
        }
        return this.logLikelihood;
    }

    @Override // dr.inference.model.Likelihood
    public final void makeDirty() {
        this.likelihoodKnown = false;
        this.treePolynomialKnown = false;
    }

    public double calculateLogLikelihood() {
        double log;
        double nodeHeight = this.tree.getNodeHeight(this.tree.getRoot());
        if (this.isNicholls) {
            int externalNodeCount = this.tree.getExternalNodeCount();
            if (nodeHeight < 0.0d || nodeHeight > 0.999d * this.maxRootHeight) {
                return Double.NEGATIVE_INFINITY;
            }
            return (nodeHeight * (2 - externalNodeCount)) - Math.log(this.maxRootHeight - nodeHeight);
        }
        if (this.k > 0) {
            log = this.logFactorialK - (this.k * Math.log(nodeHeight));
        } else if (!this.useAnalytic) {
            Arrays.fill(this.drawNodeHeights[this.tree.getRoot().getNumber()], nodeHeight);
            recursivelyFindNodeMinHeights(this.tree, this.tree.getRoot());
            Arrays.fill(this.logLikelihoods, 0.0d);
            recursivelyComputeMCIntegral(this.tree, this.tree.getRoot(), this.tree.getRoot().getNumber());
            log = (-LogTricks.logSum(this.logLikelihoods)) + Math.log(this.mcSampleSize);
        } else if (this.useMarginal) {
            if (!this.treePolynomialKnown) {
                this.treePolynomials = constructRootPolyonmials(this.tree, this.polynomialType);
                this.treePolynomialKnown = true;
            }
            log = (-this.treePolynomials[0].logEvaluate(nodeHeight)) - this.treePolynomials[1].logEvaluate(nodeHeight);
            if (Double.isNaN(log)) {
                log = (-this.treePolynomials[0].logEvaluateHorner(nodeHeight)) - this.treePolynomials[1].logEvaluateHorner(nodeHeight);
                if (Double.isNaN(log)) {
                    log = Double.NEGATIVE_INFINITY;
                }
            }
        } else {
            this.tmpLogLikelihood = 0.0d;
            recursivelyComputeDensity(this.tree, this.tree.getRoot(), 0.0d);
            log = this.tmpLogLikelihood;
        }
        if ($assertionsDisabled || !(Double.isInfinite(log) || Double.isNaN(log))) {
            return log;
        }
        throw new AssertionError();
    }

    private double recursivelyComputeDensity(Tree tree, NodeRef nodeRef, double d) {
        if (tree.isExternal(nodeRef)) {
            return tree.getNodeHeight(nodeRef);
        }
        double nodeHeight = tree.getNodeHeight(nodeRef);
        double recursivelyComputeDensity = recursivelyComputeDensity(tree, tree.getChild(nodeRef, 0), nodeHeight);
        double recursivelyComputeDensity2 = recursivelyComputeDensity(tree, tree.getChild(nodeRef, 1), nodeHeight);
        double d2 = recursivelyComputeDensity > recursivelyComputeDensity2 ? recursivelyComputeDensity : recursivelyComputeDensity2;
        if (!tree.isRoot(nodeRef)) {
            double d3 = d - d2;
            if (d3 <= 0.0d) {
                this.tmpLogLikelihood = Double.NEGATIVE_INFINITY;
            } else {
                this.tmpLogLikelihood -= Math.log(d3);
            }
        }
        return d2;
    }

    private double recursivelyFindNodeMinHeights(Tree tree, NodeRef nodeRef) {
        double d;
        if (tree.isExternal(nodeRef)) {
            d = tree.getNodeHeight(nodeRef);
        } else {
            double recursivelyFindNodeMinHeights = recursivelyFindNodeMinHeights(tree, tree.getChild(nodeRef, 0));
            double recursivelyFindNodeMinHeights2 = recursivelyFindNodeMinHeights(tree, tree.getChild(nodeRef, 1));
            d = recursivelyFindNodeMinHeights > recursivelyFindNodeMinHeights2 ? recursivelyFindNodeMinHeights : recursivelyFindNodeMinHeights2;
        }
        this.minNodeHeights[nodeRef.getNumber()] = d;
        return d;
    }

    private void recursivelyComputeMCIntegral(Tree tree, NodeRef nodeRef, int i) {
        if (tree.isExternal(nodeRef)) {
            return;
        }
        int number = nodeRef.getNumber();
        if (!tree.isRoot(nodeRef)) {
            double[] dArr = this.drawNodeHeights[i];
            double[] dArr2 = this.drawNodeHeights[number];
            double d = this.minNodeHeights[number];
            boolean z = tree.isExternal(tree.getChild(nodeRef, 0)) && tree.isExternal(tree.getChild(nodeRef, 1));
            int i2 = 0;
            while (true) {
                if (i2 >= this.mcSampleSize) {
                    break;
                }
                double d2 = dArr[i2] - d;
                if (d2 <= 0.0d) {
                    this.logLikelihoods[i2] = Double.NEGATIVE_INFINITY;
                    break;
                }
                if (!z) {
                    dArr2[i2] = (MathUtils.nextDouble() * d2) + d;
                }
                double[] dArr3 = this.logLikelihoods;
                int i3 = i2;
                dArr3[i3] = dArr3[i3] + Math.log(d2);
                i2++;
            }
        }
        recursivelyComputeMCIntegral(tree, tree.getChild(nodeRef, 0), number);
        recursivelyComputeMCIntegral(tree, tree.getChild(nodeRef, 1), number);
    }

    private static double round(double d) {
        return Math.round(d * 10.0d) / 10.0d;
    }

    private Polynomial[] constructRootPolyonmials(Tree tree, Polynomial.Type type) {
        NodeRef root = tree.getRoot();
        return new Polynomial[]{recursivelyComputePolynomial(tree, tree.getChild(root, 0), type).getPolynomial(), recursivelyComputePolynomial(tree, tree.getChild(root, 1), type).getPolynomial()};
    }

    private TipLabeledPolynomial recursivelyComputePolynomial(Tree tree, NodeRef nodeRef, Polynomial.Type type) {
        if (tree.isExternal(nodeRef)) {
            return new TipLabeledPolynomial(new double[]{1.0d}, round(tree.getNodeHeight(nodeRef)), type, true);
        }
        TipLabeledPolynomial multiply = recursivelyComputePolynomial(tree, tree.getChild(nodeRef, 0), type).multiply(recursivelyComputePolynomial(tree, tree.getChild(nodeRef, 1), type));
        if (!tree.isRoot(nodeRef)) {
            multiply = multiply.integrateWithLowerBound(multiply.label);
        }
        return multiply;
    }

    private double logFactorial(int i) {
        if (i == 0 || i == 1) {
            return 0.0d;
        }
        double d = 0.0d;
        for (int i2 = i; i2 > 1; i2--) {
            d += Math.log(i2);
        }
        return d;
    }

    @Override // dr.inference.model.AbstractModel
    public Element createElement(Document document) {
        throw new RuntimeException("createElement not implemented");
    }

    static {
        $assertionsDisabled = !UniformNodeHeightPrior.class.desiredAssertionStatus();
    }
}
