package dr.evomodel.operators;

import dr.evomodel.continuous.AbstractMultivariateTraitLikelihood;
import dr.inference.distribution.MultivariateDistributionLikelihood;
import dr.inference.distribution.WishartGammalDistributionModel;
import dr.inference.model.MatrixParameterInterface;
import dr.inference.model.Parameter;
import dr.inference.operators.GibbsOperator;
import dr.inference.operators.MCMCOperator;
import dr.inference.operators.SimpleMCMCOperator;
import dr.math.distributions.GammaDistribution;
import dr.math.distributions.WishartDistribution;
import dr.math.distributions.WishartStatistics;
import dr.math.distributions.WishartSufficientStatistics;
import dr.math.interfaces.ConjugateWishartStatisticsProvider;
import dr.math.matrixAlgebra.IllegalDimension;
import dr.math.matrixAlgebra.SymmetricMatrix;
import dr.xml.AbstractXMLObjectParser;
import dr.xml.AttributeRule;
import dr.xml.ElementRule;
import dr.xml.XMLObject;
import dr.xml.XMLObjectParser;
import dr.xml.XMLParseException;
import dr.xml.XMLSyntaxRule;

/* loaded from: input_file:dr/evomodel/operators/CorrelationMatrixGibbsOperator.class */
public class CorrelationMatrixGibbsOperator extends SimpleMCMCOperator implements GibbsOperator {
    private static final String CORRELATION_OPERATOR = "correlationGibbsOperator";
    public static final String TREE_MODEL = "treeModel";
    public static final String DISTRIBUTION = "distribution";
    public static final String PRIOR = "prior";
    private final ConjugateWishartStatisticsProvider conjugateWishartProvider;
    private final MatrixParameterInterface inverseCorrelation;
    private Statistics priorStatistics;
    private Statistics workingStatistics;
    private double priorDf;
    private SymmetricMatrix priorInverseScaleMatrix;
    private final int dim;
    private double numberObservations;
    private double pathWeight = 1.0d;
    private boolean wishartIsModel;
    private WishartGammalDistributionModel priorModel;
    public static XMLObjectParser PARSER;
    private static final boolean DEBUG = false;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:dr/evomodel/operators/CorrelationMatrixGibbsOperator$Statistics.class */
    public class Statistics {
        final double degreesOfFreedom;
        final double[][] rateMatrix;

        Statistics(double d, double[][] dArr) {
            this.degreesOfFreedom = d;
            this.rateMatrix = dArr;
        }
    }

    private Statistics setupStatistics(WishartStatistics wishartStatistics) {
        double[][] scaleMatrix = wishartStatistics.getScaleMatrix();
        double[][] dArr = null;
        if (scaleMatrix != null) {
            dArr = new SymmetricMatrix(scaleMatrix).inverse().toComponents();
        }
        return new Statistics(wishartStatistics.getDF(), dArr);
    }

    private void setupWishartStatistics(WishartStatistics wishartStatistics) {
        this.priorDf = wishartStatistics.getDF();
        this.priorInverseScaleMatrix = null;
        double[][] scaleMatrix = wishartStatistics.getScaleMatrix();
        if (scaleMatrix != null) {
            this.priorInverseScaleMatrix = (SymmetricMatrix) new SymmetricMatrix(scaleMatrix).inverse();
        }
    }

    private void normalizeToGetInverseCorrelation(double[][] dArr) {
        double[][] components = new SymmetricMatrix(dArr).inverse().toComponents();
        double[] dArr2 = new double[this.dim];
        for (int i = 0; i < this.dim; i++) {
            dArr2[i] = Math.sqrt(components[i][i]);
        }
        for (int i2 = 0; i2 < this.dim; i2++) {
            for (int i3 = 0; i3 < this.dim; i3++) {
                dArr[i2][i3] = dArr2[i2] * dArr2[i3] * dArr[i2][i3];
            }
        }
    }

    public CorrelationMatrixGibbsOperator(ConjugateWishartStatisticsProvider conjugateWishartStatisticsProvider, MatrixParameterInterface matrixParameterInterface, WishartStatistics wishartStatistics, WishartStatistics wishartStatistics2, double d) {
        this.wishartIsModel = false;
        this.priorModel = null;
        this.conjugateWishartProvider = conjugateWishartStatisticsProvider;
        this.inverseCorrelation = matrixParameterInterface != null ? matrixParameterInterface : this.conjugateWishartProvider.getPrecisionParameter();
        setupWishartStatistics(wishartStatistics);
        this.priorStatistics = setupStatistics(wishartStatistics);
        if (wishartStatistics instanceof WishartGammalDistributionModel) {
            this.wishartIsModel = true;
            this.priorModel = (WishartGammalDistributionModel) wishartStatistics;
        }
        if (wishartStatistics2 != null) {
            this.workingStatistics = setupStatistics(wishartStatistics2);
        }
        setWeight(d);
        this.dim = this.inverseCorrelation.getRowDimension();
    }

    @Override // dr.inference.operators.SimpleMCMCOperator, dr.inference.operators.PathDependent
    public void setPathParameter(double d) {
        if (d < 0.0d || d > 1.0d) {
            throw new IllegalArgumentException("Illegal path weight of " + d);
        }
        this.pathWeight = d;
    }

    public int getStepCount() {
        return 1;
    }

    private double[] getDiagRescaleMatrix() {
        double[] dArr = new double[this.dim];
        for (int i = 0; i < this.dim; i++) {
            dArr[i] = Math.sqrt(this.inverseCorrelation.getParameterValue(i, i) / (2.0d * GammaDistribution.nextGamma((this.dim + 1) / 2.0d, 1.0d)));
        }
        return dArr;
    }

