package dr.evomodel.operators;

import dr.evolution.tree.MutableTreeModel;
import dr.evolution.tree.NodeRef;
import dr.evomodel.continuous.AbstractMultivariateTraitLikelihood;
import dr.evomodel.continuous.SampledMultivariateTraitLikelihood;
import dr.inference.distribution.MultivariateDistributionLikelihood;
import dr.inference.distribution.MultivariateNormalDistributionModel;
import dr.inference.distribution.WishartGammalDistributionModel;
import dr.inference.model.DiagonalConstrainedMatrixView;
import dr.inference.model.MatrixParameterInterface;
import dr.inference.model.Parameter;
import dr.inference.operators.GibbsOperator;
import dr.inference.operators.MCMCOperator;
import dr.inference.operators.SimpleMCMCOperator;
import dr.math.distributions.WishartDistribution;
import dr.math.distributions.WishartStatistics;
import dr.math.distributions.WishartSufficientStatistics;
import dr.math.interfaces.ConjugateWishartStatisticsProvider;
import dr.math.matrixAlgebra.IllegalDimension;
import dr.math.matrixAlgebra.SymmetricMatrix;
import dr.math.matrixAlgebra.Vector;
import dr.util.Attribute;
import dr.xml.AbstractXMLObjectParser;
import dr.xml.AttributeRule;
import dr.xml.ElementRule;
import dr.xml.XMLObject;
import dr.xml.XMLObjectParser;
import dr.xml.XMLParseException;
import dr.xml.XMLSyntaxRule;
import java.util.Iterator;

/* loaded from: input_file:dr/evomodel/operators/PrecisionMatrixGibbsOperator.class */
public class PrecisionMatrixGibbsOperator extends SimpleMCMCOperator implements GibbsOperator {
    private static final String VARIANCE_OPERATOR = "precisionGibbsOperator";
    public static final String TREE_MODEL = "treeModel";
    public static final String DISTRIBUTION = "distribution";
    public static final String PRIOR = "prior";
    private static final String WORKING = "workingDistribution";
    private final AbstractMultivariateTraitLikelihood traitModel;
    private AbstractMultivariateTraitLikelihood debugModel;
    private final ConjugateWishartStatisticsProvider conjugateWishartProvider;
    private final MultivariateDistributionLikelihood multivariateLikelihood;
    private final Parameter meanParam;
    private final MatrixParameterInterface precisionParam;
    private Statistics priorStatistics;
    private Statistics workingStatistics;
    private double priorDf;
    private SymmetricMatrix priorInverseScaleMatrix;
    private final MutableTreeModel treeModel;
    private final int dim;
    private double numberObservations;
    private final String traitName;
    private final boolean isSampledTraitLikelihood;
    private double pathWeight;
    private boolean wishartIsModel;
    private WishartGammalDistributionModel priorModel;
    public static XMLObjectParser PARSER;
    private static final boolean DEBUG = false;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:dr/evomodel/operators/PrecisionMatrixGibbsOperator$Statistics.class */
    public class Statistics {
        final double degreesOfFreedom;
        final double[][] rateMatrix;

        Statistics(double d, double[][] dArr) {
            this.degreesOfFreedom = d;
            this.rateMatrix = dArr;
        }
    }

    public PrecisionMatrixGibbsOperator(MultivariateDistributionLikelihood multivariateDistributionLikelihood, WishartStatistics wishartStatistics, double d) {
        this.debugModel = null;
        this.pathWeight = 1.0d;
        this.wishartIsModel = false;
        this.priorModel = null;
        this.traitModel = null;
        this.treeModel = null;
        this.traitName = null;
        this.conjugateWishartProvider = null;
        this.isSampledTraitLikelihood = false;
        this.multivariateLikelihood = multivariateDistributionLikelihood;
        MultivariateNormalDistributionModel multivariateNormalDistributionModel = (MultivariateNormalDistributionModel) multivariateDistributionLikelihood.getDistribution();
        this.meanParam = multivariateNormalDistributionModel.getMeanParameter();
        this.precisionParam = multivariateNormalDistributionModel.getPrecisionMatrixParameter();
        this.dim = this.meanParam.getDimension();
        setupWishartStatistics(wishartStatistics);
        this.priorStatistics = setupStatistics(wishartStatistics);
        if (wishartStatistics instanceof WishartGammalDistributionModel) {
            this.wishartIsModel = true;
            this.priorModel = (WishartGammalDistributionModel) wishartStatistics;
        }
        setWeight(d);
    }

