package dr.evomodel.treedatalikelihood.continuous;

import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeTrait;
import dr.evomodel.branchratemodel.ArbitraryBranchRates;
import dr.evomodel.branchratemodel.BranchRateModel;
import dr.evomodel.treedatalikelihood.TreeDataLikelihood;
import dr.evomodel.treedatalikelihood.preorder.BranchConditionalDistributionDelegate;
import dr.evomodel.treedatalikelihood.preorder.BranchSufficientStatistics;
import dr.evomodel.treedatalikelihood.preorder.ConditionalPrecisionAndTransform2;
import dr.evomodel.treedatalikelihood.preorder.MatrixSufficientStatistics;
import dr.evomodel.treedatalikelihood.preorder.NormalSufficientStatistics;
import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.hmc.HessianWrtParameterProvider;
import dr.inference.loggers.LogColumn;
import dr.inference.loggers.Loggable;
import dr.inference.model.Likelihood;
import dr.inference.model.Parameter;
import dr.inference.operators.hmc.NumericalHessianFromGradient;
import dr.math.MultivariateFunction;
import dr.math.NumericalDerivative;
import dr.math.matrixAlgebra.Vector;
import dr.math.matrixAlgebra.WrappedVector;
import dr.math.matrixAlgebra.missingData.MissingOps;
import dr.math.matrixAlgebra.missingData.PermutationIndices;
import dr.xml.Reportable;
import java.util.List;
import org.ejml.data.DenseMatrix64F;
import org.ejml.ops.CommonOps;

/* loaded from: input_file:dr/evomodel/treedatalikelihood/continuous/BranchRateGradient.class */
public class BranchRateGradient implements GradientWrtParameterProvider, HessianWrtParameterProvider, Reportable, Loggable {
    private final TreeDataLikelihood treeDataLikelihood;
    private final TreeTrait<List<BranchSufficientStatistics>> treeTraitProvider;
    private final Tree tree;
    private final int nTraits;
    private final Parameter rateParameter;
    private final ArbitraryBranchRates branchRateModel;
    private final ContinuousTraitGradientForBranch branchProvider;
    private MultivariateFunction numeric1 = new MultivariateFunction() { // from class: dr.evomodel.treedatalikelihood.continuous.BranchRateGradient.1
        @Override // dr.math.MultivariateFunction
        public double evaluate(double[] dArr) {
            for (int i = 0; i < dArr.length; i++) {
                BranchRateGradient.this.rateParameter.setParameterValue(i, dArr[i]);
            }
            BranchRateGradient.this.treeDataLikelihood.makeDirty();
            return BranchRateGradient.this.treeDataLikelihood.getLogLikelihood();
        }

        @Override // dr.math.MultivariateFunction
        public int getNumArguments() {
            return BranchRateGradient.this.rateParameter.getDimension();
        }

        @Override // dr.math.MultivariateFunction
        public double getLowerBound(int i) {
            return 0.0d;
        }

        @Override // dr.math.MultivariateFunction
        public double getUpperBound(int i) {
            return Double.POSITIVE_INFINITY;
        }
    };
    private static final boolean DEBUG = false;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:dr/evomodel/treedatalikelihood/continuous/BranchRateGradient$ContinuousTraitGradientForBranch.class */
    public interface ContinuousTraitGradientForBranch {

        /* loaded from: input_file:dr/evomodel/treedatalikelihood/continuous/BranchRateGradient$ContinuousTraitGradientForBranch$Default.class */
        public static class Default implements ContinuousTraitGradientForBranch {
            private final DenseMatrix64F matrix0;
            private final DenseMatrix64F matrix1;
            private final DenseMatrix64F vector0;
            private final int dim;

            public Default(int i) {
                this.dim = i;
                this.matrix0 = new DenseMatrix64F(i, i);
                this.matrix1 = new DenseMatrix64F(i, i);
                this.vector0 = new DenseMatrix64F(i, 1);
            }

