package dr.evomodel.continuous;

import dr.app.tools.GetNSCountsFromTrees;
import dr.evolution.tree.MutableTreeModel;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evomodel.branchratemodel.BranchRateModel;
import dr.evomodel.continuous.MissingTraits;
import dr.inference.loggers.LogColumn;
import dr.inference.model.AbstractModelLikelihood;
import dr.inference.model.CompoundParameter;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.math.distributions.MultivariateNormalDistribution;
import dr.math.distributions.WishartSufficientStatistics;
import dr.math.matrixAlgebra.Matrix;
import dr.math.matrixAlgebra.SymmetricMatrix;
import dr.math.matrixAlgebra.Vector;
import dr.util.Author;
import dr.util.Citation;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.BitSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/* loaded from: input_file:dr/evomodel/continuous/IntegratedMultivariateTraitLikelihood.class */
public abstract class IntegratedMultivariateTraitLikelihood extends AbstractMultivariateTraitLikelihood {
    protected final CacheHelper cacheHelper;
    protected boolean areStatesRedrawn;
    protected double[] meanCache;
    protected double[] correctedMeanCache;
    protected double[] upperPrecisionCache;
    protected double[] lowerPrecisionCache;
    private double[] logRemainderDensityCache;
    protected double[] storedMeanCache;
    private double[] storedUpperPrecisionCache;
    private double[] storedLowerPrecisionCache;
    private double[] storedLogRemainderDensityCache;
    protected double[] drawnStates;
    protected final boolean integrateRoot = true;
    private double[] zeroDimVector;
    protected WishartSufficientStatistics wishartStatistics;
    protected double[] Ay;
    protected double[][] tmpM;
    protected double[] tmp2;
    protected final MissingTraits missingTraits;
    protected Map<BitSet, RestrictedPartials> clampList;
    protected Map<NodeRef, RestrictedPartials> nodeToClampMap;
    private int partialsCount;
    private int spareIndex;
    protected boolean anyClamps;
    public static final double LOG_SQRT_2_PI = 0.5d * Math.log(6.283185307179586d);
    protected static boolean DEBUG = false;
    protected static boolean DEBUG_PREORDER = false;
    protected static boolean DEBUG_PNAS = false;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:dr/evomodel/continuous/IntegratedMultivariateTraitLikelihood$CacheHelper.class */
    public class CacheHelper {
        protected boolean cacheBranches;

        public CacheHelper(int i, boolean z) {
            IntegratedMultivariateTraitLikelihood.this.meanCache = new double[i];
            this.cacheBranches = z;
            if (z) {
                IntegratedMultivariateTraitLikelihood.this.storedMeanCache = new double[i];
            }
        }

        public double[] getShift(NodeRef nodeRef) {
            double[] dArr = new double[IntegratedMultivariateTraitLikelihood.this.dimTrait * IntegratedMultivariateTraitLikelihood.this.numData];
            for (int i = 0; i < IntegratedMultivariateTraitLikelihood.this.dim; i++) {
                dArr[i] = 0.0d;
            }
            return dArr;
        }

        public double[] getMeanCache() {
            return IntegratedMultivariateTraitLikelihood.this.meanCache;
        }

        public double[] getCorrectedMeanCache() {
            return IntegratedMultivariateTraitLikelihood.this.meanCache;
        }

        public void store() {
            if (IntegratedMultivariateTraitLikelihood.this.storedMeanCache.length != IntegratedMultivariateTraitLikelihood.this.meanCache.length) {
                IntegratedMultivariateTraitLikelihood.this.storedMeanCache = new double[IntegratedMultivariateTraitLikelihood.this.meanCache.length];
            }
            System.arraycopy(IntegratedMultivariateTraitLikelihood.this.meanCache, 0, IntegratedMultivariateTraitLikelihood.this.storedMeanCache, 0, IntegratedMultivariateTraitLikelihood.this.meanCache.length);
        }

        public void restore() {
            double[] dArr = IntegratedMultivariateTraitLikelihood.this.storedMeanCache;
            IntegratedMultivariateTraitLikelihood.this.storedMeanCache = IntegratedMultivariateTraitLikelihood.this.meanCache;
            IntegratedMultivariateTraitLikelihood.this.meanCache = dArr;
        }

        public double getOUFactor(NodeRef nodeRef) {
            return 1.0d;
        }

        public double getUpperPrecFactor(NodeRef nodeRef) {
            return 1.0d / IntegratedMultivariateTraitLikelihood.this.getRescaledBranchLengthForPrecision(nodeRef);
        }

        public void computeMeanCaches(int i, int i2, int i3, double d, double d2, double d3, MissingTraits missingTraits, NodeRef nodeRef, NodeRef nodeRef2, NodeRef nodeRef3) {
            if (d == 0.0d) {
                System.arraycopy(IntegratedMultivariateTraitLikelihood.this.zeroDimVector, 0, IntegratedMultivariateTraitLikelihood.this.meanCache, i, IntegratedMultivariateTraitLikelihood.this.dim);
            } else {
                missingTraits.computeWeightedAverage(IntegratedMultivariateTraitLikelihood.this.meanCache, i2, d2, i3, d3, i, IntegratedMultivariateTraitLikelihood.this.dim);
            }
        }

        public void setTipMeans(double[] dArr, int i, int i2, NodeRef nodeRef) {
            System.arraycopy(dArr, 0, IntegratedMultivariateTraitLikelihood.this.meanCache, i * i2, i);
        }

        public void setTipMeans(double[] dArr, int i, int i2) {
            System.arraycopy(dArr, 0, IntegratedMultivariateTraitLikelihood.this.meanCache, i * i2, i);
        }

        public void copyToMeanCache(double[] dArr, int i, int i2) {
            System.arraycopy(dArr, 0, IntegratedMultivariateTraitLikelihood.this.meanCache, i, i2);
        }

        public void setMeanCache(int i, double d) {
            IntegratedMultivariateTraitLikelihood.this.meanCache[i] = d;
        }
    }

