package dr.evomodel.treedatalikelihood.preorder;

import dr.evolution.tree.Tree;
import dr.evomodel.continuous.MultivariateDiffusionModel;
import dr.evomodel.treedatalikelihood.continuous.ConjugateRootTraitPrior;
import dr.evomodel.treedatalikelihood.continuous.ContinuousDataLikelihoodDelegate;
import dr.evomodel.treedatalikelihood.continuous.ContinuousRateTransformation;
import dr.evomodel.treedatalikelihood.continuous.ContinuousTraitPartialsProvider;
import dr.evomodel.treedatalikelihood.continuous.PartiallyMissingInformation;
import dr.math.distributions.MultivariateNormalDistribution;
import dr.math.matrixAlgebra.ReadableVector;
import dr.math.matrixAlgebra.WrappedMatrix;
import dr.math.matrixAlgebra.WrappedVector;
import dr.math.matrixAlgebra.missingData.MissingOps;
import org.ejml.data.DenseMatrix64F;
import org.ejml.ops.CommonOps;

/* loaded from: input_file:dr/evomodel/treedatalikelihood/preorder/MultivariateConditionalOnTipsRealizedDelegate.class */
public class MultivariateConditionalOnTipsRealizedDelegate extends ConditionalOnTipsRealizedDelegate {
    private static final boolean DEBUG = false;
    private final PartiallyMissingInformation missingInformation;
    private static final boolean NEW_TIP_WITH_NO_DATA = true;
    private static final boolean NEW_CHOLESKY = false;

    public MultivariateConditionalOnTipsRealizedDelegate(String str, Tree tree, MultivariateDiffusionModel multivariateDiffusionModel, ContinuousTraitPartialsProvider continuousTraitPartialsProvider, ConjugateRootTraitPrior conjugateRootTraitPrior, ContinuousRateTransformation continuousRateTransformation, ContinuousDataLikelihoodDelegate continuousDataLikelihoodDelegate) {
        super(str, tree, multivariateDiffusionModel, continuousTraitPartialsProvider, conjugateRootTraitPrior, continuousRateTransformation, continuousDataLikelihoodDelegate);
        this.missingInformation = new PartiallyMissingInformation(tree, continuousTraitPartialsProvider);
    }

    @Override // dr.evomodel.treedatalikelihood.preorder.ConditionalOnTipsRealizedDelegate
    protected void simulateTraitForRoot(int i, int i2) {
        DenseMatrix64F wrap = MissingOps.wrap(this.partialNodeBuffer, i2 + this.dimTrait, this.dimTrait, this.dimTrait);
        DenseMatrix64F denseMatrix64F = new DenseMatrix64F(this.dimTrait, this.dimTrait);
        MissingOps.safeMult(this.Pd, MissingOps.wrap(this.partialPriorBuffer, i2 + this.dimTrait, this.dimTrait, this.dimTrait), denseMatrix64F);
        DenseMatrix64F denseMatrix64F2 = new DenseMatrix64F(this.dimTrait, this.dimTrait);
        CommonOps.add(wrap, denseMatrix64F, denseMatrix64F2);
        DenseMatrix64F denseMatrix64F3 = new DenseMatrix64F(this.dimTrait, this.dimTrait);
        MissingOps.safeInvert2(denseMatrix64F2, denseMatrix64F3, false);
        double[] dArr = new double[this.dimTrait];
        MissingOps.safeWeightedAverage(new WrappedVector.Raw(this.partialNodeBuffer, i2, this.dimTrait), wrap, new WrappedVector.Raw(this.partialPriorBuffer, i2, this.dimTrait), denseMatrix64F, new WrappedVector.Raw(dArr, 0, this.dimTrait), denseMatrix64F3, this.dimTrait);
        MultivariateNormalDistribution.nextMultivariateNormalCholesky(new WrappedVector.Raw(dArr, 0, this.dimTrait), new WrappedMatrix.Raw(getCholeskyOfVariance(denseMatrix64F3, this.dimTrait).getData(), 0, this.dimTrait, this.dimTrait), 1.0d, new WrappedVector.Raw(this.sample, i, this.dimTrait), this.tmpEpsilon);
    }

    @Override // dr.evomodel.treedatalikelihood.preorder.ConditionalOnTipsRealizedDelegate
    protected void simulateTraitForNode(int i, int i2, int i3, int i4, int i5, int i6, double d) {
        if (i6 == 1) {
            simulateTraitForExternalNode(i, i2, i3, i4, i5, d);
        } else {
            simulateTraitForInternalNode(i3, i4, i5, d);
        }
    }

