package dr.evomodel.continuous;

import dr.evolution.tree.MutableTreeModel;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evomodel.branchratemodel.BranchRateModel;
import dr.inference.model.CompoundParameter;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.math.distributions.WishartSufficientStatistics;
import dr.util.Citable;
import java.util.List;
import java.util.logging.Logger;

/* loaded from: input_file:dr/evomodel/continuous/NonPhylogeneticMultivariateTraitLikelihood.class */
public class NonPhylogeneticMultivariateTraitLikelihood extends FullyConjugateMultivariateTraitLikelihood {
    private final boolean exchangeableTips;
    private final int zeroHeightTip;
    private static final boolean DEBUG_NO_TREE = false;
    private static final boolean NO_RESCALING = false;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:dr/evomodel/continuous/NonPhylogeneticMultivariateTraitLikelihood$SufficientStatistics.class */
    public class SufficientStatistics {
        double sumWeight;
        double productWeight;
        double innerProduct;
        int nonMissingTips;

        SufficientStatistics(double d, double d2, double d3, int i) {
            this.sumWeight = d;
            this.productWeight = d2;
            this.innerProduct = d3;
            this.nonMissingTips = i;
        }
    }

    public NonPhylogeneticMultivariateTraitLikelihood(String str, MutableTreeModel mutableTreeModel, MultivariateDiffusionModel multivariateDiffusionModel, CompoundParameter compoundParameter, Parameter parameter, List<Integer> list, boolean z, boolean z2, boolean z3, BranchRateModel branchRateModel, Model model, boolean z4, double[] dArr, double d, List<RestrictedPartials> list2, boolean z5, boolean z6) {
        super(str, mutableTreeModel, multivariateDiffusionModel, compoundParameter, parameter, list, z, z2, z3, branchRateModel, null, null, null, model, z4, dArr, list2, d, z5);
        this.exchangeableTips = z6;
        this.zeroHeightTip = findZeroHeightTip(mutableTreeModel);
        printInformtion2();
    }

    private int findZeroHeightTip(Tree tree) {
        for (int i = 0; i < tree.getExternalNodeCount(); i++) {
            if (tree.getNodeHeight(tree.getExternalNode(i)) == 0.0d) {
                return i;
            }
        }
        return -1;
    }

    @Override // dr.evomodel.continuous.AbstractMultivariateTraitLikelihood
    protected void printInformtion() {
    }

    protected void printInformtion2() {
        StringBuilder sb = new StringBuilder("Creating non-phylogenetic multivariate diffusion model:\n");
        sb.append("\tTrait: ").append(this.traitName).append("\n");
        sb.append("\tDiffusion process: ").append(this.diffusionModel.getId()).append("\n");
        sb.append("\tExchangeable tips: ").append(this.exchangeableTips ? "yes" : "no");
        if (this.exchangeableTips) {
            sb.append(" initial inverse-weight = ").append(1.0d / getLengthToRoot(this.treeModel.getExternalNode(0)));
        }
        sb.append("\n");
        sb.append(extraInfo());
        sb.append("\tPlease cite:\n");
        sb.append(Citable.Utils.getCitationString(this));
        sb.append("\n\tDiffusion dimension   : ").append(this.dimTrait).append("\n");
        sb.append("\tNumber of observations: ").append(this.numData).append("\n");
        Logger.getLogger("dr.evomodel").info(sb.toString());
    }

    @Override // dr.evomodel.continuous.AbstractMultivariateTraitLikelihood
    protected double getTreeLength() {
        double nodeHeight = this.treeModel.getNodeHeight(this.treeModel.getRoot());
        double d = 0.0d;
        for (int i = 0; i < this.treeModel.getExternalNodeCount(); i++) {
            d += nodeHeight - this.treeModel.getNodeHeight(this.treeModel.getExternalNode(i));
        }
        return d;
    }

    protected double getLengthToRoot(NodeRef nodeRef) {
        return this.exchangeableTips ? getRescaledLengthToRoot(this.treeModel.getExternalNode(this.zeroHeightTip)) : getRescaledLengthToRoot(nodeRef);
    }