            @Override // dr.evomodel.treedatalikelihood.continuous.BranchRateGradient.ContinuousTraitGradientForBranch
            public double getGradientForBranch(BranchSufficientStatistics branchSufficientStatistics, double d) {
                NormalSufficientStatistics below = branchSufficientStatistics.getBelow();
                MatrixSufficientStatistics branch = branchSufficientStatistics.getBranch();
                NormalSufficientStatistics above = branchSufficientStatistics.getAbove();
                DenseMatrix64F rawPrecision = above.getRawPrecision();
                DenseMatrix64F denseMatrix64F = this.matrix0;
                DenseMatrix64F denseMatrix64F2 = this.matrix1;
                makeGradientMatrices0(denseMatrix64F2, denseMatrix64F, branchSufficientStatistics, d);
                double d2 = 0.0d;
                for (int i = 0; i < this.dim; i++) {
                    d2 -= 0.5d * denseMatrix64F.unsafe_get(i, i);
                }
                NormalSufficientStatistics computeJointStatistics = computeJointStatistics(below, above, this.dim);
                DenseMatrix64F denseMatrix64F3 = this.matrix0;
                makeGradientMatrices1(denseMatrix64F3, denseMatrix64F2, computeJointStatistics);
                DenseMatrix64F denseMatrix64F4 = this.vector0;
                makeDeltaVector(denseMatrix64F4, computeJointStatistics, above);
                double d3 = 0.0d;
                for (int i2 = 0; i2 < this.dim; i2++) {
                    for (int i3 = 0; i3 < this.dim; i3++) {
                        d3 += 0.5d * denseMatrix64F4.unsafe_get(i2, 0) * denseMatrix64F2.unsafe_get(i2, i3) * denseMatrix64F4.unsafe_get(i3, 0);
                    }
                    d3 += 0.5d * denseMatrix64F3.unsafe_get(i2, i2);
                }
                DenseMatrix64F denseMatrix64F5 = new DenseMatrix64F(this.dim, 1);
                CommonOps.scale(d, branch.getRawMean(), denseMatrix64F5);
                double d4 = 0.0d;
                for (int i4 = 0; i4 < this.dim; i4++) {
                    for (int i5 = 0; i5 < this.dim; i5++) {
                        d4 += denseMatrix64F4.unsafe_get(i4, 0) * rawPrecision.unsafe_get(i4, i5) * denseMatrix64F5.unsafe_get(i5, 0);
                    }
                }
                return d2 + d3 + d4;
            }

            public static NormalSufficientStatistics computeJointStatistics(NormalSufficientStatistics normalSufficientStatistics, NormalSufficientStatistics normalSufficientStatistics2, int i) {
                PermutationIndices permutationIndices = new PermutationIndices(normalSufficientStatistics.getRawPrecision());
                return permutationIndices.getNumberOfInfiniteDiagonals() == i ? computeJointFullyObserved(normalSufficientStatistics, i) : permutationIndices.getNumberOfZeroDiagonals() == i ? computeJointFullyMissing(normalSufficientStatistics2, i) : (permutationIndices.getNumberOfZeroDiagonals() == 0 || permutationIndices.getNumberOfInfiniteDiagonals() == 0) ? computeJointLatent(normalSufficientStatistics, normalSufficientStatistics2, i) : computeJointPartiallyMissing(normalSufficientStatistics, normalSufficientStatistics2, permutationIndices, i);
            }

            private static NormalSufficientStatistics computeJointFullyObserved(NormalSufficientStatistics normalSufficientStatistics, int i) {
                return new NormalSufficientStatistics(normalSufficientStatistics.getRawMean(), normalSufficientStatistics.getRawPrecision(), new DenseMatrix64F(i, i));
            }

            private static NormalSufficientStatistics computeJointFullyMissing(NormalSufficientStatistics normalSufficientStatistics, int i) {
                return new NormalSufficientStatistics(normalSufficientStatistics.getRawMean(), normalSufficientStatistics.getRawPrecision(), normalSufficientStatistics.getRawVariance());
            }