    private Statistics setupStatistics(WishartStatistics wishartStatistics) {
        double[][] scaleMatrix = wishartStatistics.getScaleMatrix();
        double[][] dArr = null;
        if (scaleMatrix != null) {
            dArr = new SymmetricMatrix(scaleMatrix).inverse().toComponents();
        }
        return new Statistics(wishartStatistics.getDF(), dArr);
    }

    private void setupWishartStatistics(WishartStatistics wishartStatistics) {
        this.priorDf = wishartStatistics.getDF();
        this.priorInverseScaleMatrix = null;
        double[][] scaleMatrix = wishartStatistics.getScaleMatrix();
        if (scaleMatrix != null) {
            this.priorInverseScaleMatrix = (SymmetricMatrix) new SymmetricMatrix(scaleMatrix).inverse();
        }
    }

    @Deprecated
    public PrecisionMatrixGibbsOperator(MatrixParameterInterface matrixParameterInterface, AbstractMultivariateTraitLikelihood abstractMultivariateTraitLikelihood, WishartStatistics wishartStatistics, double d) {
        this.debugModel = null;
        this.pathWeight = 1.0d;
        this.wishartIsModel = false;
        this.priorModel = null;
        this.traitModel = abstractMultivariateTraitLikelihood;
        this.conjugateWishartProvider = null;
        this.meanParam = null;
        this.precisionParam = matrixParameterInterface;
        setupWishartStatistics(wishartStatistics);
        this.priorStatistics = setupStatistics(wishartStatistics);
        if (wishartStatistics instanceof WishartGammalDistributionModel) {
            this.wishartIsModel = true;
            this.priorModel = (WishartGammalDistributionModel) wishartStatistics;
        }
        setWeight(d);
        this.treeModel = abstractMultivariateTraitLikelihood.getTreeModel();
        this.traitName = abstractMultivariateTraitLikelihood.getTraitName();
        this.dim = matrixParameterInterface.getRowDimension();
        this.isSampledTraitLikelihood = abstractMultivariateTraitLikelihood instanceof SampledMultivariateTraitLikelihood;
        if (!this.isSampledTraitLikelihood && !(abstractMultivariateTraitLikelihood instanceof ConjugateWishartStatisticsProvider)) {
            throw new RuntimeException("Only implemented for a SampledMultivariateTraitLikelihood or ConjugateWishartStatisticsProvider");
        }
        this.multivariateLikelihood = null;
    }

    public PrecisionMatrixGibbsOperator(ConjugateWishartStatisticsProvider conjugateWishartStatisticsProvider, MatrixParameterInterface matrixParameterInterface, WishartStatistics wishartStatistics, WishartStatistics wishartStatistics2, double d, AbstractMultivariateTraitLikelihood abstractMultivariateTraitLikelihood) {
        this.debugModel = null;
        this.pathWeight = 1.0d;
        this.wishartIsModel = false;
        this.priorModel = null;
        this.traitModel = null;
        this.debugModel = abstractMultivariateTraitLikelihood;
        this.conjugateWishartProvider = conjugateWishartStatisticsProvider;
        this.meanParam = null;
        this.precisionParam = matrixParameterInterface != null ? matrixParameterInterface : this.conjugateWishartProvider.getPrecisionParameter();
        this.isSampledTraitLikelihood = false;
        this.treeModel = null;
        this.traitName = null;
        setupWishartStatistics(wishartStatistics);
        this.priorStatistics = setupStatistics(wishartStatistics);
        if (wishartStatistics instanceof WishartGammalDistributionModel) {
            this.wishartIsModel = true;
            this.priorModel = (WishartGammalDistributionModel) wishartStatistics;
        }
        if (wishartStatistics2 != null) {
            this.workingStatistics = setupStatistics(wishartStatistics2);
        }
        setWeight(d);
        this.dim = this.precisionParam.getRowDimension();
        this.multivariateLikelihood = null;
    }