    private void simulateTraitForExternalNode(int i, int i2, int i3, int i4, int i5, double d) {
        DenseMatrix64F wrap = MissingOps.wrap(this.partialNodeBuffer, i5 + this.dimTrait, this.dimTrait, this.dimTrait);
        int countFiniteDiagonals = MissingOps.countFiniteDiagonals(wrap);
        if (countFiniteDiagonals == 0) {
            System.arraycopy(this.partialNodeBuffer, i5, this.sample, i3, this.dimTrait);
            return;
        }
        if (MissingOps.countZeroDiagonals(wrap) == this.dimTrait) {
            MultivariateNormalDistribution.nextMultivariateNormalCholesky(getMeanBranch(i4), new WrappedMatrix.ArrayOfArray(this.cholesky), Math.sqrt(1.0d / d), new WrappedVector.Raw(this.sample, i3, this.dimTrait), this.tmpEpsilon);
            return;
        }
        if (countFiniteDiagonals == this.dimTrait) {
            simulateTraitForInternalNode(i3, i4, i5, d);
            return;
        }
        System.arraycopy(this.partialNodeBuffer, i5, this.sample, i3, this.dimTrait);
        PartiallyMissingInformation.HashedIntArray missingIndices = this.missingInformation.getMissingIndices(i, i2);
        int[] complement = missingIndices.getComplement();
        int[] array = missingIndices.getArray();
        ConditionalVarianceAndTransform2 conditionalVarianceAndTransform2 = new ConditionalVarianceAndTransform2(getVarianceBranch(d), array, complement);
        DenseMatrix64F denseMatrix64F = new DenseMatrix64F(array.length, array.length);
        MissingOps.gatherRowsAndColumns(wrap, denseMatrix64F, array, array);
        WrappedVector conditionalMean = conditionalVarianceAndTransform2.getConditionalMean(this.partialNodeBuffer, i5, this.sample, i4);
        DenseMatrix64F conditionalPrecision = conditionalVarianceAndTransform2.getConditionalPrecision();
        DenseMatrix64F denseMatrix64F2 = new DenseMatrix64F(array.length, array.length);
        DenseMatrix64F denseMatrix64F3 = new DenseMatrix64F(array.length, array.length);
        CommonOps.add(denseMatrix64F, conditionalPrecision, denseMatrix64F2);
        MissingOps.safeInvert2(denseMatrix64F2, denseMatrix64F3, false);
        MultivariateNormalDistribution.nextMultivariateNormalCholesky(conditionalMean, new WrappedMatrix.ArrayOfArray(getCholeskyOfVariance(denseMatrix64F3.getData(), array.length)), 1.0d, new WrappedVector.Indexed(this.sample, i3, array, array.length), this.tmpEpsilon);
    }

    ReadableVector getMeanBranch(int i) {
        double[] dArr = new double[this.dimTrait];
        System.arraycopy(this.sample, i, dArr, 0, this.dimTrait);
        double[] dArr2 = new double[this.dimTrait];
        this.cdi.getBranchExpectation(this.actualizationBuffer, dArr, this.displacementBuffer, dArr2);
        return new WrappedVector.Raw(dArr2, 0, this.dimTrait);
    }

    private void simulateTraitForInternalNode(int i, int i2, int i3, double d) {
        if (Double.isInfinite(d)) {
            System.arraycopy(this.sample, i2, this.sample, i, this.dimTrait);
            return;
        }
        WrappedVector.Raw raw = new WrappedVector.Raw(this.partialNodeBuffer, i3, this.dimTrait);
        DenseMatrix64F wrap = MissingOps.wrap(this.partialNodeBuffer, i3 + this.dimTrait, this.dimTrait, this.dimTrait);
        ReadableVector meanBranch = getMeanBranch(i2);
        DenseMatrix64F precisionBranch = getPrecisionBranch(d);
        WrappedVector.Raw raw2 = new WrappedVector.Raw(this.tmpMean, 0, this.dimTrait);
        DenseMatrix64F denseMatrix64F = new DenseMatrix64F(this.dimTrait, this.dimTrait);
        DenseMatrix64F denseMatrix64F2 = new DenseMatrix64F(this.dimTrait, this.dimTrait);
        CommonOps.add(wrap, precisionBranch, denseMatrix64F);
        MissingOps.safeInvert2(denseMatrix64F, denseMatrix64F2, false);
        MissingOps.weightedAverage(raw, wrap, meanBranch, precisionBranch, raw2, denseMatrix64F2, this.dimTrait);
        MultivariateNormalDistribution.nextMultivariateNormalCholesky(raw2, new WrappedMatrix.ArrayOfArray(getCholeskyOfVariance(denseMatrix64F2.getData(), this.dimTrait)), 1.0d, new WrappedVector.Raw(this.sample, i, this.dimTrait), this.tmpEpsilon);
    }

    private boolean check(ReadableVector readableVector) {
        for (int i = 0; i < readableVector.getDim(); i++) {
            if (Double.isNaN(readableVector.get(i))) {
                return false;
            }
        }
        return true;
    }

    DenseMatrix64F getPrecisionBranch(double d) {
        if (this.hasDrift) {
            return DenseMatrix64F.wrap(this.dimTrait, this.dimTrait, this.precisionBuffer);
        }
        DenseMatrix64F denseMatrix64F = new DenseMatrix64F(this.dimTrait, this.dimTrait);
        CommonOps.scale(d, this.Pd, denseMatrix64F);
        return denseMatrix64F;
    }

    DenseMatrix64F getVarianceBranch(double d) {
        if (!this.hasDrift) {
            DenseMatrix64F denseMatrix64F = new DenseMatrix64F(this.dimTrait, this.dimTrait);
            CommonOps.scale(1.0d / d, this.Vd, denseMatrix64F);
            return denseMatrix64F;
        }
        DenseMatrix64F precisionBranch = getPrecisionBranch(d);
        DenseMatrix64F denseMatrix64F2 = new DenseMatrix64F(this.dimTrait, this.dimTrait);
        CommonOps.invert(precisionBranch, denseMatrix64F2);
        return denseMatrix64F2;
    }
}