    private void rescaleOuterProduct(double[] dArr) {
        double[] diagRescaleMatrix = getDiagRescaleMatrix();
        for (int i = 0; i < this.dim; i++) {
            for (int i2 = 0; i2 < this.dim; i2++) {
                dArr[(i2 * this.dim) + i] = diagRescaleMatrix[i] * diagRescaleMatrix[i2] * dArr[(i2 * this.dim) + i];
            }
        }
    }

    private void incrementOuterProductWithRescale(double[][] dArr, ConjugateWishartStatisticsProvider conjugateWishartStatisticsProvider) {
        WishartSufficientStatistics wishartStatistics = conjugateWishartStatisticsProvider.getWishartStatistics();
        double[] scaleMatrix = wishartStatistics.getScaleMatrix();
        rescaleOuterProduct(scaleMatrix);
        double df = wishartStatistics.getDf();
        int length = dArr.length;
        for (int i = 0; i < length; i++) {
            System.arraycopy(scaleMatrix, i * length, dArr[i], 0, length);
        }
        this.numberObservations = df;
    }

    private double[][] getOperationScaleMatrixAndSetObservationCount2() {
        double[][] dArr = new double[this.dim][this.dim];
        SymmetricMatrix symmetricMatrix = null;
        this.numberObservations = 0.0d;
        incrementOuterProductWithRescale(dArr, this.conjugateWishartProvider);
        try {
            SymmetricMatrix symmetricMatrix2 = new SymmetricMatrix(dArr);
            if (this.priorInverseScaleMatrix != null) {
                symmetricMatrix2 = this.priorInverseScaleMatrix.add(symmetricMatrix2);
            }
            symmetricMatrix = (SymmetricMatrix) symmetricMatrix2.inverse();
        } catch (IllegalDimension e) {
            e.printStackTrace();
        }
        if ($assertionsDisabled || symmetricMatrix != null) {
            return symmetricMatrix.toComponents();
        }
        throw new AssertionError();
    }

    @Override // dr.inference.operators.SimpleMCMCOperator
    public double doOperation() {
        if (this.wishartIsModel) {
            setupWishartStatistics(this.priorModel);
            this.priorStatistics = setupStatistics(this.priorModel);
        }
        double[][] nextWishart = WishartDistribution.nextWishart(this.priorDf + (this.numberObservations * this.pathWeight), getOperationScaleMatrixAndSetObservationCount2());
        normalizeToGetInverseCorrelation(nextWishart);
        for (int i = 0; i < this.dim; i++) {
            Parameter parameter = this.inverseCorrelation.getParameter(i);
            for (int i2 = 0; i2 < this.dim; i2++) {
                parameter.setParameterValueQuietly(i2, nextWishart[i2][i]);
            }
        }
        this.inverseCorrelation.fireParameterChangedEvent();
        return 0.0d;
    }

    public String getPerformanceSuggestion() {
        return null;
    }

    @Override // dr.inference.operators.SimpleMCMCOperator, dr.inference.operators.MCMCOperator
    public String getOperatorName() {
        return CORRELATION_OPERATOR;
    }

    static {
        $assertionsDisabled = !CorrelationMatrixGibbsOperator.class.desiredAssertionStatus();
        PARSER = new AbstractXMLObjectParser() { // from class: dr.evomodel.operators.CorrelationMatrixGibbsOperator.1
            private XMLSyntaxRule[] rules = {AttributeRule.newDoubleRule("weight"), new ElementRule(AbstractMultivariateTraitLikelihood.class, true), new ElementRule(ConjugateWishartStatisticsProvider.class, true), new ElementRule(MultivariateDistributionLikelihood.class, 1, 2), new ElementRule(MatrixParameterInterface.class, true)};

            @Override // dr.xml.XMLObjectParser
            public String getParserName() {
                return CorrelationMatrixGibbsOperator.CORRELATION_OPERATOR;
            }

            @Override // dr.xml.AbstractXMLObjectParser
            public Object parseXMLObject(XMLObject xMLObject) throws XMLParseException {
                double doubleAttribute = xMLObject.getDoubleAttribute("weight");
                ConjugateWishartStatisticsProvider conjugateWishartStatisticsProvider = (ConjugateWishartStatisticsProvider) xMLObject.getChild(ConjugateWishartStatisticsProvider.class);
                MatrixParameterInterface precisionParameter = conjugateWishartStatisticsProvider.getPrecisionParameter();
                MultivariateDistributionLikelihood multivariateDistributionLikelihood = (MultivariateDistributionLikelihood) xMLObject.getChild(MultivariateDistributionLikelihood.class);
                if (!(multivariateDistributionLikelihood.getDistribution() instanceof WishartStatistics)) {
                    throw new XMLParseException("Only a Wishart distribution is conjugate for Gibbs sampling");
                }
                if (precisionParameter.getColumnDimension() != precisionParameter.getRowDimension()) {
                    throw new XMLParseException("The variance matrix is not square or of wrong dimension");
                }
                return new CorrelationMatrixGibbsOperator(conjugateWishartStatisticsProvider, precisionParameter, (WishartStatistics) multivariateDistributionLikelihood.getDistribution(), null, doubleAttribute);
            }

            @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
            public String getParserDescription() {
                return "This element returns a multivariate normal random walk operator on a given parameter.";
            }

            @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
            public Class getReturnType() {
                return MCMCOperator.class;
            }

            @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
            public XMLSyntaxRule[] getSyntaxRules() {
                return this.rules;
            }
        };
    }
}
