package dr.evomodel.continuous;

import dr.evolution.tree.Tree;
import dr.inference.distribution.ParametricMultivariateDistributionModel;
import dr.inference.model.AbstractModel;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.inferencexml.distribution.MultivariateNormalDistributionModelParser;
import dr.math.distributions.MultivariateNormalDistribution;
import dr.math.distributions.RandomGenerator;

/* loaded from: input_file:dr/evomodel/continuous/TreeTraitNormalDistributionModel.class */
public class TreeTraitNormalDistributionModel extends AbstractModel implements ParametricMultivariateDistributionModel, RandomGenerator {
    private final FullyConjugateMultivariateTraitLikelihood traitModel;
    private double[] mean;
    private double[][] precision;
    private MultivariateNormalDistribution distribution;
    private MultivariateNormalDistribution storedDistribution;
    private boolean distributionKnown;
    private boolean storedDistributionKnown;
    private final boolean conditionOnRoot;
    private double[][] precisionMatrix;
    private double[] rootValue;
    private final int dim;

    public TreeTraitNormalDistributionModel(FullyConjugateMultivariateTraitLikelihood fullyConjugateMultivariateTraitLikelihood, Parameter parameter, boolean z) {
        super(MultivariateNormalDistributionModelParser.NORMAL_DISTRIBUTION_MODEL);
        this.precisionMatrix = null;
        this.traitModel = fullyConjugateMultivariateTraitLikelihood;
        if (parameter != null) {
            this.rootValue = parameter.getParameterValues();
        }
        this.conditionOnRoot = z;
        this.dim = fullyConjugateMultivariateTraitLikelihood.getTreeModel().getExternalNodeCount() * fullyConjugateMultivariateTraitLikelihood.getDimTrait();
        addModel(fullyConjugateMultivariateTraitLikelihood);
        this.distributionKnown = false;
    }

    public TreeTraitNormalDistributionModel(FullyConjugateMultivariateTraitLikelihood fullyConjugateMultivariateTraitLikelihood, boolean z) {
        this(fullyConjugateMultivariateTraitLikelihood, null, z);
    }

    public Tree getTree() {
        return this.traitModel.getTreeModel();
    }

    @Override // dr.math.distributions.MultivariateDistribution, dr.inference.distribution.DensityModel
    public double logPdf(double[] dArr) {
        checkDistribution();
        return this.distribution.logPdf(dArr);
    }

    @Override // dr.math.distributions.MultivariateDistribution
    public double[][] getScaleMatrix() {
        checkDistribution();
        return this.distribution.getScaleMatrix();
    }

    @Override // dr.math.distributions.MultivariateDistribution
    public double[] getMean() {
        checkDistribution();
        return this.distribution.getMean();
    }

    @Override // dr.math.distributions.MultivariateDistribution
    public String getType() {
        return "TreeTraitMVN";
    }

    public int getDimTrait() {
        return this.traitModel.dimTrait;
    }

    @Override // dr.inference.model.AbstractModel
    public void handleModelChangedEvent(Model model, Object obj, int i) {
        this.distributionKnown = false;
    }

    @Override // dr.inference.model.AbstractModel
    protected final void handleVariableChangedEvent(Variable variable, int i, Variable.ChangeType changeType) {
        this.distributionKnown = false;
    }

    @Override // dr.inference.model.AbstractModel
    protected void storeState() {
        this.storedDistribution = this.distribution;
        this.storedDistributionKnown = this.distributionKnown;
    }

    @Override // dr.inference.model.AbstractModel
    protected void restoreState() {
        this.distributionKnown = this.storedDistributionKnown;
        this.distribution = this.storedDistribution;
    }

    @Override // dr.inference.model.AbstractModel
    protected void acceptState() {
    }

    private void checkDistribution() {
        if (this.distributionKnown) {
            return;
        }
        this.mean = null;
        this.precision = null;
        this.distribution = createNewDistribution();
        this.distributionKnown = true;
    }

    private MultivariateNormalDistribution createNewDistribution() {
        return new MultivariateNormalDistribution(computeMean(), computePrecision());
    }

    private double[] computeMean() {
        return this.traitModel.strengthOfSelection != null ? MultivariateTraitUtils.computeTreeTraitMeanOU(this.traitModel, this.rootValue, this.conditionOnRoot) : MultivariateTraitUtils.computeTreeTraitMean(this.traitModel, this.rootValue, this.conditionOnRoot);
    }

    private double[][] computePrecision() {
        return MultivariateTraitUtils.computeTreeTraitPrecision(this.traitModel, this.conditionOnRoot);
    }

    @Override // dr.math.distributions.RandomGenerator
    public double[] nextRandom() {
        checkDistribution();
        return this.distribution.nextMultivariateNormal();
    }

    @Override // dr.math.distributions.RandomGenerator
    public double logPdf(Object obj) {
        checkDistribution();
        return this.distribution.logPdf(obj);
    }

    @Override // dr.inference.distribution.DensityModel
    public Variable<Double> getLocationVariable() {
        throw new UnsupportedOperationException("Not implemented");
    }
}
