package dr.evomodel.treedatalikelihood.continuous;

import dr.evolution.tree.MutableTreeModel;
import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeTrait;
import dr.evomodel.tree.TreeModel;
import dr.evomodel.treedatalikelihood.continuous.cdi.PrecisionType;
import dr.evomodel.treedatalikelihood.preorder.ContinuousExtensionDelegate;
import dr.evomodel.treedatalikelihood.preorder.ModelExtensionProvider;
import dr.evomodelxml.treelikelihood.TreeTraitParserUtilities;
import dr.inference.model.CompoundParameter;
import dr.inference.model.MatrixParameterInterface;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.math.matrixAlgebra.CholeskyDecomposition;
import dr.math.matrixAlgebra.IllegalDimension;
import dr.math.matrixAlgebra.Matrix;
import dr.math.matrixAlgebra.missingData.MissingOps;
import dr.xml.AbstractXMLObjectParser;
import dr.xml.AttributeRule;
import dr.xml.ElementRule;
import dr.xml.XMLObject;
import dr.xml.XMLParseException;
import dr.xml.XMLSyntaxRule;
import java.util.List;
import org.ejml.data.DenseMatrix64F;

/* loaded from: input_file:dr/evomodel/treedatalikelihood/continuous/RepeatedMeasuresTraitDataModel.class */
public class RepeatedMeasuresTraitDataModel extends ContinuousTraitDataModel implements ContinuousTraitPartialsProvider, ModelExtensionProvider.NormalExtensionProvider {
    private final String traitName;
    private final MatrixParameterInterface samplingPrecisionParameter;
    private boolean diagonalOnly;
    private boolean variableChanged;
    private boolean varianceKnown;
    private Matrix samplingPrecision;
    private Matrix samplingVariance;
    private Matrix storedSamplingPrecision;
    private Matrix storedSamplingVariance;
    private boolean storedVarianceKnown;
    private boolean storedVariableChanged;
    private static final boolean DEBUG = false;
    private static final String REPEATED_MEASURES_MODEL = "repeatedMeasuresModel";
    private static final String PRECISION = "samplingPrecision";
    public static AbstractXMLObjectParser PARSER;
    private static final XMLSyntaxRule[] rules;
    static final /* synthetic */ boolean $assertionsDisabled;

