package dr.evomodel.continuous;

import dr.evolution.tree.MutableTreeModel;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evomodel.branchratemodel.BranchRateModel;
import dr.evomodel.tree.TreeModel;
import dr.inference.model.CompoundParameter;
import dr.inference.model.CompoundSymmetricMatrix;
import dr.inference.model.Model;
import dr.math.matrixAlgebra.Matrix;
import dr.math.matrixAlgebra.Vector;
import java.util.List;

/* loaded from: input_file:dr/evomodel/continuous/SampledMultivariateTraitLikelihood.class */
public class SampledMultivariateTraitLikelihood extends AbstractMultivariateTraitLikelihood {
    public SampledMultivariateTraitLikelihood(String str, MutableTreeModel mutableTreeModel, MultivariateDiffusionModel multivariateDiffusionModel, CompoundParameter compoundParameter, List<Integer> list, boolean z, boolean z2, boolean z3, BranchRateModel branchRateModel, Model model, boolean z4, boolean z5) {
        super(str, mutableTreeModel, multivariateDiffusionModel, compoundParameter, null, list, z, z2, z3, branchRateModel, null, null, null, model, z4, z5);
    }

    @Override // dr.evomodel.continuous.AbstractMultivariateTraitLikelihood
    protected String extraInfo() {
        return "\tSampling internal trait values: true\n";
    }

    @Override // dr.evomodel.continuous.AbstractMultivariateTraitLikelihood
    public double calculateLogLikelihood() {
        double traitLogLikelihood = !this.cacheBranches ? traitLogLikelihood(null, this.treeModel.getRoot()) : traitCachedLogLikelihood(null, this.treeModel.getRoot());
        if (traitLogLikelihood > this.maxLogLikelihood) {
            this.maxLogLikelihood = traitLogLikelihood;
        }
        return traitLogLikelihood;
    }

    @Override // dr.evomodel.continuous.AbstractMultivariateTraitLikelihood
    protected double calculateAscertainmentCorrection(int i) {
        throw new RuntimeException("Ascertainment correction not yet implemented for sampled trait likelihoods");
    }

    @Override // dr.evomodel.continuous.AbstractMultivariateTraitLikelihood
    public final double getLogDataLikelihood() {
        double d;
        double logLikelihood;
        double d2 = 0.0d;
        for (int i = 0; i < this.treeModel.getExternalNodeCount(); i++) {
            NodeRef externalNode = this.treeModel.getExternalNode(i);
            if (this.cacheBranches && this.validLogLikelihoods[externalNode.getNumber()]) {
                d = d2;
                logLikelihood = this.cachedLogLikelihoods[externalNode.getNumber()];
            } else {
                NodeRef parent = this.treeModel.getParent(externalNode);
                double[] multivariateNodeTrait = this.treeModel.getMultivariateNodeTrait(externalNode, this.traitName);
                double[] multivariateNodeTrait2 = this.treeModel.getMultivariateNodeTrait(parent, this.traitName);
                double rescaledBranchLengthForPrecision = getRescaledBranchLengthForPrecision(externalNode);
                d = d2;
                logLikelihood = this.diffusionModel.getLogLikelihood(multivariateNodeTrait2, multivariateNodeTrait, rescaledBranchLengthForPrecision);
            }
            d2 = d + logLikelihood;
        }
        return d2;
    }

    private double traitCachedLogLikelihood(double[] dArr, NodeRef nodeRef) {
        double d = 0.0d;
        double[] dArr2 = null;
        int number = nodeRef.getNumber();
        if (!this.treeModel.isRoot(nodeRef)) {
            if (this.validLogLikelihoods[number]) {
                d = this.cachedLogLikelihoods[number];
            } else {
                dArr2 = this.treeModel.getMultivariateNodeTrait(nodeRef, this.traitName);
                double rescaledBranchLengthForPrecision = getRescaledBranchLengthForPrecision(nodeRef);
                if (dArr == null) {
                    dArr = this.treeModel.getMultivariateNodeTrait(this.treeModel.getParent(nodeRef), this.traitName);
                }
                d = this.diffusionModel.getLogLikelihood(dArr, dArr2, rescaledBranchLengthForPrecision);
                this.cachedLogLikelihoods[number] = d;
                this.validLogLikelihoods[number] = true;
            }
        }
        int childCount = this.treeModel.getChildCount(nodeRef);
        for (int i = 0; i < childCount; i++) {
            d += traitCachedLogLikelihood(dArr2, this.treeModel.getChild(nodeRef, i));
        }
        return d;
    }

    private double traitLogLikelihood(double[] dArr, NodeRef nodeRef) {
        double d = 0.0d;
        double[] multivariateNodeTrait = this.treeModel.getMultivariateNodeTrait(nodeRef, this.traitName);
        if (dArr != null) {
            double rescaledBranchLengthForPrecision = getRescaledBranchLengthForPrecision(nodeRef);
            d = this.diffusionModel.getLogLikelihood(dArr, multivariateNodeTrait, rescaledBranchLengthForPrecision);
            if (new Double(d).isNaN()) {
                System.err.println("AbstractMultivariateTraitLikelihood: likelihood is undefined");
                System.err.println("time = " + rescaledBranchLengthForPrecision);
                System.err.println("parent trait value = " + new Vector(dArr));
                System.err.println("child trait value = " + new Vector(multivariateNodeTrait));
                if (this.diffusionModel.getPrecisionmatrix() != null) {
                    System.err.println("precision matrix = " + new Matrix(this.diffusionModel.getPrecisionmatrix()));
                    if (this.diffusionModel.getPrecisionParameter() instanceof CompoundSymmetricMatrix) {
                    }
                }
            }
        }
        int childCount = this.treeModel.getChildCount(nodeRef);
        for (int i = 0; i < childCount; i++) {
            d += traitLogLikelihood(multivariateNodeTrait, this.treeModel.getChild(nodeRef, i));
        }
        if (new Double(d).isNaN()) {
            System.err.println("logL = " + d);
            System.exit(-1);
        }
        return d;
    }

    @Override // dr.evomodel.continuous.AbstractMultivariateTraitLikelihood
    public double[] getTraitForNode(Tree tree, NodeRef nodeRef, String str) {
        return ((TreeModel) tree).getMultivariateNodeTrait(nodeRef, str);
    }
}