    /* loaded from: input_file:dr/evomodel/continuous/IntegratedMultivariateTraitLikelihood$DriftCacheHelper.class */
    class DriftCacheHelper extends CacheHelper {
        public DriftCacheHelper(int i, boolean z) {
            super(i, z);
            IntegratedMultivariateTraitLikelihood.this.correctedMeanCache = new double[i];
        }

        @Override // dr.evomodel.continuous.IntegratedMultivariateTraitLikelihood.CacheHelper
        public double[] getShift(NodeRef nodeRef) {
            return IntegratedMultivariateTraitLikelihood.this.getShiftForBranchLength(nodeRef);
        }

        @Override // dr.evomodel.continuous.IntegratedMultivariateTraitLikelihood.CacheHelper
        public double[] getCorrectedMeanCache() {
            return IntegratedMultivariateTraitLikelihood.this.correctedMeanCache;
        }

        @Override // dr.evomodel.continuous.IntegratedMultivariateTraitLikelihood.CacheHelper
        public double getOUFactor(NodeRef nodeRef) {
            return 1.0d;
        }

        @Override // dr.evomodel.continuous.IntegratedMultivariateTraitLikelihood.CacheHelper
        public double getUpperPrecFactor(NodeRef nodeRef) {
            return 1.0d / IntegratedMultivariateTraitLikelihood.this.getRescaledBranchLengthForPrecision(nodeRef);
        }

        @Override // dr.evomodel.continuous.IntegratedMultivariateTraitLikelihood.CacheHelper
        public void setTipMeans(double[] dArr, int i, int i2, NodeRef nodeRef) {
            System.arraycopy(dArr, 0, IntegratedMultivariateTraitLikelihood.this.meanCache, i * i2, i);
        }

        @Override // dr.evomodel.continuous.IntegratedMultivariateTraitLikelihood.CacheHelper
        public void computeMeanCaches(int i, int i2, int i3, double d, double d2, double d3, MissingTraits missingTraits, NodeRef nodeRef, NodeRef nodeRef2, NodeRef nodeRef3) {
            if (d == 0.0d) {
                System.arraycopy(IntegratedMultivariateTraitLikelihood.this.zeroDimVector, 0, IntegratedMultivariateTraitLikelihood.this.meanCache, i, IntegratedMultivariateTraitLikelihood.this.dim);
            } else {
                IntegratedMultivariateTraitLikelihood.this.computeCorrectedWeightedAverage(i2, d2, nodeRef2, i3, d3, nodeRef3, i, IntegratedMultivariateTraitLikelihood.this.dim, nodeRef);
            }
        }
    }

    /* loaded from: input_file:dr/evomodel/continuous/IntegratedMultivariateTraitLikelihood$IntegratedDiffusionType.class */
    public enum IntegratedDiffusionType {
        PLAIN,
        SCALED,
        DRIFT,
        OU
    }

    /* loaded from: input_file:dr/evomodel/continuous/IntegratedMultivariateTraitLikelihood$OUCacheHelper.class */
    class OUCacheHelper extends CacheHelper {
        public OUCacheHelper(int i, boolean z) {
            super(i, z);
            IntegratedMultivariateTraitLikelihood.this.correctedMeanCache = new double[i];
        }

        @Override // dr.evomodel.continuous.IntegratedMultivariateTraitLikelihood.CacheHelper
        public double[] getCorrectedMeanCache() {
            return IntegratedMultivariateTraitLikelihood.this.correctedMeanCache;
        }

        @Override // dr.evomodel.continuous.IntegratedMultivariateTraitLikelihood.CacheHelper
        public double getOUFactor(NodeRef nodeRef) {
            return Math.exp(-IntegratedMultivariateTraitLikelihood.this.getTimeScaledSelection(nodeRef));
        }

        @Override // dr.evomodel.continuous.IntegratedMultivariateTraitLikelihood.CacheHelper
        public double getUpperPrecFactor(NodeRef nodeRef) {
            return (2.0d * IntegratedMultivariateTraitLikelihood.this.strengthOfSelection.getBranchRate(IntegratedMultivariateTraitLikelihood.this.treeModel, nodeRef)) / (1.0d - Math.exp((-2.0d) * IntegratedMultivariateTraitLikelihood.this.getTimeScaledSelection(nodeRef)));
        }

        @Override // dr.evomodel.continuous.IntegratedMultivariateTraitLikelihood.CacheHelper
        public void setTipMeans(double[] dArr, int i, int i2, NodeRef nodeRef) {
            System.arraycopy(dArr, 0, IntegratedMultivariateTraitLikelihood.this.meanCache, i * i2, i);
        }

        @Override // dr.evomodel.continuous.IntegratedMultivariateTraitLikelihood.CacheHelper
        public void computeMeanCaches(int i, int i2, int i3, double d, double d2, double d3, MissingTraits missingTraits, NodeRef nodeRef, NodeRef nodeRef2, NodeRef nodeRef3) {
            if (d == 0.0d) {
                System.arraycopy(IntegratedMultivariateTraitLikelihood.this.zeroDimVector, 0, IntegratedMultivariateTraitLikelihood.this.meanCache, i, IntegratedMultivariateTraitLikelihood.this.dim);
            } else {
                IntegratedMultivariateTraitLikelihood.this.computeCorrectedOUWeightedAverage(i2, d2, nodeRef2, i3, d3, nodeRef3, i, IntegratedMultivariateTraitLikelihood.this.dim, nodeRef);
            }
        }
    }

    /* loaded from: input_file:dr/evomodel/continuous/IntegratedMultivariateTraitLikelihood$StandarizedCacheHelper.class */
    class StandarizedCacheHelper extends CacheHelper {
        private final int dim;
        private final int nodeCount;

        public StandarizedCacheHelper(int i, int i2, boolean z) {
            super(i * i2, z);
            this.dim = i;
            this.nodeCount = i2;
        }