            private static NormalSufficientStatistics computeJointLatent(NormalSufficientStatistics normalSufficientStatistics, NormalSufficientStatistics normalSufficientStatistics2, int i) {
                DenseMatrix64F denseMatrix64F = new DenseMatrix64F(i, 1);
                DenseMatrix64F denseMatrix64F2 = new DenseMatrix64F(i, i);
                DenseMatrix64F denseMatrix64F3 = new DenseMatrix64F(i, i);
                CommonOps.add(normalSufficientStatistics.getRawPrecision(), normalSufficientStatistics2.getRawPrecision(), denseMatrix64F2);
                MissingOps.safeInvert2(denseMatrix64F2, denseMatrix64F3, false);
                MissingOps.safeWeightedAverage(new WrappedVector.Raw(normalSufficientStatistics.getRawMean().getData(), 0, i), normalSufficientStatistics.getRawPrecision(), new WrappedVector.Raw(normalSufficientStatistics2.getRawMean().getData(), 0, i), normalSufficientStatistics2.getRawPrecision(), new WrappedVector.Raw(denseMatrix64F.getData(), 0, i), denseMatrix64F3, i);
                return new NormalSufficientStatistics(denseMatrix64F, denseMatrix64F2, denseMatrix64F3);
            }

            private static NormalSufficientStatistics computeJointPartiallyMissing(NormalSufficientStatistics normalSufficientStatistics, NormalSufficientStatistics normalSufficientStatistics2, PermutationIndices permutationIndices, int i) {
                DenseMatrix64F denseMatrix64F = new DenseMatrix64F(i, 1);
                DenseMatrix64F denseMatrix64F2 = new DenseMatrix64F(i, i);
                DenseMatrix64F denseMatrix64F3 = new DenseMatrix64F(i, i);
                if (permutationIndices.getNumberOfNonZeroFiniteDiagonals() != 0) {
                    throw new RuntimeException("Unsure if this works for latent trait below");
                }
                ConditionalPrecisionAndTransform2 conditionalPrecisionAndTransform2 = new ConditionalPrecisionAndTransform2(normalSufficientStatistics2.getRawPrecision(), permutationIndices.getZeroIndices(), permutationIndices.getInfiniteIndices());
                double[] conditionalMean = conditionalPrecisionAndTransform2.getConditionalMean(normalSufficientStatistics.getRawMean().getData(), 0, normalSufficientStatistics2.getRawMean().getData(), 0);
                MissingOps.copyRowsAndColumns(normalSufficientStatistics2.getRawPrecision(), denseMatrix64F2, permutationIndices.getZeroIndices(), permutationIndices.getZeroIndices(), false);
                MissingOps.scatterRowsAndColumns(conditionalPrecisionAndTransform2.getConditionalVariance(), denseMatrix64F3, permutationIndices.getZeroIndices(), permutationIndices.getZeroIndices(), false);
                int i2 = 0;
                for (int i3 : permutationIndices.getZeroIndices()) {
                    int i4 = i2;
                    i2++;
                    denseMatrix64F.unsafe_set(i3, 0, conditionalMean[i4]);
                }
                for (int i5 : permutationIndices.getInfiniteIndices()) {
                    denseMatrix64F.unsafe_set(i5, 0, normalSufficientStatistics.getMean(i5));
                    denseMatrix64F2.unsafe_set(i5, i5, Double.POSITIVE_INFINITY);
                }
                return new NormalSufficientStatistics(denseMatrix64F, denseMatrix64F2, denseMatrix64F3);
            }

            public void makeGradientMatrices0(DenseMatrix64F denseMatrix64F, DenseMatrix64F denseMatrix64F2, BranchSufficientStatistics branchSufficientStatistics, double d) {
                NormalSufficientStatistics above = branchSufficientStatistics.getAbove();
                MatrixSufficientStatistics branch = branchSufficientStatistics.getBranch();
                DenseMatrix64F rawPrecision = above.getRawPrecision();
                CommonOps.scale(d, branch.getRawVariance(), denseMatrix64F);
                CommonOps.mult(rawPrecision, denseMatrix64F, denseMatrix64F2);
                CommonOps.mult(denseMatrix64F2, rawPrecision, denseMatrix64F);
            }