    @Override // dr.inference.operators.SimpleMCMCOperator, dr.inference.operators.PathDependent
    public void setPathParameter(double d) {
        if (d < 0.0d || d > 1.0d) {
            throw new IllegalArgumentException("Illegal path weight of " + d);
        }
        this.pathWeight = d;
    }

    public int getStepCount() {
        return 1;
    }

    private void incrementOuterProduct(double[][] dArr, MultivariateDistributionLikelihood multivariateDistributionLikelihood) {
        double[] mean = multivariateDistributionLikelihood.getDistribution().getMean();
        this.numberObservations = 0.0d;
        Iterator<Attribute<double[]>> it = multivariateDistributionLikelihood.getDataList().iterator();
        while (it.hasNext()) {
            double[] attributeValue = it.next().getAttributeValue();
            for (int i = 0; i < this.dim; i++) {
                int i2 = i;
                attributeValue[i2] = attributeValue[i2] - mean[i];
            }
            for (int i3 = 0; i3 < this.dim; i3++) {
                for (int i4 = i3; i4 < this.dim; i4++) {
                    double[] dArr2 = dArr[i3];
                    int i5 = i4;
                    double d = dArr2[i5] + (attributeValue[i3] * attributeValue[i4]);
                    dArr2[i5] = d;
                    dArr[i4][i3] = d;
                }
            }
            this.numberObservations += 1.0d;
        }
    }

    private void incrementOuterProduct(double[][] dArr, ConjugateWishartStatisticsProvider conjugateWishartStatisticsProvider) {
        WishartSufficientStatistics wishartStatistics = conjugateWishartStatisticsProvider.getWishartStatistics();
        double[] scaleMatrix = wishartStatistics.getScaleMatrix();
        double df = wishartStatistics.getDf();
        if (this.debugModel != null) {
            WishartSufficientStatistics wishartStatistics2 = ((ConjugateWishartStatisticsProvider) this.debugModel).getWishartStatistics();
            System.err.println(df + " ?= " + wishartStatistics2.getDf());
            System.err.println(new Vector(scaleMatrix));
            System.err.println("");
            System.err.println(new Vector(wishartStatistics2.getScaleMatrix()));
            System.exit(-1);
        }
        int length = dArr.length;
        for (int i = 0; i < length; i++) {
            System.arraycopy(scaleMatrix, i * length, dArr[i], 0, length);
        }
        this.numberObservations = df;
    }

    private void incrementOuterProduct(double[][] dArr, NodeRef nodeRef) {
        if (!this.treeModel.isRoot(nodeRef)) {
            double[] multivariateNodeTrait = this.treeModel.getMultivariateNodeTrait(this.treeModel.getParent(nodeRef), this.traitName);
            double[] multivariateNodeTrait2 = this.treeModel.getMultivariateNodeTrait(nodeRef, this.traitName);
            double rescaledBranchLengthForPrecision = this.traitModel.getRescaledBranchLengthForPrecision(nodeRef);
            if (rescaledBranchLengthForPrecision > 0.0d) {
                double sqrt = Math.sqrt(rescaledBranchLengthForPrecision);
                double[] dArr2 = new double[this.dim];
                for (int i = 0; i < this.dim; i++) {
                    dArr2[i] = (multivariateNodeTrait2[i] - multivariateNodeTrait[i]) / sqrt;
                }
                for (int i2 = 0; i2 < this.dim; i2++) {
                    for (int i3 = i2; i3 < this.dim; i3++) {
                        double[] dArr3 = dArr[i2];
                        int i4 = i3;
                        double d = dArr3[i4] + (dArr2[i2] * dArr2[i3]);
                        dArr3[i4] = d;
                        dArr[i3][i2] = d;
                    }
                }
                this.numberObservations += 1.0d;
            }
        }
        for (int i5 = 0; i5 < this.treeModel.getChildCount(nodeRef); i5++) {
            incrementOuterProduct(dArr, this.treeModel.getChild(nodeRef, i5));
        }
    }