        @Override // dr.evomodel.continuous.IntegratedMultivariateTraitLikelihood.CacheHelper
        public void setTipMeans(double[] dArr, int i, int i2, NodeRef nodeRef) {
            for (int i3 = 0; i3 < i; i3++) {
                setMeanCache((i * i2) + i3, dArr[i3]);
            }
        }

        @Override // dr.evomodel.continuous.IntegratedMultivariateTraitLikelihood.CacheHelper
        public void setTipMeans(double[] dArr, int i, int i2) {
            for (int i3 = 0; i3 < i; i3++) {
                setMeanCache((i * i2) + i3, dArr[i3]);
            }
        }

        @Override // dr.evomodel.continuous.IntegratedMultivariateTraitLikelihood.CacheHelper
        public void setMeanCache(int i, double d) {
            int i2 = i % this.dim;
            IntegratedMultivariateTraitLikelihood.this.meanCache[i] = d;
        }
    }

    public IntegratedMultivariateTraitLikelihood(String str, MutableTreeModel mutableTreeModel, MultivariateDiffusionModel multivariateDiffusionModel, CompoundParameter compoundParameter, Parameter parameter, List<Integer> list, boolean z, boolean z2, boolean z3, BranchRateModel branchRateModel, List<BranchRateModel> list2, List<BranchRateModel> list3, BranchRateModel branchRateModel2, Model model, List<RestrictedPartials> list4, boolean z4, boolean z5) {
        super(str, mutableTreeModel, multivariateDiffusionModel, compoundParameter, parameter, list, z, z2, z3, branchRateModel, list2, list3, branchRateModel2, model, z4, z5);
        this.areStatesRedrawn = false;
        this.integrateRoot = true;
        this.clampList = null;
        this.nodeToClampMap = null;
        this.anyClamps = false;
        this.partialsCount = mutableTreeModel.getNodeCount();
        if (list4 != null) {
            for (RestrictedPartials restrictedPartials : list4) {
                restrictedPartials.setIndex(this.partialsCount);
                addRestrictedPartials(restrictedPartials);
                this.partialsCount++;
            }
            this.spareIndex = this.partialsCount;
            this.partialsCount++;
            setupClamps();
        }
        if (list2 != null) {
            this.cacheHelper = new DriftCacheHelper(this.dim * this.partialsCount, z);
        } else if (list3 != null) {
            this.cacheHelper = new OUCacheHelper(this.dim * this.partialsCount, z);
        } else {
            this.cacheHelper = new CacheHelper(this.dim * this.partialsCount, z);
        }
        this.drawnStates = new double[this.dim * this.partialsCount];
        this.upperPrecisionCache = new double[this.partialsCount];
        this.lowerPrecisionCache = new double[this.partialsCount];
        this.logRemainderDensityCache = new double[this.partialsCount];
        if (z) {
            this.storedUpperPrecisionCache = new double[this.partialsCount];
            this.storedLowerPrecisionCache = new double[this.partialsCount];
            this.storedLogRemainderDensityCache = new double[this.partialsCount];
        }
        this.Ay = new double[this.dimTrait];
        this.tmpM = new double[this.dimTrait][this.dimTrait];
        this.tmp2 = new double[this.dimTrait];
        this.zeroDimVector = new double[this.dim];
        this.missingTraits = new MissingTraits.CompletelyMissing(mutableTreeModel, list, this.dim);
        setTipDataValuesForAllNodes();
    }

    private void setTipDataValuesForAllNodes() {
        for (int i = 0; i < this.treeModel.getExternalNodeCount(); i++) {
            setTipDataValuesForNode(this.treeModel.getExternalNode(i));
        }
        this.missingTraits.handleMissingTips();
    }

    public double getTotalTreePrecision() {
        getLogLikelihood();
        return this.lowerPrecisionCache[this.treeModel.getRoot().getNumber()];
    }

    private void setTipDataValuesForNode(NodeRef nodeRef) {
        int number = nodeRef.getNumber();
        double[] parameterValues = this.traitParameter.getParameter(number).getParameterValues();
        if (parameterValues.length < this.dim) {
            throw new RuntimeException("The trait parameter for the tip with index, " + number + ", is too short");
        }
        this.cacheHelper.setTipMeans(parameterValues, this.dim, number, nodeRef);
    }

    public double[] getTipDataValues(int i) {
        double[] dArr = new double[this.dim];
        System.arraycopy(this.cacheHelper.getMeanCache(), this.dim * i, dArr, 0, this.dim);
        return dArr;
    }