            public void makeGradientMatrices1(DenseMatrix64F denseMatrix64F, DenseMatrix64F denseMatrix64F2, NormalSufficientStatistics normalSufficientStatistics) {
                CommonOps.mult(denseMatrix64F2, normalSufficientStatistics.getRawVariance(), denseMatrix64F);
            }

            public void makeDeltaVector(DenseMatrix64F denseMatrix64F, NormalSufficientStatistics normalSufficientStatistics, NormalSufficientStatistics normalSufficientStatistics2) {
                for (int i = 0; i < this.dim; i++) {
                    denseMatrix64F.unsafe_set(i, 0, normalSufficientStatistics.getRawMean().unsafe_get(i, 0) - normalSufficientStatistics2.getMean(i));
                }
            }
        }

        double getGradientForBranch(BranchSufficientStatistics branchSufficientStatistics, double d);
    }

    public BranchRateGradient(String str, TreeDataLikelihood treeDataLikelihood, ContinuousDataLikelihoodDelegate continuousDataLikelihoodDelegate, Parameter parameter) {
        if (!$assertionsDisabled && treeDataLikelihood == null) {
            throw new AssertionError();
        }
        this.treeDataLikelihood = treeDataLikelihood;
        this.tree = treeDataLikelihood.getTree();
        this.rateParameter = parameter;
        BranchRateModel branchRateModel = treeDataLikelihood.getBranchRateModel();
        this.branchRateModel = branchRateModel instanceof ArbitraryBranchRates ? (ArbitraryBranchRates) branchRateModel : null;
        String name = BranchConditionalDistributionDelegate.getName(str);
        if (treeDataLikelihood.getTreeTrait(name) == null) {
            continuousDataLikelihoodDelegate.addBranchConditionalDensityTrait(str);
        }
        this.treeTraitProvider = treeDataLikelihood.getTreeTrait(name);
        if (!$assertionsDisabled && this.treeTraitProvider == null) {
            throw new AssertionError();
        }
        this.nTraits = treeDataLikelihood.getDataLikelihoodDelegate().getTraitCount();
        if (this.nTraits != 1) {
            throw new RuntimeException("Not yet implemented for >1 traits");
        }
        this.branchProvider = new ContinuousTraitGradientForBranch.Default(treeDataLikelihood.getDataLikelihoodDelegate().getTraitDim());
    }

    @Override // dr.inference.hmc.GradientWrtParameterProvider
    public Likelihood getLikelihood() {
        return this.treeDataLikelihood;
    }

    @Override // dr.inference.hmc.GradientWrtParameterProvider
    public Parameter getParameter() {
        return this.rateParameter;
    }

    @Override // dr.inference.hmc.GradientWrtParameterProvider
    public int getDimension() {
        return getParameter().getDimension();
    }

    @Override // dr.inference.hmc.GradientWrtParameterProvider
    public double[] getGradientLogDensity() {
        this.treeDataLikelihood.makeDirty();
        double[] dArr = new double[this.rateParameter.getDimension()];
        for (int i = 0; i < this.tree.getNodeCount(); i++) {
            NodeRef node = this.tree.getNode(i);
            if (!this.tree.isRoot(node)) {
                List<BranchSufficientStatistics> trait = this.treeTraitProvider.getTrait(this.tree, node);
                if (!$assertionsDisabled && trait.size() != this.nTraits) {
                    throw new AssertionError();
                }
                double branchRateDifferential = this.branchRateModel.getBranchRateDifferential(this.tree, node) / this.branchRateModel.getBranchRate(this.tree, node);
                double d = 0.0d;
                for (int i2 = 0; i2 < this.nTraits; i2++) {
                    d += this.branchProvider.getGradientForBranch(trait.get(i2), branchRateDifferential);
                }
                int parameterIndexFromNode = getParameterIndexFromNode(node);
                if (!$assertionsDisabled && parameterIndexFromNode == -1) {
                    throw new AssertionError();
                }
                dArr[parameterIndexFromNode] = d;
            }
        }
        return dArr;
    }

