package dr.evomodel.treedatalikelihood.hmc;

import dr.evomodel.treedatalikelihood.TreeDataLikelihood;
import dr.evomodel.treedatalikelihood.continuous.BranchSpecificGradient;
import dr.evomodel.treedatalikelihood.hmc.MultivariateChainRule;
import dr.math.distributions.WishartSufficientStatistics;
import dr.math.interfaces.ConjugateWishartStatisticsProvider;

/* loaded from: input_file:dr/evomodel/treedatalikelihood/hmc/GradientWrtPrecisionProvider.class */
public interface GradientWrtPrecisionProvider {

    /* loaded from: input_file:dr/evomodel/treedatalikelihood/hmc/GradientWrtPrecisionProvider$AbstractGradientWrtPrecisionProvider.class */
    public static abstract class AbstractGradientWrtPrecisionProvider implements GradientWrtPrecisionProvider {
        int dim;

        @Override // dr.evomodel.treedatalikelihood.hmc.GradientWrtPrecisionProvider
        public ConjugateWishartStatisticsProvider getWishartStatistic() {
            return null;
        }

        @Override // dr.evomodel.treedatalikelihood.hmc.GradientWrtPrecisionProvider
        public BranchSpecificGradient getBranchSpecificGradient() {
            return null;
        }
    }

    /* loaded from: input_file:dr/evomodel/treedatalikelihood/hmc/GradientWrtPrecisionProvider$BranchSpecificGradientWrtPrecisionProvider.class */
    public static class BranchSpecificGradientWrtPrecisionProvider extends AbstractGradientWrtPrecisionProvider {
        private final BranchSpecificGradient branchSpecificGradient;

        public BranchSpecificGradientWrtPrecisionProvider(BranchSpecificGradient branchSpecificGradient) {
            this.branchSpecificGradient = branchSpecificGradient;
            this.dim = ((TreeDataLikelihood) branchSpecificGradient.getLikelihood()).getDataLikelihoodDelegate().getTraitDim();
        }

        @Override // dr.evomodel.treedatalikelihood.hmc.GradientWrtPrecisionProvider
        public double[] getGradientWrtPrecision(double[] dArr, double[] dArr2) {
            return new MultivariateChainRule.InverseGeneral(dArr).chainGradient(dArr2);
        }

        @Override // dr.evomodel.treedatalikelihood.hmc.GradientWrtPrecisionProvider
        public double[] getGradientWrtVariance(double[] dArr, double[] dArr2, double[] dArr3) {
            return dArr3;
        }

        @Override // dr.evomodel.treedatalikelihood.hmc.GradientWrtPrecisionProvider.AbstractGradientWrtPrecisionProvider, dr.evomodel.treedatalikelihood.hmc.GradientWrtPrecisionProvider
        public BranchSpecificGradient getBranchSpecificGradient() {
            return this.branchSpecificGradient;
        }
    }

    /* loaded from: input_file:dr/evomodel/treedatalikelihood/hmc/GradientWrtPrecisionProvider$WishartGradientWrtPrecisionProvider.class */
    public static class WishartGradientWrtPrecisionProvider extends AbstractGradientWrtPrecisionProvider {
        private final ConjugateWishartStatisticsProvider wishartStatistics;
        static final /* synthetic */ boolean $assertionsDisabled;

        public WishartGradientWrtPrecisionProvider(ConjugateWishartStatisticsProvider conjugateWishartStatisticsProvider) {
            this.wishartStatistics = conjugateWishartStatisticsProvider;
            this.dim = conjugateWishartStatisticsProvider.getPrecisionParameter().getRowDimension();
        }

        @Override // dr.evomodel.treedatalikelihood.hmc.GradientWrtPrecisionProvider
        public double[] getGradientWrtPrecision(double[] dArr, double[] dArr2) {
            WishartSufficientStatistics wishartStatistics = this.wishartStatistics.getWishartStatistics();
            return getGradientWrtPrecision(dArr, wishartStatistics.getDf(), wishartStatistics.getScaleMatrix());
        }

        private double[] getGradientWrtPrecision(double[] dArr, int i, double[] dArr2) {
            if (!$assertionsDisabled && dArr.length != this.dim * this.dim) {
                throw new AssertionError();
            }
            if (!$assertionsDisabled && dArr2.length != this.dim * this.dim) {
                throw new AssertionError();
            }
            if (!$assertionsDisabled && i <= 0) {
                throw new AssertionError();
            }
            double[] dArr3 = new double[this.dim * this.dim];
            for (int i2 = 0; i2 < this.dim * this.dim; i2++) {
                dArr3[i2] = 0.5d * ((i * dArr[i2]) - dArr2[i2]);
            }
            return dArr3;
        }

        @Override // dr.evomodel.treedatalikelihood.hmc.GradientWrtPrecisionProvider
        public double[] getGradientWrtVariance(double[] dArr, double[] dArr2, double[] dArr3) {
            return new MultivariateChainRule.InverseGeneral(dArr).chainGradient(getGradientWrtPrecision(dArr2, dArr3));
        }

        @Override // dr.evomodel.treedatalikelihood.hmc.GradientWrtPrecisionProvider.AbstractGradientWrtPrecisionProvider, dr.evomodel.treedatalikelihood.hmc.GradientWrtPrecisionProvider
        public ConjugateWishartStatisticsProvider getWishartStatistic() {
            return this.wishartStatistics;
        }

        static {
            $assertionsDisabled = !GradientWrtPrecisionProvider.class.desiredAssertionStatus();
        }
    }

    double[] getGradientWrtPrecision(double[] dArr, double[] dArr2);

    double[] getGradientWrtVariance(double[] dArr, double[] dArr2, double[] dArr3);

    ConjugateWishartStatisticsProvider getWishartStatistic();

    BranchSpecificGradient getBranchSpecificGradient();
}
