package dr.evomodel.continuous;

import dr.evolution.tree.NodeRef;
import dr.inference.model.CompoundParameter;
import dr.inference.model.Likelihood;
import dr.math.KroneckerOperation;
import dr.math.distributions.GaussianProcessRandomGenerator;
import dr.math.distributions.MultivariateNormalDistribution;
import dr.math.matrixAlgebra.CholeskyDecomposition;
import dr.math.matrixAlgebra.IllegalDimension;
import dr.math.matrixAlgebra.Matrix;
import dr.math.matrixAlgebra.SymmetricMatrix;

/* loaded from: input_file:dr/evomodel/continuous/GaussianProcessFromTree.class */
public class GaussianProcessFromTree implements GaussianProcessRandomGenerator {
    private final FullyConjugateMultivariateTraitLikelihood traitModel;
    private static final boolean USE_BUFFER = true;

    public GaussianProcessFromTree(FullyConjugateMultivariateTraitLikelihood fullyConjugateMultivariateTraitLikelihood) {
        this.traitModel = fullyConjugateMultivariateTraitLikelihood;
    }

    @Override // dr.math.distributions.GaussianProcessRandomGenerator
    public Likelihood getLikelihood() {
        return this.traitModel;
    }

    @Override // dr.math.distributions.GaussianProcessRandomGenerator
    public int getDimension() {
        return this.traitModel.getTreeModel().getExternalNodeCount() * this.traitModel.getDimTrait();
    }

    @Override // dr.math.distributions.GaussianProcessRandomGenerator
    public double[][] getPrecisionMatrix() {
        double[][] computeTreeVariance2 = this.traitModel.computeTreeVariance2(false);
        return KroneckerOperation.product(new Matrix(computeTreeVariance2).inverse().toComponents(), this.traitModel.getDiffusionModel().getPrecisionmatrix());
    }

    private static void scale(double[][] dArr, double d) {
        for (int i = 0; i < dArr.length; i++) {
            for (int i2 = 0; i2 < dArr[i].length; i2++) {
                double[] dArr2 = dArr[i];
                int i3 = i2;
                dArr2[i3] = dArr2[i3] * d;
            }
        }
    }

    public double getLogLikelihood() {
        return this.traitModel.getLogLikelihood();
    }

    public double[] nextRandomFast() {
        double[] dArr = new double[this.traitModel.getTreeModel().getExternalNodeCount() * this.traitModel.getDimTrait()];
        NodeRef root = this.traitModel.getTreeModel().getRoot();
        double[] priorMean = this.traitModel.getPriorMean();
        double[][] dArr2 = null;
        try {
            dArr2 = new CholeskyDecomposition(new SymmetricMatrix(this.traitModel.getDiffusionModel().getPrecisionmatrix()).inverse().toComponents()).getL();
        } catch (IllegalDimension e) {
            e.printStackTrace();
        }
        int dimTrait = this.traitModel.getDimTrait();
        int nodeCount = this.traitModel.getTreeModel().getNodeCount();
        double[] dArr3 = new double[(nodeCount + 1) * dimTrait];
        double[] dArr4 = new double[dimTrait];
        int i = nodeCount * dimTrait;
        System.arraycopy(priorMean, 0, dArr3, i, dimTrait);
        nextRandomFast2(dArr3, i, root, dArr, dArr2, dArr4);
        return dArr;
    }

    private void nextRandomFast(double[] dArr, NodeRef nodeRef, double[] dArr2, double[][] dArr3) {
        double[] nextMultivariateNormalCholesky = MultivariateNormalDistribution.nextMultivariateNormalCholesky(dArr, dArr3, Math.sqrt(this.traitModel.getTreeModel().isRoot(nodeRef) ? 1.0d / this.traitModel.getPriorSampleSize() : this.traitModel.getRescaledBranchLengthForPrecision(nodeRef)));
        if (this.traitModel.getTreeModel().isExternal(nodeRef)) {
            System.arraycopy(nextMultivariateNormalCholesky, 0, dArr2, nodeRef.getNumber() * nextMultivariateNormalCholesky.length, nextMultivariateNormalCholesky.length);
            return;
        }
        int childCount = this.traitModel.getTreeModel().getChildCount(nodeRef);
        for (int i = 0; i < childCount; i++) {
            nextRandomFast(nextMultivariateNormalCholesky, this.traitModel.getTreeModel().getChild(nodeRef, i), dArr2, dArr3);
        }
    }

    private void nextRandomFast2(double[] dArr, int i, NodeRef nodeRef, double[] dArr2, double[][] dArr3, double[] dArr4) {
        int length = dArr3.length;
        double sqrt = Math.sqrt(this.traitModel.getTreeModel().isRoot(nodeRef) ? 1.0d / this.traitModel.getPriorSampleSize() : this.traitModel.getRescaledBranchLengthForPrecision(nodeRef));
        int number = nodeRef.getNumber() * length;
        MultivariateNormalDistribution.nextMultivariateNormalCholesky(dArr, i, dArr3, sqrt, dArr, number, dArr4);
        if (this.traitModel.getTreeModel().isExternal(nodeRef)) {
            System.arraycopy(dArr, number, dArr2, number, length);
            return;
        }
        int childCount = this.traitModel.getTreeModel().getChildCount(nodeRef);
        for (int i2 = 0; i2 < childCount; i2++) {
            nextRandomFast2(dArr, number, this.traitModel.getTreeModel().getChild(nodeRef, i2), dArr2, dArr3, dArr4);
        }
    }

    @Override // dr.math.distributions.RandomGenerator
    public Object nextRandom() {
        return nextRandomFast();
    }

    @Override // dr.math.distributions.RandomGenerator
    public double logPdf(Object obj) {
        double[] dArr = (double[]) obj;
        CompoundParameter traitParameter = this.traitModel.getTraitParameter();
        for (int i = 0; i < dArr.length; i++) {
            traitParameter.setParameterValueQuietly(i, dArr[i]);
        }
        traitParameter.fireParameterChangedEvent();
        return this.traitModel.getLogLikelihood();
    }
}