    private double[][] getOperationScaleMatrixAndSetObservationCount() {
        double[][] dArr = new double[this.dim][this.dim];
        SymmetricMatrix symmetricMatrix = null;
        this.numberObservations = 0.0d;
        if (this.isSampledTraitLikelihood) {
            incrementOuterProduct(dArr, this.treeModel.getRoot());
        } else if (this.traitModel != null) {
            incrementOuterProduct(dArr, (ConjugateWishartStatisticsProvider) this.traitModel);
        } else if (this.conjugateWishartProvider != null) {
            incrementOuterProduct(dArr, this.conjugateWishartProvider);
        } else {
            incrementOuterProduct(dArr, this.multivariateLikelihood);
        }
        try {
            SymmetricMatrix symmetricMatrix2 = new SymmetricMatrix(dArr);
            if (this.pathWeight != 1.0d) {
                symmetricMatrix2 = (SymmetricMatrix) symmetricMatrix2.product(this.pathWeight);
            }
            if (this.priorInverseScaleMatrix != null) {
                symmetricMatrix2 = this.priorInverseScaleMatrix.add(symmetricMatrix2);
            }
            symmetricMatrix = (SymmetricMatrix) symmetricMatrix2.inverse();
        } catch (IllegalDimension e) {
            e.printStackTrace();
        }
        if ($assertionsDisabled || symmetricMatrix != null) {
            return symmetricMatrix.toComponents();
        }
        throw new AssertionError();
    }

    @Override // dr.inference.operators.SimpleMCMCOperator
    public double doOperation() {
        doOperationDontFireChange();
        this.precisionParam.fireParameterChangedEvent();
        return 0.0d;
    }

    public void doOperationDontFireChange() {
        if (this.wishartIsModel) {
            setupWishartStatistics(this.priorModel);
            this.priorStatistics = setupStatistics(this.priorModel);
        }
        double[][] nextWishart = WishartDistribution.nextWishart(this.priorDf + (this.numberObservations * this.pathWeight), getOperationScaleMatrixAndSetObservationCount());
        for (int i = 0; i < this.dim; i++) {
            Parameter parameter = this.precisionParam.getParameter(i);
            for (int i2 = 0; i2 < this.dim; i2++) {
                parameter.setParameterValueQuietly(i2, nextWishart[i2][i]);
            }
        }
    }

    public MatrixParameterInterface getPrecisionParam() {
        return this.precisionParam;
    }

    public ConjugateWishartStatisticsProvider getConjugateWishartProvider() {
        return this.conjugateWishartProvider;
    }

    public String getPerformanceSuggestion() {
        return null;
    }

    @Override // dr.inference.operators.SimpleMCMCOperator, dr.inference.operators.MCMCOperator
    public String getOperatorName() {
        return "precisionGibbsOperator";
    }