    private int getParameterIndexFromNode(NodeRef nodeRef) {
        return this.branchRateModel == null ? nodeRef.getNumber() : this.branchRateModel.getParameterIndexFromNode(nodeRef);
    }

    @Override // dr.inference.hmc.HessianWrtParameterProvider
    public double[] getDiagonalHessianLogDensity() {
        double[] dArr = new double[this.rateParameter.getDimension()];
        for (int i = 0; i < this.tree.getNodeCount(); i++) {
            NodeRef node = this.tree.getNode(i);
            if (!this.tree.isRoot(node)) {
                List<BranchSufficientStatistics> trait = this.treeTraitProvider.getTrait(this.tree, node);
                if (!$assertionsDisabled && trait.size() != this.nTraits) {
                    throw new AssertionError();
                }
                double branchRate = this.branchRateModel.getBranchRate(this.tree, node);
                double branchRateDifferential = this.branchRateModel.getBranchRateDifferential(this.tree, node) / branchRate;
                double branchRateSecondDifferential = this.branchRateModel.getBranchRateSecondDifferential(this.tree, node) / branchRate;
                double d = 0.0d;
                for (int i2 = 0; i2 < this.nTraits; i2++) {
                    d += getDiagonalHessianLogDensity(trait.get(i2), branchRateDifferential, branchRateSecondDifferential);
                }
                int parameterIndexFromNode = getParameterIndexFromNode(node);
                if (!$assertionsDisabled && parameterIndexFromNode == -1) {
                    throw new AssertionError();
                }
                dArr[parameterIndexFromNode] = d;
            }
        }
        return dArr;
    }

    private double getDiagonalHessianLogDensity(BranchSufficientStatistics branchSufficientStatistics, double d, double d2) {
        int traitDim = this.treeDataLikelihood.getDataLikelihoodDelegate().getTraitDim();
        DenseMatrix64F denseMatrix64F = new DenseMatrix64F(traitDim, traitDim);
        DenseMatrix64F denseMatrix64F2 = new DenseMatrix64F(traitDim, traitDim);
        DenseMatrix64F denseMatrix64F3 = new DenseMatrix64F(traitDim, traitDim);
        DenseMatrix64F denseMatrix64F4 = new DenseMatrix64F(traitDim, traitDim);
        NormalSufficientStatistics below = branchSufficientStatistics.getBelow();
        MatrixSufficientStatistics branch = branchSufficientStatistics.getBranch();
        NormalSufficientStatistics above = branchSufficientStatistics.getAbove();
        NormalSufficientStatistics computeJointStatistics = ContinuousTraitGradientForBranch.Default.computeJointStatistics(below, above, traitDim);
        ((ContinuousTraitGradientForBranch.Default) this.branchProvider).makeDeltaVector(denseMatrix64F4, computeJointStatistics, above);
        DenseMatrix64F rawPrecision = above.getRawPrecision();
        DenseMatrix64F rawVariance = computeJointStatistics.getRawVariance();
        ((ContinuousTraitGradientForBranch.Default) this.branchProvider).makeGradientMatrices0(denseMatrix64F2, denseMatrix64F, branchSufficientStatistics, d);
        CommonOps.mult(denseMatrix64F, denseMatrix64F, denseMatrix64F3);
        double d3 = 0.0d;
        for (int i = 0; i < traitDim; i++) {
            d3 += 0.5d * denseMatrix64F3.unsafe_get(i, i);
        }
        CommonOps.mult(denseMatrix64F3, rawPrecision, denseMatrix64F);
        CommonOps.mult(rawVariance, denseMatrix64F, denseMatrix64F3);
        double d4 = 0.0d;
        for (int i2 = 0; i2 < traitDim; i2++) {
            for (int i3 = 0; i3 < traitDim; i3++) {
                d4 -= (denseMatrix64F4.unsafe_get(i2, 0) * denseMatrix64F.unsafe_get(i2, i3)) * denseMatrix64F4.unsafe_get(i3, 0);
            }
            d4 -= denseMatrix64F3.unsafe_get(i2, i2);
        }
        ((ContinuousTraitGradientForBranch.Default) this.branchProvider).makeGradientMatrices1(denseMatrix64F3, denseMatrix64F2, computeJointStatistics);
        CommonOps.mult(denseMatrix64F3, denseMatrix64F2, denseMatrix64F);
        CommonOps.mult(denseMatrix64F, rawVariance, denseMatrix64F3);
        double d5 = 0.0d;
        for (int i4 = 0; i4 < traitDim; i4++) {
            for (int i5 = 0; i5 < traitDim; i5++) {
                d5 += denseMatrix64F4.unsafe_get(i4, 0) * denseMatrix64F.unsafe_get(i4, i5) * denseMatrix64F4.unsafe_get(i5, 0);
            }
            d5 += 0.5d * denseMatrix64F3.unsafe_get(i4, i4);
        }
        CommonOps.scale(d2, branch.getRawVariance(), denseMatrix64F);
        CommonOps.mult(rawPrecision, denseMatrix64F, denseMatrix64F3);
        CommonOps.mult(denseMatrix64F3, rawPrecision, denseMatrix64F);
        CommonOps.mult(rawVariance, denseMatrix64F, denseMatrix64F2);
        double d6 = 0.0d;
        for (int i6 = 0; i6 < traitDim; i6++) {
            for (int i7 = 0; i7 < traitDim; i7++) {
                d6 += denseMatrix64F4.unsafe_get(i6, 0) * denseMatrix64F.unsafe_get(i6, i7) * denseMatrix64F4.unsafe_get(i7, 0);
            }
            d6 = (d6 - (0.5d * denseMatrix64F3.unsafe_get(i6, i6))) + (0.5d * denseMatrix64F2.unsafe_get(i6, i6));
        }
        return d3 + d4 + d5 + d6;
    }

