package dr.inference.distribution;

import dr.inference.model.AbstractModel;
import dr.inference.model.DiagonalMatrix;
import dr.inference.model.DuplicatedParameter;
import dr.inference.model.GradientProvider;
import dr.inference.model.HessianProvider;
import dr.inference.model.Likelihood;
import dr.inference.model.MatrixParameter;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.inferencexml.distribution.MultivariateNormalDistributionModelParser;
import dr.math.distributions.GaussianProcessRandomGenerator;
import dr.math.distributions.MultivariateNormalDistribution;

/* loaded from: input_file:dr/inference/distribution/MultivariateNormalDistributionModel.class */
public class MultivariateNormalDistributionModel extends AbstractModel implements ParametricMultivariateDistributionModel, GaussianProcessRandomGenerator, GradientProvider, HessianProvider {
    private final Parameter mean;
    private final MatrixParameter precision;
    private final boolean hasSinglePrecision;
    private final Parameter singlePrecision;
    private MultivariateNormalDistribution distribution;
    private MultivariateNormalDistribution storedDistribution;
    private boolean distributionKnown;
    private boolean storedDistributionKnown;

    public MultivariateNormalDistributionModel(Parameter parameter, MatrixParameter matrixParameter) {
        super(MultivariateNormalDistributionModelParser.NORMAL_DISTRIBUTION_MODEL);
        this.mean = parameter;
        addVariable(parameter);
        if (!(parameter instanceof DuplicatedParameter)) {
            parameter.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, Double.NEGATIVE_INFINITY, parameter.getDimension()));
        }
        this.precision = matrixParameter;
        addVariable(matrixParameter);
        Parameter parameter2 = null;
        if (matrixParameter instanceof DiagonalMatrix) {
            DiagonalMatrix diagonalMatrix = (DiagonalMatrix) matrixParameter;
            if (diagonalMatrix.getDiagonalParameter() instanceof DuplicatedParameter) {
                parameter2 = diagonalMatrix.getDiagonalParameter();
            }
        }
        this.hasSinglePrecision = parameter2 != null;
        this.singlePrecision = parameter2;
        this.distribution = createNewDistribution();
        this.distributionKnown = true;
    }

    public MatrixParameter getPrecisionMatrixParameter() {
        return this.precision;
    }

    public Parameter getMeanParameter() {
        return this.mean;
    }

    private void checkDistribution() {
        if (this.distributionKnown) {
            return;
        }
        this.distribution = createNewDistribution();
        this.distributionKnown = true;
    }

    @Override // dr.math.distributions.MultivariateDistribution, dr.inference.distribution.DensityModel
    public double logPdf(double[] dArr) {
        checkDistribution();
        return this.distribution.logPdf(dArr);
    }

    @Override // dr.math.distributions.MultivariateDistribution
    public double[][] getScaleMatrix() {
        return this.precision.getParameterAsMatrix();
    }

    @Override // dr.math.distributions.MultivariateDistribution
    public double[] getMean() {
        return this.mean.getParameterValues();
    }

    @Override // dr.math.distributions.MultivariateDistribution
    public String getType() {
        return this.distribution.getType();
    }

    @Override // dr.inference.model.AbstractModel
    public void handleModelChangedEvent(Model model, Object obj, int i) {
    }

    @Override // dr.math.distributions.GaussianProcessRandomGenerator
    public Likelihood getLikelihood() {
        return null;
    }

    @Override // dr.inference.model.AbstractModel
    protected final void handleVariableChangedEvent(Variable variable, int i, Variable.ChangeType changeType) {
        this.distributionKnown = false;
    }

    @Override // dr.inference.model.AbstractModel
    protected void storeState() {
        this.storedDistribution = this.distribution;
        this.storedDistributionKnown = this.distributionKnown;
    }

    @Override // dr.inference.model.AbstractModel
    protected void restoreState() {
        this.distributionKnown = this.storedDistributionKnown;
        this.distribution = this.storedDistribution;
    }

    @Override // dr.inference.model.AbstractModel
    protected void acceptState() {
    }

    @Override // dr.math.distributions.GaussianProcessRandomGenerator
    public int getDimension() {
        return this.mean.getDimension();
    }

    @Override // dr.math.distributions.GaussianProcessRandomGenerator
    public double[][] getPrecisionMatrix() {
        return this.precision.getParameterAsMatrix();
    }

    @Override // dr.inference.distribution.DensityModel
    public Variable<Double> getLocationVariable() {
        return this.mean;
    }

    private MultivariateNormalDistribution createNewDistribution() {
        return this.hasSinglePrecision ? new MultivariateNormalDistribution(getMean(), this.singlePrecision.getParameterValue(0)) : new MultivariateNormalDistribution(getMean(), getScaleMatrix());
    }

    @Override // dr.math.distributions.RandomGenerator
    public double[] nextRandom() {
        checkDistribution();
        return this.distribution.nextMultivariateNormal();
    }

    @Override // dr.math.distributions.RandomGenerator
    public double logPdf(Object obj) {
        checkDistribution();
        return this.distribution.logPdf(obj);
    }

    @Override // dr.inference.model.GradientProvider
    public double[] getGradientLogDensity(Object obj) {
        checkDistribution();
        return this.distribution.getGradientLogDensity(obj);
    }

    @Override // dr.inference.model.HessianProvider
    public double[] getDiagonalHessianLogDensity(Object obj) {
        checkDistribution();
        return this.distribution.getDiagonalHessianLogDensity(obj);
    }

    @Override // dr.inference.model.HessianProvider
    public double[][] getHessianLogDensity(Object obj) {
        checkDistribution();
        return this.distribution.getHessianLogDensity(obj);
    }
}
