package dr.evomodel.treedatalikelihood.hmc;

import org.ejml.data.DenseMatrix64F;
import org.ejml.ops.CommonOps;

/* loaded from: input_file:dr/evomodel/treedatalikelihood/hmc/MultivariateChainRule.class */
public interface MultivariateChainRule {

    /* loaded from: input_file:dr/evomodel/treedatalikelihood/hmc/MultivariateChainRule$Chain.class */
    public static class Chain implements MultivariateChainRule {
        private final MultivariateChainRule[] rules;

        Chain(MultivariateChainRule[] multivariateChainRuleArr) {
            this.rules = multivariateChainRuleArr;
        }

        @Override // dr.evomodel.treedatalikelihood.hmc.MultivariateChainRule
        public double[] chainGradient(double[] dArr) {
            for (MultivariateChainRule multivariateChainRule : this.rules) {
                dArr = multivariateChainRule.chainGradient(dArr);
            }
            return dArr;
        }

        @Override // dr.evomodel.treedatalikelihood.hmc.MultivariateChainRule
        public void chainGradient(DenseMatrix64F denseMatrix64F) {
            for (MultivariateChainRule multivariateChainRule : this.rules) {
                multivariateChainRule.chainGradient(denseMatrix64F);
            }
        }
    }

    /* loaded from: input_file:dr/evomodel/treedatalikelihood/hmc/MultivariateChainRule$Inverse.class */
    public static class Inverse implements MultivariateChainRule {
        private final double[] vecP;
        private final double[] vecV;
        private final int dim;
        static final /* synthetic */ boolean $assertionsDisabled;

        Inverse(double[] dArr, double[] dArr2) {
            this.vecP = dArr;
            this.vecV = dArr2;
            this.dim = (int) Math.sqrt(dArr.length);
        }

        @Override // dr.evomodel.treedatalikelihood.hmc.MultivariateChainRule
        public double[] chainGradient(double[] dArr) {
            if (!$assertionsDisabled && dArr.length != this.dim * this.dim) {
                throw new AssertionError();
            }
            double[] dArr2 = new double[this.dim * this.dim];
            for (int i = 0; i < this.dim * this.dim; i++) {
                if (this.vecV[i] == 0.0d || Double.isNaN(this.vecV[i])) {
                    throw new RuntimeException("0 or NaN value in variance. check start value or use smaller step size for hmc");
                }
                dArr2[i] = ((-dArr[i]) * this.vecP[i]) / this.vecV[i];
            }
            return dArr2;
        }

        @Override // dr.evomodel.treedatalikelihood.hmc.MultivariateChainRule
        public void chainGradient(DenseMatrix64F denseMatrix64F) {
            throw new RuntimeException("not yet implemented");
        }

        static {
            $assertionsDisabled = !MultivariateChainRule.class.desiredAssertionStatus();
        }
    }

    /* loaded from: input_file:dr/evomodel/treedatalikelihood/hmc/MultivariateChainRule$InverseGeneral.class */
    public static class InverseGeneral implements MultivariateChainRule {
        private final DenseMatrix64F Mat;
        private final DenseMatrix64F temp;
        private final int dim;
        static final /* synthetic */ boolean $assertionsDisabled;

        public InverseGeneral(double[] dArr) {
            this.dim = (int) Math.sqrt(dArr.length);
            this.Mat = DenseMatrix64F.wrap(this.dim, this.dim, dArr);
            this.temp = new DenseMatrix64F(this.dim, this.dim);
        }

        public InverseGeneral(DenseMatrix64F denseMatrix64F) {
            this.dim = denseMatrix64F.getNumCols();
            if (!$assertionsDisabled && this.dim != denseMatrix64F.getNumRows()) {
                throw new AssertionError("Inverse rule is only valid for square matrices.");
            }
            this.Mat = denseMatrix64F;
            this.temp = new DenseMatrix64F(this.dim, this.dim);
        }

        @Override // dr.evomodel.treedatalikelihood.hmc.MultivariateChainRule
        public double[] chainGradient(double[] dArr) {
            if (!$assertionsDisabled && dArr.length != this.dim * this.dim) {
                throw new AssertionError();
            }
            DenseMatrix64F denseMatrix64F = new DenseMatrix64F(this.dim, this.dim);
            CommonOps.mult(this.Mat, DenseMatrix64F.wrap(this.dim, this.dim, dArr), this.temp);
            CommonOps.mult(-1.0d, this.temp, this.Mat, denseMatrix64F);
            return denseMatrix64F.getData();
        }

        @Override // dr.evomodel.treedatalikelihood.hmc.MultivariateChainRule
        public void chainGradient(DenseMatrix64F denseMatrix64F) {
            CommonOps.mult(this.Mat, denseMatrix64F, this.temp);
            CommonOps.mult(-1.0d, this.temp, this.Mat, denseMatrix64F);
        }

        static {
            $assertionsDisabled = !MultivariateChainRule.class.desiredAssertionStatus();
        }
    }

    double[] chainGradient(double[] dArr);

    void chainGradient(DenseMatrix64F denseMatrix64F);
}