    @Override // dr.inference.hmc.HessianWrtParameterProvider
    public double[][] getHessianLogDensity() {
        throw new RuntimeException("Not yet implemented!");
    }

    public double[] getNumericalGradient() {
        double[] parameterValues = this.rateParameter.getParameterValues();
        double[] gradient = NumericalDerivative.gradient(this.numeric1, this.rateParameter.getParameterValues());
        for (int i = 0; i < parameterValues.length; i++) {
            this.rateParameter.setParameterValue(i, parameterValues[i]);
        }
        return gradient;
    }

    @Override // dr.xml.Reportable
    public String getReport() {
        double[] numericalGradient = getNumericalGradient();
        NumericalHessianFromGradient numericalHessianFromGradient = new NumericalHessianFromGradient(this);
        StringBuilder sb = new StringBuilder();
        sb.append("Peeling: ").append(new Vector(getGradientLogDensity()));
        sb.append("\n");
        sb.append("numeric: ").append(new Vector(numericalGradient));
        sb.append("\n");
        sb.append("Peeling diagonal hessian: ").append(new Vector(getDiagonalHessianLogDensity()));
        sb.append("\n");
        sb.append("numeric diagonal hessian: ").append(new Vector(NumericalDerivative.diagonalHessian(this.numeric1, getParameter().getParameterValues())));
        sb.append("\n");
        sb.append("Another numeric diagonal hessian: ").append(new Vector(numericalHessianFromGradient.getDiagonalHessianLogDensity()));
        sb.append("\n");
        return sb.toString();
    }

    @Override // dr.inference.loggers.Loggable
    public LogColumn[] getColumns() {
        return new LogColumn[]{new LogColumn.Default("gradient report", new Object() { // from class: dr.evomodel.treedatalikelihood.continuous.BranchRateGradient.2
            public String toString() {
                return "\n" + BranchRateGradient.this.getReport();
            }
        })};
    }

    static {
        $assertionsDisabled = !BranchRateGradient.class.desiredAssertionStatus();
    }
}