    public void setTipDataValuesForNode(int i, double[] dArr) {
        this.cacheHelper.setTipMeans(dArr, this.dim, i);
        makeDirty();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // dr.evomodel.continuous.AbstractMultivariateTraitLikelihood
    public String extraInfo() {
        return "\tSample internal node traits: false\n";
    }

    @Override // dr.evomodel.continuous.AbstractMultivariateTraitLikelihood, dr.util.Citable
    public Citation.Category getCategory() {
        return Citation.Category.TRAIT_MODELS;
    }

    @Override // dr.evomodel.continuous.AbstractMultivariateTraitLikelihood, dr.util.Citable
    public String getDescription() {
        return super.getDescription() + " (first citation) with efficiently integrated internal traits (second citation)";
    }

    @Override // dr.evomodel.continuous.AbstractMultivariateTraitLikelihood, dr.util.Citable
    public List<Citation> getCitations() {
        ArrayList arrayList = new ArrayList(super.getCitations());
        arrayList.add(new Citation(new Author[]{new Author("OG", "Pybus"), new Author("MA", "Suchard"), new Author("P", "Lemey"), new Author("F", "Bernadin"), new Author("A", "Rambaut"), new Author("FW", "Crawford"), new Author("RR", "Gray"), new Author("N", "Arinaminpathy"), new Author(GetNSCountsFromTrees.totalcS, "Stramer"), new Author("MP", "Busch"), new Author("E", "Delwart")}, "Unifying the spatial epidemiology and evolution of emerging epidemics", 2012, "Proceedings of the National Academy of Sciences", 109, 15066, 15071, Citation.Status.PUBLISHED));
        return arrayList;
    }

    @Override // dr.evomodel.continuous.AbstractMultivariateTraitLikelihood
    public double getLogDataLikelihood() {
        return getLogLikelihood();
    }

    private void setupClamps() {
        if (this.nodeToClampMap == null) {
            this.nodeToClampMap = new HashMap();
        }
        this.nodeToClampMap.clear();
        recursiveSetupClamp(this.treeModel, this.treeModel.getRoot(), new BitSet());
        this.anyClamps = this.nodeToClampMap.size() > 0;
    }

    private void recursiveSetupClamp(Tree tree, NodeRef nodeRef, BitSet bitSet) {
        if (tree.isExternal(nodeRef)) {
            bitSet.set(nodeRef.getNumber());
            return;
        }
        for (int i = 0; i < tree.getChildCount(nodeRef); i++) {
            NodeRef child = tree.getChild(nodeRef, i);
            BitSet bitSet2 = new BitSet();
            recursiveSetupClamp(tree, child, bitSet2);
            bitSet.or(bitSet2);
        }
        if (this.clampList.containsKey(bitSet)) {
            RestrictedPartials restrictedPartials = this.clampList.get(bitSet);
            restrictedPartials.setNode(nodeRef);
            this.nodeToClampMap.put(nodeRef, restrictedPartials);
        }
    }

    public abstract boolean getComputeWishartSufficientStatistics();

    @Override // dr.evomodel.continuous.AbstractMultivariateTraitLikelihood
    public double calculateLogLikelihood() {
        if (this.updateRestrictedNodePartials) {
            if (this.clampList != null) {
                setupClamps();
            }
            this.updateRestrictedNodePartials = false;
        }
        double d = 0.0d;
        double[][] precisionmatrix = this.diffusionModel.getPrecisionmatrix();
        double log = Math.log(this.diffusionModel.getDeterminantPrecisionMatrix());
        double[] dArr = this.tmp2;
        boolean computeWishartSufficientStatistics = getComputeWishartSufficientStatistics();
        if (computeWishartSufficientStatistics) {
            this.wishartStatistics = new WishartSufficientStatistics(this.dimTrait);
        }
        postOrderTraverse(this.treeModel, this.treeModel.getRoot(), precisionmatrix, log, computeWishartSufficientStatistics);
        if (DEBUG) {
            System.err.println("mean: " + new Vector(this.cacheHelper.getMeanCache()));
            System.err.println("correctedMean: " + new Vector(this.cacheHelper.getCorrectedMeanCache()));
            System.err.println("upre: " + new Vector(this.upperPrecisionCache));
            System.err.println("lpre: " + new Vector(this.lowerPrecisionCache));
            System.err.println("cach: " + new Vector(this.logRemainderDensityCache));
        }
        int number = this.treeModel.getRoot().getNumber();
        double d2 = this.lowerPrecisionCache[number];
        for (int i = 0; i < this.numData; i++) {
            System.arraycopy(this.cacheHelper.getMeanCache(), (number * this.dim) + (i * this.dimTrait), dArr, 0, this.dimTrait);
            if (DEBUG) {
                System.err.println("Datum #" + i);
                System.err.println("root mean: " + new Vector(dArr));
                System.err.println("root prec: " + d2);
                System.err.println("diffusion prec: " + new Matrix(precisionmatrix));
            }
            double computeWeightedAverageAndSumOfSquares = computeWeightedAverageAndSumOfSquares(dArr, this.Ay, precisionmatrix, this.dimTrait, d2);
            double log2 = d2 != 0.0d ? 0.0d + ((-LOG_SQRT_2_PI) * this.dimTrait) + (0.5d * ((log + (this.dimTrait * Math.log(d2))) - computeWeightedAverageAndSumOfSquares)) : 0.0d;
            if (DEBUG) {
                double[][] dArr2 = new double[this.dimTrait][this.dimTrait];
                for (int i2 = 0; i2 < this.dimTrait; i2++) {
                    for (int i3 = 0; i3 < this.dimTrait; i3++) {
                        dArr2[i2][i3] = precisionmatrix[i2][i3] * d2;
                    }
                }
                System.err.println("Conditional root MVN precision = \n" + new Matrix(dArr2));
                System.err.println("Conditional root MVN density = " + MultivariateNormalDistribution.logPdf(dArr, new double[this.dimTrait], dArr2, Math.log(MultivariateNormalDistribution.calculatePrecisionMatrixDeterminate(dArr2)), 1.0d));
            }
            double integrateLogLikelihoodAtRoot = log2 + integrateLogLikelihoodAtRoot(dArr, this.Ay, this.tmpM, precisionmatrix, d2);
            if (DEBUG) {
                System.err.println("yAy = " + computeWeightedAverageAndSumOfSquares);
                System.err.println("logLikelihood (before remainders) = " + integrateLogLikelihoodAtRoot + " (should match conditional root MVN density when root not integrated out)");
            }
            d += integrateLogLikelihoodAtRoot;
        }
        double sumLogRemainders = d + sumLogRemainders();
        if (DEBUG) {
            System.out.println("logLikelihood is " + sumLogRemainders);
        }
        if (DEBUG) {
            System.err.println("logLikelihood (final) = " + sumLogRemainders);
        }
        if (DEBUG_PNAS) {
            checkLogLikelihood(sumLogRemainders, sumLogRemainders(), dArr, d2, precisionmatrix);
            for (int i4 = 0; i4 < this.logRemainderDensityCache.length; i4++) {
                if (this.logRemainderDensityCache[i4] < -1.0E10d) {
                    System.err.println(this.logRemainderDensityCache[i4] + " @ " + i4);
                }
            }
        }
        this.areStatesRedrawn = false;
        return sumLogRemainders;
    }

    protected void checkLogLikelihood(double d, double d2, double[] dArr, double d3, double[][] dArr2) {
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // dr.evomodel.continuous.AbstractMultivariateTraitLikelihood, dr.inference.model.AbstractModel
    public void handleVariableChangedEvent(Variable variable, int i, Variable.ChangeType changeType) {
        if (variable == this.traitParameter) {
            if (i > this.dimTrait * this.numData * this.treeModel.getExternalNodeCount()) {
                throw new RuntimeException("Attempting to update an invalid index");
            }
            if (i != -1) {
                this.cacheHelper.setMeanCache(i, this.traitParameter.getValue(i).doubleValue());
            } else {
                for (int i2 = 0; i2 < this.traitParameter.getDimension(); i2++) {
                    this.cacheHelper.setMeanCache(i2, this.traitParameter.getValue(i2).doubleValue());
                }
            }
            this.likelihoodKnown = false;
        }
        super.handleVariableChangedEvent(variable, i, changeType);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static double computeWeightedAverageAndSumOfSquares(double[] dArr, double[] dArr2, double[][] dArr3, int i, double d) {
        double d2 = 0.0d;
        for (int i2 = 0; i2 < i; i2++) {
            dArr2[i2] = 0.0d;
            for (int i3 = 0; i3 < i; i3++) {
                int i4 = i2;
                dArr2[i4] = dArr2[i4] + (dArr3[i2][i3] * dArr[i3] * d);
            }
            d2 += dArr[i2] * dArr2[i2];
        }
        return d2;
    }

    private double sumLogRemainders() {
        double d = 0.0d;
        for (double d2 : this.logRemainderDensityCache) {
            d += d2;
        }
        return d;
    }

    protected abstract double integrateLogLikelihoodAtRoot(double[] dArr, double[] dArr2, double[][] dArr3, double[][] dArr4, double d);

    @Override // dr.evomodel.continuous.AbstractMultivariateTraitLikelihood, dr.inference.model.Likelihood
    public void makeDirty() {
        super.makeDirty();
        this.areStatesRedrawn = false;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void postOrderTraverse(MutableTreeModel mutableTreeModel, NodeRef nodeRef, double[][] dArr, double d, boolean z) {
        int number = nodeRef.getNumber();
        if (mutableTreeModel.isExternal(nodeRef)) {
            if (this.missingTraits.isCompletelyMissing(number)) {
                this.upperPrecisionCache[number] = 0.0d;
                this.lowerPrecisionCache[number] = 0.0d;
                return;
            } else {
                this.upperPrecisionCache[number] = this.cacheHelper.getUpperPrecFactor(nodeRef) * Math.pow(this.cacheHelper.getOUFactor(nodeRef), 2.0d);
                this.lowerPrecisionCache[number] = Double.POSITIVE_INFINITY;
                return;
            }
        }
        NodeRef child = mutableTreeModel.getChild(nodeRef, 0);
        NodeRef child2 = mutableTreeModel.getChild(nodeRef, 1);
        postOrderTraverse(mutableTreeModel, child, dArr, d, z);
        postOrderTraverse(mutableTreeModel, child2, dArr, d, z);
        int number2 = child.getNumber();
        int number3 = child2.getNumber();
        int i = this.dim * number2;
        int i2 = this.dim * number3;
        int i3 = this.dim * number;
        double d2 = this.upperPrecisionCache[number2];
        double d3 = this.upperPrecisionCache[number3];
        doPeel(number, i3, i, i2, d2 + d3, d2, d3, this.missingTraits, number, dArr, d, this.cacheHelper.getOUFactor(child), this.cacheHelper.getOUFactor(child2), z, nodeRef, child, child2, true, false);
        if (this.nodeToClampMap == null || !this.nodeToClampMap.containsKey(nodeRef)) {
            return;
        }
        RestrictedPartials restrictedPartials = this.nodeToClampMap.get(nodeRef);
        int index = restrictedPartials.getIndex();
        int i4 = this.dim * index;
        int i5 = this.dim * this.spareIndex;
        for (int i6 = 0; i6 < this.dim; i6++) {
            this.meanCache[i4 + i6] = restrictedPartials.getPartial(i6);
        }
        double d4 = this.lowerPrecisionCache[number];
        double priorSampleSize = restrictedPartials.getPriorSampleSize() / rescaleLength(1.0d);
        doPeel(this.spareIndex, i5, i3, i4, d4 + priorSampleSize, d4, priorSampleSize, this.missingTraits, index, dArr, d, 1.0d, 1.0d, z, nodeRef, null, null, true, false);
        this.lowerPrecisionCache[number] = this.lowerPrecisionCache[this.spareIndex];
        this.upperPrecisionCache[number] = this.upperPrecisionCache[this.spareIndex];
        for (int i7 = 0; i7 < this.dim; i7++) {
            this.meanCache[i3 + i7] = this.meanCache[i5 + i7];
        }
    }

    private void doPeel(int i, int i2, int i3, int i4, double d, double d2, double d3, MissingTraits missingTraits, int i5, double[][] dArr, double d4, double d5, double d6, boolean z, NodeRef nodeRef, NodeRef nodeRef2, NodeRef nodeRef3, boolean z2, boolean z3) {
        this.lowerPrecisionCache[i] = d;
        this.cacheHelper.computeMeanCaches(i2, i3, i4, d, d2, d3, missingTraits, nodeRef, nodeRef2, nodeRef3);
        if (!this.treeModel.isRoot(nodeRef)) {
            double upperPrecFactor = this.cacheHelper.getUpperPrecFactor(nodeRef);
            if (Double.isInfinite(upperPrecFactor)) {
                this.upperPrecisionCache[i] = d;
            } else {
                this.upperPrecisionCache[i] = ((d * upperPrecFactor) / (d + upperPrecFactor)) * Math.pow(this.cacheHelper.getOUFactor(nodeRef), 2.0d);
            }
        }
        this.logRemainderDensityCache[i5] = 0.0d;
        if (d2 == 0.0d || d3 == 0.0d || !z2) {
            return;
        }
        incrementRemainderDensities(dArr, d4, i5, i2, i3, i4, d2, d3, d5, d6, z);
    }

    private void incrementRemainderDensities(double[][] dArr, double d, int i, int i2, int i3, int i4, double d2, double d3, double d4, double d5, boolean z) {
        double d6 = (d2 * d3) / (d2 + d3);
        if (z) {
            incrementOuterProducts(i2, i3, i4, d2, d3);
        }
        for (int i5 = 0; i5 < this.numData; i5++) {
            double d7 = 0.0d;
            double d8 = 0.0d;
            double d9 = 0.0d;
            for (int i6 = 0; i6 < this.dimTrait; i6++) {
                double d10 = this.cacheHelper.getCorrectedMeanCache()[i3 + (i5 * this.dimTrait) + i6] * d2;
                double d11 = this.cacheHelper.getCorrectedMeanCache()[i4 + (i5 * this.dimTrait) + i6] * d3;
                for (int i7 = 0; i7 < this.dimTrait; i7++) {
                    double d12 = this.cacheHelper.getCorrectedMeanCache()[i3 + (i5 * this.dimTrait) + i7];
                    double d13 = this.cacheHelper.getCorrectedMeanCache()[i4 + (i5 * this.dimTrait) + i7];
                    d7 += d10 * dArr[i6][i7] * d12;
                    d8 += d11 * dArr[i6][i7] * d13;
                    d9 += (d10 + d11) * dArr[i6][i7] * this.cacheHelper.getMeanCache()[i2 + (i5 * this.dimTrait) + i7];
                }
            }
            double[] dArr2 = this.logRemainderDensityCache;
            dArr2[i] = dArr2[i] + (((((-this.dimTrait) * LOG_SQRT_2_PI) + (0.5d * ((this.dimTrait * Math.log(d6)) + d))) - (0.5d * ((d7 + d8) - d9))) - (this.dimTrait * (Math.log(d4) + Math.log(d5))));
            if (DEBUG && this.logRemainderDensityCache[i] > 100.0d) {
                System.err.println(i);
                System.err.println(this.logRemainderDensityCache[i]);
                System.err.println("rP = " + d6);
                System.err.println("p0 = " + d2);
                System.err.println("p1 = " + d3 + "\n");
                System.err.println(new Matrix(dArr));
                System.err.println(d7);
                System.err.println(d8);
                System.err.println(d9);
                for (int i8 = 0; i8 < this.dimTrait; i8++) {
                    System.err.println("\t" + this.cacheHelper.getCorrectedMeanCache()[i3 + (0 * this.dimTrait) + i8] + " " + this.cacheHelper.getCorrectedMeanCache()[i4 + (0 * this.dimTrait) + i8]);
                }
                System.exit(-1);
            }
        }
    }

    private void incrementOuterProducts(int i, int i2, int i3, double d, double d2) {
        double[] scaleMatrix = this.wishartStatistics.getScaleMatrix();
        if (d == 0.0d || d2 == 0.0d) {
            System.err.println("ZERO PRECISION");
        }
        if (d < 1.0E-16d || d2 < 1.0E-16d) {
            System.err.println("LOW PRECISION");
        }
        double d3 = (d * d2) / (d + d2);
        for (int i4 = 0; i4 < this.numData; i4++) {
            for (int i5 = 0; i5 < this.dimTrait; i5++) {
                double d4 = this.cacheHelper.getCorrectedMeanCache()[i2 + (i4 * this.dimTrait) + i5];
                double d5 = this.cacheHelper.getCorrectedMeanCache()[i3 + (i4 * this.dimTrait) + i5];
                for (int i6 = 0; i6 < this.dimTrait; i6++) {
                    double d6 = (d4 - d5) * (this.cacheHelper.getCorrectedMeanCache()[(i2 + (i4 * this.dimTrait)) + i6] - this.cacheHelper.getCorrectedMeanCache()[(i3 + (i4 * this.dimTrait)) + i6]) * d3;
                    int i7 = (i5 * this.dimTrait) + i6;
                    scaleMatrix[i7] = scaleMatrix[i7] + d6;
                }
            }
        }
        this.wishartStatistics.incrementDf(this.numData);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // dr.evomodel.continuous.AbstractMultivariateTraitLikelihood
    public double[] getRootNodeTrait() {
        return getTraitForNode(this.treeModel, this.treeModel.getRoot(), this.traitName);
    }

    @Override // dr.evomodel.continuous.AbstractMultivariateTraitLikelihood
    public double[] getTraitForNode(Tree tree, NodeRef nodeRef, String str) {
        getLogLikelihood();
        if (!this.areStatesRedrawn) {
            redrawAncestralStates();
        }
        int number = nodeRef.getNumber();
        double[] dArr = new double[this.dim];
        System.arraycopy(this.drawnStates, number * this.dim, dArr, 0, this.dim);
        return dArr;
    }

    public void redrawAncestralStates() {
        double[][] precisionmatrix = this.diffusionModel.getPrecisionmatrix();
        preOrderTraverseSample(this.treeModel, this.treeModel.getRoot(), 0, precisionmatrix, new SymmetricMatrix(precisionmatrix).inverse().toComponents());
        if (DEBUG) {
            System.err.println("all draws = " + new Vector(this.drawnStates));
        }
        this.areStatesRedrawn = true;
    }

    @Override // dr.evomodel.continuous.AbstractMultivariateTraitLikelihood, dr.inference.model.AbstractModel
    public void storeState() {
        super.storeState();
        if (this.cacheBranches) {
            this.cacheHelper.store();
            System.arraycopy(this.upperPrecisionCache, 0, this.storedUpperPrecisionCache, 0, this.upperPrecisionCache.length);
            System.arraycopy(this.lowerPrecisionCache, 0, this.storedLowerPrecisionCache, 0, this.lowerPrecisionCache.length);
            System.arraycopy(this.logRemainderDensityCache, 0, this.storedLogRemainderDensityCache, 0, this.logRemainderDensityCache.length);
        }
    }

    @Override // dr.evomodel.continuous.AbstractMultivariateTraitLikelihood, dr.inference.model.AbstractModel
    public void restoreState() {
        super.restoreState();
        if (this.cacheBranches) {
            this.cacheHelper.restore();
            double[] dArr = this.storedUpperPrecisionCache;
            this.storedUpperPrecisionCache = this.upperPrecisionCache;
            this.upperPrecisionCache = dArr;
            double[] dArr2 = this.storedLowerPrecisionCache;
            this.storedLowerPrecisionCache = this.lowerPrecisionCache;
            this.lowerPrecisionCache = dArr2;
            double[] dArr3 = this.storedLogRemainderDensityCache;
            this.storedLogRemainderDensityCache = this.logRemainderDensityCache;
            this.logRemainderDensityCache = dArr3;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static double computeQuadraticProduct(double[] dArr, double[][] dArr2, double[] dArr3, int i) {
        double d = 0.0d;
        for (int i2 = 0; i2 < i; i2++) {
            for (int i3 = 0; i3 < i; i3++) {
                d += dArr[i2] * dArr2[i2][i3] * dArr3[i3];
            }
        }
        return d;
    }

    public static void computeWeightedAverage(double[] dArr, int i, double d, double[] dArr2, int i2, double d2, double[] dArr3, int i3, int i4) {
        for (int i5 = 0; i5 < i4; i5++) {
            dArr3[i3 + i5] = ((dArr[i + i5] * d) + (dArr2[i2 + i5] * d2)) / (d + d2);
        }
    }

    protected void computeCorrectedWeightedAverage(int i, double d, NodeRef nodeRef, int i2, double d2, NodeRef nodeRef2, int i3, int i4, NodeRef nodeRef3) {
        double d3 = 1.0d / (d + d2);
        double[] shiftForBranchLength = !this.treeModel.isRoot(nodeRef3) ? getShiftForBranchLength(nodeRef3) : null;
        double[] shiftForBranchLength2 = getShiftForBranchLength(nodeRef);
        double[] shiftForBranchLength3 = getShiftForBranchLength(nodeRef2);
        if (this.treeModel.isExternal(nodeRef)) {
            for (int i5 = 0; i5 < i4; i5++) {
                this.correctedMeanCache[i + i5] = this.meanCache[i + i5] - shiftForBranchLength2[i5];
            }
        }
        if (this.treeModel.isExternal(nodeRef2)) {
            for (int i6 = 0; i6 < i4; i6++) {
                this.correctedMeanCache[i2 + i6] = this.meanCache[i2 + i6] - shiftForBranchLength3[i6];
            }
        }
        for (int i7 = 0; i7 < i4; i7++) {
            this.meanCache[i3 + i7] = ((this.correctedMeanCache[i + i7] * d) + (this.correctedMeanCache[i2 + i7] * d2)) * d3;
            if (this.treeModel.isRoot(nodeRef3)) {
                this.correctedMeanCache[i3 + i7] = this.meanCache[i3 + i7];
            } else {
                this.correctedMeanCache[i3 + i7] = this.meanCache[i3 + i7] - shiftForBranchLength[i7];
            }
        }
    }

    protected void computeCorrectedOUWeightedAverage(int i, double d, NodeRef nodeRef, int i2, double d2, NodeRef nodeRef2, int i3, int i4, NodeRef nodeRef3) {
        double[] dArr;
        double d3;
        double d4 = 1.0d / (d + d2);
        if (this.treeModel.isRoot(nodeRef3)) {
            dArr = null;
            d3 = 1.0d;
        } else {
            dArr = getOptimalValue(nodeRef3);
            d3 = getTimeScaledSelection(nodeRef3);
        }
        double[] optimalValue = getOptimalValue(nodeRef);
        double[] optimalValue2 = getOptimalValue(nodeRef2);
        double timeScaledSelection = getTimeScaledSelection(nodeRef);
        double timeScaledSelection2 = getTimeScaledSelection(nodeRef2);
        if (this.treeModel.isExternal(nodeRef)) {
            for (int i5 = 0; i5 < i4; i5++) {
                this.correctedMeanCache[i + i5] = (Math.exp(timeScaledSelection) * this.meanCache[i + i5]) - ((Math.exp(timeScaledSelection) - 1.0d) * optimalValue[i5]);
            }
        }
        if (this.treeModel.isExternal(nodeRef2)) {
            for (int i6 = 0; i6 < i4; i6++) {
                this.correctedMeanCache[i2 + i6] = (Math.exp(timeScaledSelection2) * this.meanCache[i2 + i6]) - ((Math.exp(timeScaledSelection2) - 1.0d) * optimalValue2[i6]);
            }
        }
        for (int i7 = 0; i7 < i4; i7++) {
            this.meanCache[i3 + i7] = ((this.correctedMeanCache[i + i7] * d) + (this.correctedMeanCache[i2 + i7] * d2)) * d4;
            if (this.treeModel.isRoot(nodeRef3)) {
                this.correctedMeanCache[i3 + i7] = this.meanCache[i3 + i7];
            } else {
                this.correctedMeanCache[i3 + i7] = (Math.exp(d3) * this.meanCache[i3 + i7]) - ((Math.exp(d3) - 1.0d) * dArr[i7]);
            }
        }
    }

    protected abstract double[][] computeMarginalRootMeanAndVariance(double[] dArr, double[][] dArr2, double[][] dArr3, double d);

    private void preOrderTraverseSample(MutableTreeModel mutableTreeModel, NodeRef nodeRef, int i, double[][] dArr, double[][] dArr2) {
        int number = nodeRef.getNumber();
        if (mutableTreeModel.isRoot(nodeRef)) {
            double[] dArr3 = new double[this.dimTrait];
            int number2 = mutableTreeModel.getRoot().getNumber();
            double d = this.lowerPrecisionCache[number2];
            for (int i2 = 0; i2 < this.numData; i2++) {
                System.arraycopy(this.cacheHelper.getMeanCache(), (number * this.dim) + (i2 * this.dimTrait), dArr3, 0, this.dimTrait);
                double[][] computeMarginalRootMeanAndVariance = computeMarginalRootMeanAndVariance(dArr3, dArr, dArr2, d);
                double[] nextMultivariateNormalVariance = MultivariateNormalDistribution.nextMultivariateNormalVariance(dArr3, computeMarginalRootMeanAndVariance);
                if (DEBUG_PREORDER) {
                    Arrays.fill(nextMultivariateNormalVariance, 1.0d);
                }
                System.arraycopy(nextMultivariateNormalVariance, 0, this.drawnStates, (number2 * this.dim) + (i2 * this.dimTrait), this.dimTrait);
                if (DEBUG) {
                    System.err.println("Root mean: " + new Vector(dArr3));
                    System.err.println("Root var : " + new Matrix(computeMarginalRootMeanAndVariance));
                    System.err.println("Root draw: " + new Vector(nextMultivariateNormalVariance));
                }
            }
        } else if (this.missingTraits.isCompletelyMissing(number) || this.missingTraits.isPartiallyMissing(number)) {
            if (this.missingTraits.isPartiallyMissing(number)) {
                throw new RuntimeException("Partially missing values are not yet implemented");
            }
            double rescaledBranchLengthForPrecision = 1.0d / getRescaledBranchLengthForPrecision(nodeRef);
            double d2 = this.lowerPrecisionCache[number];
            double d3 = d2 + rescaledBranchLengthForPrecision;
            double[] dArr4 = this.Ay;
            double[][] dArr5 = this.tmpM;
            for (int i3 = 0; i3 < this.numData; i3++) {
                int i4 = (i * this.dim) + (i3 * this.dimTrait);
                int i5 = (number * this.dim) + (i3 * this.dimTrait);
                if (DEBUG) {
                    double[] dArr6 = new double[this.dimTrait];
                    System.arraycopy(this.drawnStates, i4, dArr6, 0, this.dimTrait);
                    System.err.println("Parent draw: " + new Vector(dArr6));
                    if (dArr6[0] != this.drawnStates[i4]) {
                        throw new RuntimeException("Error in setting indices");
                    }
                }
                for (int i6 = 0; i6 < this.dimTrait; i6++) {
                    dArr4[i6] = (((this.drawnStates[i4 + i6] + this.cacheHelper.getShift(nodeRef)[i6]) * rescaledBranchLengthForPrecision) + (this.cacheHelper.getMeanCache()[i5 + i6] * d2)) / d3;
                    for (int i7 = 0; i7 < this.dimTrait; i7++) {
                        dArr5[i6][i7] = dArr2[i6][i7] / d3;
                    }
                }
                double[] nextMultivariateNormalVariance2 = MultivariateNormalDistribution.nextMultivariateNormalVariance(dArr4, dArr5);
                System.arraycopy(nextMultivariateNormalVariance2, 0, this.drawnStates, i5, this.dimTrait);
                if (DEBUG) {
                    System.err.println("Int prec: " + d3);
                    System.err.println("Int mean: " + new Vector(dArr4));
                    System.err.println("Int var : " + new Matrix(dArr5));
                    System.err.println("Int draw: " + new Vector(nextMultivariateNormalVariance2));
                    System.err.println("");
                }
            }
        } else {
            System.arraycopy(this.cacheHelper.getMeanCache(), number * this.dim, this.drawnStates, number * this.dim, this.dim);
        }
        if (!peel() || mutableTreeModel.isExternal(nodeRef)) {
            return;
        }
        preOrderTraverseSample(mutableTreeModel, mutableTreeModel.getChild(nodeRef, 0), number, dArr, dArr2);
        preOrderTraverseSample(mutableTreeModel, mutableTreeModel.getChild(nodeRef, 1), number, dArr, dArr2);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // dr.evomodel.continuous.AbstractMultivariateTraitLikelihood, dr.inference.model.AbstractModel
    public void handleModelChangedEvent(Model model, Object obj, int i) {
        if (this.driftModels != null && this.driftModels.contains(model)) {
            if (this.cacheBranches) {
                updateAllNodes();
                return;
            } else {
                this.likelihoodKnown = false;
                return;
            }
        }
        if (this.optimalValues != null && this.optimalValues.contains(model)) {
            if (this.cacheBranches) {
                updateAllNodes();
                return;
            } else {
                this.likelihoodKnown = false;
                return;
            }
        }
        if (this.strengthOfSelection == null) {
            super.handleModelChangedEvent(model, obj, i);
        } else if (this.cacheBranches) {
            updateAllNodes();
        } else {
            this.likelihoodKnown = false;
        }
    }

    protected boolean peel() {
        return true;
    }

    @Override // dr.evomodel.continuous.AbstractMultivariateTraitLikelihood, dr.inference.model.AbstractModelLikelihood, dr.inference.loggers.Loggable
    public LogColumn[] getColumns() {
        return new LogColumn[]{new AbstractModelLikelihood.LikelihoodColumn(getId())};
    }

    private CacheHelper createCacheHelper(IntegratedDiffusionType integratedDiffusionType, int i, boolean z) {
        CacheHelper cacheHelper = null;
        switch (integratedDiffusionType) {
            case PLAIN:
                cacheHelper = new CacheHelper(i, z);
                break;
            case SCALED:
                cacheHelper = new CacheHelper(i, z);
                break;
            case DRIFT:
                cacheHelper = new DriftCacheHelper(i, z);
                break;
            case OU:
                cacheHelper = new OUCacheHelper(i, z);
                break;
        }
        return cacheHelper;
    }

    @Override // dr.evomodel.continuous.AbstractMultivariateTraitLikelihood
    protected void addRestrictedPartials(RestrictedPartials restrictedPartials) {
        if (this.clampList == null) {
            this.clampList = new HashMap();
        }
        this.clampList.put(restrictedPartials.getTipBitSet(), restrictedPartials);
        addModel(restrictedPartials);
        System.err.println("Added a CLAMP!");
    }
}