    static {
        $assertionsDisabled = !PrecisionMatrixGibbsOperator.class.desiredAssertionStatus();
        PARSER = new AbstractXMLObjectParser() { // from class: dr.evomodel.operators.PrecisionMatrixGibbsOperator.1
            private XMLSyntaxRule[] rules = {AttributeRule.newDoubleRule("weight"), new ElementRule(AbstractMultivariateTraitLikelihood.class, true), new ElementRule(ConjugateWishartStatisticsProvider.class, true), new ElementRule(MultivariateDistributionLikelihood.class, 1, 2), new ElementRule(MatrixParameterInterface.class, true)};

            @Override // dr.xml.XMLObjectParser
            public String getParserName() {
                return "precisionGibbsOperator";
            }

            @Override // dr.xml.AbstractXMLObjectParser
            public Object parseXMLObject(XMLObject xMLObject) throws XMLParseException {
                MatrixParameterInterface matrixParameterInterface;
                double doubleAttribute = xMLObject.getDoubleAttribute("weight");
                AbstractMultivariateTraitLikelihood abstractMultivariateTraitLikelihood = (AbstractMultivariateTraitLikelihood) xMLObject.getChild(AbstractMultivariateTraitLikelihood.class);
                ConjugateWishartStatisticsProvider conjugateWishartStatisticsProvider = (ConjugateWishartStatisticsProvider) xMLObject.getChild(ConjugateWishartStatisticsProvider.class);
                if (conjugateWishartStatisticsProvider == abstractMultivariateTraitLikelihood) {
                    conjugateWishartStatisticsProvider = null;
                }
                MultivariateDistributionLikelihood multivariateDistributionLikelihood = null;
                MatrixParameterInterface matrixParameterInterface2 = null;
                MultivariateDistributionLikelihood multivariateDistributionLikelihood2 = null;
                if (abstractMultivariateTraitLikelihood != null) {
                    matrixParameterInterface2 = abstractMultivariateTraitLikelihood.getDiffusionModel().getPrecisionParameter();
                    multivariateDistributionLikelihood = (MultivariateDistributionLikelihood) xMLObject.getChild(MultivariateDistributionLikelihood.class);
                }
                if (conjugateWishartStatisticsProvider != null) {
                    matrixParameterInterface2 = conjugateWishartStatisticsProvider.getPrecisionParameter();
                    multivariateDistributionLikelihood = (MultivariateDistributionLikelihood) xMLObject.getChild(MultivariateDistributionLikelihood.class);
                }
                if (abstractMultivariateTraitLikelihood == null && conjugateWishartStatisticsProvider == null) {
                    for (int i = 0; i < xMLObject.getChildCount(); i++) {
                        MultivariateDistributionLikelihood multivariateDistributionLikelihood3 = (MultivariateDistributionLikelihood) xMLObject.getChild(i);
                        if (multivariateDistributionLikelihood3.getDistribution() instanceof WishartStatistics) {
                            multivariateDistributionLikelihood = multivariateDistributionLikelihood3;
                        } else if (multivariateDistributionLikelihood3.getDistribution() instanceof MultivariateNormalDistributionModel) {
                            multivariateDistributionLikelihood2 = multivariateDistributionLikelihood3;
                            matrixParameterInterface2 = ((MultivariateNormalDistributionModel) multivariateDistributionLikelihood3.getDistribution()).getPrecisionMatrixParameter();
                        }
                    }
                    if (multivariateDistributionLikelihood == null || multivariateDistributionLikelihood2 == null) {
                        throw new XMLParseException("Must provide a multivariate normal likelihood and Wishart prior in element '" + xMLObject.getName() + "'\n");
                    }
                }
                if (!(multivariateDistributionLikelihood.getDistribution() instanceof WishartStatistics)) {
                    throw new XMLParseException("Only a Wishart distribution is conjugate for Gibbs sampling");
                }
                if (matrixParameterInterface2.getColumnDimension() != matrixParameterInterface2.getRowDimension()) {
                    throw new XMLParseException("The variance matrix is not square or of wrong dimension");
                }
                if (abstractMultivariateTraitLikelihood != null && conjugateWishartStatisticsProvider == null) {
                    if (matrixParameterInterface2 instanceof DiagonalConstrainedMatrixView) {
                        matrixParameterInterface2 = (MatrixParameterInterface) xMLObject.getChild(MatrixParameterInterface.class);
                        if (matrixParameterInterface2 == null) {
                            throw new XMLParseException("Must provide unconstrained precision matrix");
                        }
                    }
                    return new PrecisionMatrixGibbsOperator(matrixParameterInterface2, abstractMultivariateTraitLikelihood, (WishartStatistics) multivariateDistributionLikelihood.getDistribution(), doubleAttribute);
                }
                if (conjugateWishartStatisticsProvider == null) {
                    return new PrecisionMatrixGibbsOperator(multivariateDistributionLikelihood2, (WishartStatistics) multivariateDistributionLikelihood.getDistribution(), doubleAttribute);
                }
                if (matrixParameterInterface2 instanceof DiagonalConstrainedMatrixView) {
                    matrixParameterInterface = (MatrixParameterInterface) xMLObject.getChild(MatrixParameterInterface.class);
                    if (matrixParameterInterface == null) {
                        throw new XMLParseException("Must provide unconstrained precision matrix");
                    }
                } else {
                    matrixParameterInterface = null;
                }
                return new PrecisionMatrixGibbsOperator(conjugateWishartStatisticsProvider, matrixParameterInterface, (WishartStatistics) multivariateDistributionLikelihood.getDistribution(), xMLObject.hasChildNamed(PrecisionMatrixGibbsOperator.WORKING) ? (WishartStatistics) xMLObject.getElementFirstChild(PrecisionMatrixGibbsOperator.WORKING) : null, doubleAttribute, abstractMultivariateTraitLikelihood);
            }

            @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
            public String getParserDescription() {
                return "This element returns a multivariate normal random walk operator on a given parameter.";
            }

            @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
            public Class getReturnType() {
                return MCMCOperator.class;
            }

            @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
            public XMLSyntaxRule[] getSyntaxRules() {
                return this.rules;
            }
        };
    }
}