    private SufficientStatistics computeInnerProductsForTips(double[][] dArr, double[] dArr2) {
        int number = this.treeModel.getRoot().getNumber();
        int i = this.dim * number;
        for (int i2 = 0; i2 < this.dim; i2++) {
            this.meanCache[i + i2] = 0.0d;
        }
        double d = 0.0d;
        double d2 = 1.0d;
        double d3 = 0.0d;
        int i3 = 0;
        for (int i4 = 0; i4 < this.treeModel.getExternalNodeCount(); i4++) {
            NodeRef externalNode = this.treeModel.getExternalNode(i4);
            int number2 = externalNode.getNumber();
            double d4 = 0.0d;
            if (!this.missingTraits.isCompletelyMissing(number2)) {
                d4 = 1.0d / getLengthToRoot(externalNode);
                int i5 = this.dim * number2;
                int i6 = this.dim * number;
                for (int i7 = 0; i7 < this.numData; i7++) {
                    for (int i8 = 0; i8 < this.dimTrait; i8++) {
                        double[] dArr3 = this.meanCache;
                        int i9 = i6 + i8;
                        dArr3[i9] = dArr3[i9] + (d4 * this.meanCache[i5 + i8]);
                        dArr2[i8] = this.meanCache[i5 + i8];
                    }
                    d += computeWeightedAverageAndSumOfSquares(dArr2, this.Ay, dArr, this.dimTrait, d4);
                    i5 += this.dimTrait;
                    i6 += this.dimTrait;
                }
                if (this.computeWishartStatistics) {
                    incrementOuterProducts(number2, d4);
                }
            }
            if (d4 > 0.0d) {
                d3 += d4;
                d2 *= d4;
                i3++;
            }
        }
        this.lowerPrecisionCache[number] = d3;
        normalize(this.meanCache, i, this.dim, d3);
        if (this.computeWishartStatistics) {
            incrementOuterProducts(number, -d3);
            this.wishartStatistics.incrementDf(-1);
        }
        return new SufficientStatistics(d3, d2, d, i3);
    }

    private void normalize(double[] dArr, int i, int i2, double d) {
        for (int i3 = 0; i3 < i2; i3++) {
            int i4 = i + i3;
            dArr[i4] = dArr[i4] / d;
        }
    }

    private void incrementOuterProducts(int i, double d) {
        double[] scaleMatrix = this.wishartStatistics.getScaleMatrix();
        int i2 = this.dim * i;
        for (int i3 = 0; i3 < this.numData; i3++) {
            for (int i4 = 0; i4 < this.dim; i4++) {
                double d2 = this.meanCache[i2 + i4];
                for (int i5 = 0; i5 < this.dim; i5++) {
                    int i6 = (i4 * this.dim) + i5;
                    scaleMatrix[i6] = scaleMatrix[i6] + (d2 * this.meanCache[i2 + i5] * d);
                }
            }
            i2 += this.dimTrait;
        }
        this.wishartStatistics.incrementDf(1);
    }

    @Override // dr.evomodel.continuous.IntegratedMultivariateTraitLikelihood
    protected boolean peel() {
        return false;
    }

    @Override // dr.evomodel.continuous.IntegratedMultivariateTraitLikelihood, dr.evomodel.continuous.AbstractMultivariateTraitLikelihood
    public double calculateLogLikelihood() {
        double[][] precisionmatrix = this.diffusionModel.getPrecisionmatrix();
        double log = Math.log(this.diffusionModel.getDeterminantPrecisionMatrix());
        double[] dArr = this.tmp2;
        if (this.computeWishartStatistics) {
            this.wishartStatistics = new WishartSufficientStatistics(this.dimTrait);
        }
        SufficientStatistics computeInnerProductsForTips = computeInnerProductsForTips(precisionmatrix, this.tmp2);
        double d = computeInnerProductsForTips.sumWeight;
        double d2 = computeInnerProductsForTips.productWeight;
        double d3 = computeInnerProductsForTips.innerProduct;
        int i = computeInnerProductsForTips.nonMissingTips;
        double d4 = d + this.rootPriorSampleSize;
        double d5 = (d2 * this.rootPriorSampleSize) / d4;
        int number = this.dim * this.treeModel.getRoot().getNumber();
        for (int i2 = 0; i2 < this.numData; i2++) {
            for (int i3 = 0; i3 < this.dimTrait; i3++) {
                dArr[i3] = (d * this.meanCache[number + i3]) + (this.rootPriorSampleSize * this.rootPriorMean[i3]);
            }
            d3 = (d3 + computeWeightedAverageAndSumOfSquares(this.rootPriorMean, this.Ay, precisionmatrix, this.dimTrait, this.rootPriorSampleSize)) - computeWeightedAverageAndSumOfSquares(dArr, this.Ay, precisionmatrix, this.dimTrait, 1.0d / d4);
            if (this.computeWishartStatistics) {
                double[] scaleMatrix = this.wishartStatistics.getScaleMatrix();
                double d6 = (d * this.rootPriorSampleSize) / d4;
                for (int i4 = 0; i4 < this.dimTrait; i4++) {
                    double d7 = this.meanCache[number + i4] - this.rootPriorMean[i4];
                    for (int i5 = 0; i5 < this.dimTrait; i5++) {
                        int i6 = (i4 * this.dimTrait) + i5;
                        scaleMatrix[i6] = scaleMatrix[i6] + (d7 * d6 * (this.meanCache[number + i5] - this.rootPriorMean[i5]));
                    }
                }
                this.wishartStatistics.incrementDf(1);
            }
            number += this.dimTrait;
        }
        double log2 = ((((((-LOG_SQRT_2_PI) * this.dimTrait) * i) * this.numData) + (((0.5d * log) * i) * this.numData)) + (((0.5d * Math.log(d5)) * this.dimTrait) * this.numData)) - (0.5d * d3);
        this.areStatesRedrawn = false;
        return log2;
    }
}