    public RepeatedMeasuresTraitDataModel(String str, CompoundParameter compoundParameter, List<Integer> list, boolean z, int i, MatrixParameterInterface matrixParameterInterface) {
        super(str, compoundParameter, list, z, i, PrecisionType.FULL);
        this.diagonalOnly = false;
        this.variableChanged = true;
        this.varianceKnown = false;
        this.storedVarianceKnown = false;
        this.storedVariableChanged = true;
        this.traitName = str;
        this.samplingPrecisionParameter = matrixParameterInterface;
        addVariable(matrixParameterInterface);
        calculatePrecisionInfo();
        this.samplingVariance = null;
        this.samplingPrecisionParameter.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, 0.0d, matrixParameterInterface.getDimension()));
    }

    @Override // dr.evomodel.treedatalikelihood.continuous.ContinuousTraitDataModel, dr.evomodel.treedatalikelihood.continuous.ContinuousTraitPartialsProvider
    public double[] getTipPartial(int i, boolean z) {
        if (!$assertionsDisabled && this.numTraits != 1) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && (this.samplingPrecision.rows() != this.dimTrait || this.samplingPrecision.columns() != this.dimTrait)) {
            throw new AssertionError();
        }
        recomputeVariance();
        if (z) {
            return new double[this.dimTrait + 1];
        }
        double[] tipPartial = super.getTipPartial(i, z);
        DenseMatrix64F wrap = MissingOps.wrap(tipPartial, this.dimTrait + (this.dimTrait * this.dimTrait), this.dimTrait, this.dimTrait);
        if (this.diagonalOnly) {
            for (int i2 = 0; i2 < this.dimTrait; i2++) {
                wrap.set(i2, i2, wrap.get(i2, i2) + (1.0d / this.samplingPrecision.component(i2, i2)));
            }
        } else {
            for (int i3 = 0; i3 < this.dimTrait; i3++) {
                for (int i4 = 0; i4 < this.dimTrait; i4++) {
                    wrap.set(i3, i4, wrap.get(i3, i4) + this.samplingVariance.component(i3, i4));
                }
            }
        }
        DenseMatrix64F denseMatrix64F = new DenseMatrix64F(this.dimTrait, this.dimTrait);
        MissingOps.safeInvert2(wrap, denseMatrix64F, false);
        MissingOps.unwrap(denseMatrix64F, tipPartial, this.dimTrait);
        MissingOps.unwrap(wrap, tipPartial, this.dimTrait + (this.dimTrait * this.dimTrait));
        return tipPartial;
    }

    private void recomputeVariance() {
        checkVariableChanged();
        if (this.varianceKnown) {
            return;
        }
        this.samplingVariance = this.samplingPrecision.inverse();
        this.varianceKnown = true;
    }

    public Matrix getSamplingVariance() {
        recomputeVariance();
        return this.samplingVariance;
    }

    public String getTraitName() {
        return this.traitName;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // dr.evomodel.treedatalikelihood.continuous.ContinuousTraitDataModel, dr.inference.model.AbstractModel
    public void handleVariableChangedEvent(Variable variable, int i, Variable.ChangeType changeType) {
        super.handleVariableChangedEvent(variable, i, changeType);
        if (variable == this.samplingPrecisionParameter) {
            this.variableChanged = true;
            this.varianceKnown = false;
            fireModelChanged();
        }
    }

    private void calculatePrecisionInfo() {
        this.samplingPrecision = new Matrix(this.samplingPrecisionParameter.getParameterAsMatrix());
    }

    private void checkVariableChanged() {
        if (this.variableChanged) {
            calculatePrecisionInfo();
            this.variableChanged = false;
            this.varianceKnown = false;
        }
    }

    @Override // dr.evomodel.treedatalikelihood.continuous.ContinuousTraitDataModel, dr.inference.model.AbstractModel
    protected void storeState() {
        this.storedSamplingPrecision = this.samplingPrecision.m1031clone();
        this.storedSamplingVariance = this.samplingVariance.m1031clone();
        this.storedVarianceKnown = this.varianceKnown;
        this.storedVariableChanged = this.variableChanged;
    }

    @Override // dr.evomodel.treedatalikelihood.continuous.ContinuousTraitDataModel, dr.inference.model.AbstractModel
    protected void restoreState() {
        Matrix matrix = this.samplingPrecision;
        this.samplingPrecision = this.storedSamplingPrecision;
        this.storedSamplingPrecision = matrix;
        Matrix matrix2 = this.samplingVariance;
        this.samplingVariance = this.storedSamplingVariance;
        this.storedSamplingVariance = matrix2;
        this.varianceKnown = this.storedVarianceKnown;
        this.variableChanged = this.storedVariableChanged;
    }

    @Override // dr.evomodel.treedatalikelihood.preorder.ModelExtensionProvider
    public ContinuousExtensionDelegate getExtensionDelegate(ContinuousDataLikelihoodDelegate continuousDataLikelihoodDelegate, TreeTrait treeTrait, Tree tree) {
        checkVariableChanged();
        return new ContinuousExtensionDelegate.MultivariateNormalExtensionDelegate(continuousDataLikelihoodDelegate, treeTrait, this, tree);
    }

    @Override // dr.evomodel.treedatalikelihood.preorder.ModelExtensionProvider.NormalExtensionProvider
    public DenseMatrix64F getExtensionVariance() {
        recomputeVariance();
        return DenseMatrix64F.wrap(this.dimTrait, this.dimTrait, this.samplingVariance.toArrayComponents());
    }

    @Override // dr.evomodel.treedatalikelihood.preorder.ModelExtensionProvider.NormalExtensionProvider
    public MatrixParameterInterface getExtensionPrecision() {
        checkVariableChanged();
        return this.samplingPrecisionParameter;
    }

    @Override // dr.evomodel.treedatalikelihood.preorder.ModelExtensionProvider.NormalExtensionProvider
    public double[] transformTreeTraits(double[] dArr) {
        return dArr;
    }

    @Override // dr.evomodel.treedatalikelihood.preorder.ModelExtensionProvider.NormalExtensionProvider
    public int getDataDimension() {
        return this.dimTrait;
    }

    static {
        $assertionsDisabled = !RepeatedMeasuresTraitDataModel.class.desiredAssertionStatus();
        PARSER = new AbstractXMLObjectParser() { // from class: dr.evomodel.treedatalikelihood.continuous.RepeatedMeasuresTraitDataModel.1
            @Override // dr.xml.AbstractXMLObjectParser
            public Object parseXMLObject(XMLObject xMLObject) throws XMLParseException {
                TreeTraitParserUtilities.TraitsAndMissingIndices parseTraitsFromTaxonAttributes = new TreeTraitParserUtilities().parseTraitsFromTaxonAttributes(xMLObject, "trait", (MutableTreeModel) xMLObject.getChild(TreeModel.class), true);
                CompoundParameter compoundParameter = parseTraitsFromTaxonAttributes.traitParameter;
                List<Integer> list = parseTraitsFromTaxonAttributes.missingIndices;
                MatrixParameterInterface matrixParameterInterface = (MatrixParameterInterface) xMLObject.getChild(RepeatedMeasuresTraitDataModel.PRECISION).getChild(MatrixParameterInterface.class);
                try {
                    if (new CholeskyDecomposition(matrixParameterInterface.getParameterAsMatrix()).isSPD()) {
                        return new RepeatedMeasuresTraitDataModel(parseTraitsFromTaxonAttributes.traitName, compoundParameter, list, true, matrixParameterInterface.getColumnDimension(), matrixParameterInterface);
                    }
                    throw new XMLParseException("samplingPrecision must be a positive definite matrix.");
                } catch (IllegalDimension e) {
                    throw new XMLParseException("samplingPrecision must be a square matrix.");
                }
            }

            @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
            public XMLSyntaxRule[] getSyntaxRules() {
                return RepeatedMeasuresTraitDataModel.rules;
            }

            @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
            public String getParserDescription() {
                return null;
            }

            @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
            public Class getReturnType() {
                return RepeatedMeasuresTraitDataModel.class;
            }

            @Override // dr.xml.XMLObjectParser
            public String getParserName() {
                return RepeatedMeasuresTraitDataModel.REPEATED_MEASURES_MODEL;
            }
        };
        rules = new XMLSyntaxRule[]{new ElementRule(PRECISION, new XMLSyntaxRule[]{new ElementRule(Parameter.class)}), new ElementRule(MutableTreeModel.class), AttributeRule.newStringRule("traitName"), new ElementRule("traitParameter", new XMLSyntaxRule[]{new ElementRule(Parameter.class)}), new ElementRule("missingIndicator", new XMLSyntaxRule[]{new ElementRule(Parameter.class)}, true)};
    }
}
