package dr.evomodel.continuous;

import dr.evolution.tree.MutableTreeModel;
import dr.evolution.tree.NodeRef;
import dr.evomodel.branchratemodel.BranchRateModel;
import dr.inference.model.CompoundParameter;
import dr.inference.model.Model;
import dr.math.distributions.MultivariateNormalDistribution;
import dr.math.matrixAlgebra.IllegalDimension;
import dr.math.matrixAlgebra.Matrix;
import java.util.List;

/* loaded from: input_file:dr/evomodel/continuous/SemiConjugateMultivariateTraitLikelihood.class */
public class SemiConjugateMultivariateTraitLikelihood extends IntegratedMultivariateTraitLikelihood {
    protected double[] rootPriorMean;
    protected double[][] rootPriorPrecision;
    protected double logRootPriorPrecisionDeterminant;
    protected double[] Bz;
    private double zBz;

    public SemiConjugateMultivariateTraitLikelihood(String str, MutableTreeModel mutableTreeModel, MultivariateDiffusionModel multivariateDiffusionModel, CompoundParameter compoundParameter, List<Integer> list, boolean z, boolean z2, boolean z3, BranchRateModel branchRateModel, Model model, boolean z4, MultivariateNormalDistribution multivariateNormalDistribution, boolean z5, List<RestrictedPartials> list2) {
        super(str, mutableTreeModel, multivariateDiffusionModel, compoundParameter, null, list, z, z2, z3, branchRateModel, null, null, null, model, list2, z4, z5);
        setRootPrior(multivariateNormalDistribution);
    }

    @Override // dr.evomodel.continuous.IntegratedMultivariateTraitLikelihood
    public boolean getComputeWishartSufficientStatistics() {
        return false;
    }

    @Override // dr.evomodel.continuous.AbstractMultivariateTraitLikelihood
    protected double calculateAscertainmentCorrection(int i) {
        throw new RuntimeException("Ascertainment correction not yet implemented for semi-conjugate trait likelihoods");
    }

    public double getRescaledLengthToRoot(NodeRef nodeRef) {
        double d = 0.0d;
        NodeRef root = this.treeModel.getRoot();
        while (nodeRef != root) {
            d += getRescaledBranchLengthForPrecision(nodeRef);
            nodeRef = this.treeModel.getParent(nodeRef);
        }
        return d;
    }

    @Override // dr.evomodel.continuous.IntegratedMultivariateTraitLikelihood
    protected double integrateLogLikelihoodAtRoot(double[] dArr, double[] dArr2, double[][] dArr3, double[][] dArr4, double d) {
        double d2 = 0.0d;
        double d3 = 0.0d;
        if (this.dimTrait > 1) {
            for (int i = 0; i < this.dimTrait; i++) {
                int i2 = i;
                dArr2[i2] = dArr2[i2] + this.Bz[i];
                for (int i3 = 0; i3 < this.dimTrait; i3++) {
                    dArr3[i][i3] = (dArr4[i][i3] * d) + this.rootPriorPrecision[i][i3];
                }
            }
            Matrix matrix = new Matrix(dArr3);
            try {
                d2 = matrix.determinant();
            } catch (IllegalDimension e) {
                e.printStackTrace();
            }
            double[][] components = matrix.inverse().toComponents();
            for (int i4 = 0; i4 < this.dimTrait; i4++) {
                for (int i5 = 0; i5 < this.dimTrait; i5++) {
                    d3 += dArr2[i4] * components[i4][i5] * dArr2[i5];
                }
            }
        } else {
            d2 = (dArr4[0][0] * d) + this.rootPriorPrecision[0][0];
            dArr2[0] = dArr2[0] + this.Bz[0];
            d3 = (dArr2[0] * dArr2[0]) / d2;
        }
        double log = 0.5d * (((this.logRootPriorPrecisionDeterminant - Math.log(d2)) - this.zBz) + d3);
        if (DEBUG) {
            System.err.println("(Ay+Bz)(A+B)^{-1}(Ay+Bz) = " + d3);
            System.err.println("density = " + log);
            System.err.println("zBz = " + this.zBz);
        }
        return log;
    }

    private void setRootPriorSumOfSquares() {
        this.Bz = new double[this.dimTrait];
        this.zBz = computeWeightedAverageAndSumOfSquares(this.rootPriorMean, this.Bz, this.rootPriorPrecision, this.dimTrait, 1.0d);
    }

    private void setRootPrior(MultivariateNormalDistribution multivariateNormalDistribution) {
        this.rootPriorMean = multivariateNormalDistribution.getMean();
        this.rootPriorPrecision = multivariateNormalDistribution.getScaleMatrix();
        try {
            this.logRootPriorPrecisionDeterminant = Math.log(new Matrix(this.rootPriorPrecision).determinant());
        } catch (IllegalDimension e) {
            e.printStackTrace();
        }
        setRootPriorSumOfSquares();
    }

    @Override // dr.evomodel.continuous.IntegratedMultivariateTraitLikelihood
    protected double[][] computeMarginalRootMeanAndVariance(double[] dArr, double[][] dArr2, double[][] dArr3, double d) {
        computeWeightedAverageAndSumOfSquares(dArr, this.Ay, dArr2, this.dimTrait, d);
        double[][] dArr4 = this.tmpM;
        for (int i = 0; i < this.dimTrait; i++) {
            double[] dArr5 = this.Ay;
            int i2 = i;
            dArr5[i2] = dArr5[i2] + this.Bz[i];
            for (int i3 = 0; i3 < this.dimTrait; i3++) {
                dArr4[i][i3] = (dArr2[i][i3] * d) + this.rootPriorPrecision[i][i3];
            }
        }
        double[][] components = new Matrix(dArr4).inverse().toComponents();
        for (int i4 = 0; i4 < this.dimTrait; i4++) {
            dArr[i4] = 0.0d;
            for (int i5 = 0; i5 < this.dimTrait; i5++) {
                int i6 = i4;
                dArr[i6] = dArr[i6] + (components[i4][i5] * this.Ay[i5]);
            }
        }
        return components;
    }
}
