package dr.evomodel.treedatalikelihood.continuous;

import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evomodel.branchratemodel.ArbitraryBranchRates;
import dr.evomodel.branchratemodel.BranchRateModel;
import dr.evomodel.treedatalikelihood.continuous.BranchRateGradient;
import dr.evomodel.treedatalikelihood.continuous.cdi.ContinuousDiffusionIntegrator;
import dr.evomodel.treedatalikelihood.hmc.MultivariateChainRule;
import dr.evomodel.treedatalikelihood.preorder.BranchSufficientStatistics;
import dr.evomodel.treedatalikelihood.preorder.ModelExtensionProvider;
import dr.evomodel.treedatalikelihood.preorder.NormalSufficientStatistics;
import dr.math.matrixAlgebra.missingData.MissingOps;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import org.ejml.data.DenseMatrix64F;
import org.ejml.ops.CommonOps;

/* loaded from: input_file:dr/evomodel/treedatalikelihood/continuous/ContinuousTraitGradientForBranch.class */
public interface ContinuousTraitGradientForBranch {

    /* loaded from: input_file:dr/evomodel/treedatalikelihood/continuous/ContinuousTraitGradientForBranch$ContinuousProcessParameterGradient.class */
    public static class ContinuousProcessParameterGradient extends Default {
        ContinuousDataLikelihoodDelegate likelihoodDelegate;
        ContinuousDiffusionIntegrator cdi;
        DiffusionProcessDelegate diffusionProcessDelegate;
        final List<DerivationParameter> derivationParameter;

        /* loaded from: input_file:dr/evomodel/treedatalikelihood/continuous/ContinuousTraitGradientForBranch$ContinuousProcessParameterGradient$DerivationParameter.class */
        public enum DerivationParameter {
            WRT_VARIANCE { // from class: dr.evomodel.treedatalikelihood.continuous.ContinuousTraitGradientForBranch.ContinuousProcessParameterGradient.DerivationParameter.1
                @Override // dr.evomodel.treedatalikelihood.continuous.ContinuousTraitGradientForBranch.ContinuousProcessParameterGradient.DerivationParameter
                public double[] chainRule(ContinuousDiffusionIntegrator continuousDiffusionIntegrator, DiffusionProcessDelegate diffusionProcessDelegate, ContinuousDataLikelihoodDelegate continuousDataLikelihoodDelegate, BranchSufficientStatistics branchSufficientStatistics, NodeRef nodeRef, DenseMatrix64F denseMatrix64F, DenseMatrix64F denseMatrix64F2) {
                    return diffusionProcessDelegate.getGradientVarianceWrtVariance(nodeRef, continuousDiffusionIntegrator, continuousDataLikelihoodDelegate, denseMatrix64F).getData();
                }

                @Override // dr.evomodel.treedatalikelihood.continuous.ContinuousTraitGradientForBranch.ContinuousProcessParameterGradient.DerivationParameter
                public double[] chainRuleRoot(ContinuousDiffusionIntegrator continuousDiffusionIntegrator, DiffusionProcessDelegate diffusionProcessDelegate, ContinuousDataLikelihoodDelegate continuousDataLikelihoodDelegate, BranchSufficientStatistics branchSufficientStatistics, NodeRef nodeRef, DenseMatrix64F denseMatrix64F, DenseMatrix64F denseMatrix64F2) {
                    return chainRule(continuousDiffusionIntegrator, diffusionProcessDelegate, continuousDataLikelihoodDelegate, branchSufficientStatistics, nodeRef, denseMatrix64F, denseMatrix64F2);
                }

                @Override // dr.evomodel.treedatalikelihood.continuous.ContinuousTraitGradientForBranch.ContinuousProcessParameterGradient.DerivationParameter
                public int getDimension(int i) {
                    return i * i;
                }
            },
            WRT_CONSTANT_DRIFT { // from class: dr.evomodel.treedatalikelihood.continuous.ContinuousTraitGradientForBranch.ContinuousProcessParameterGradient.DerivationParameter.2
                @Override // dr.evomodel.treedatalikelihood.continuous.ContinuousTraitGradientForBranch.ContinuousProcessParameterGradient.DerivationParameter
                public double[] chainRule(ContinuousDiffusionIntegrator continuousDiffusionIntegrator, DiffusionProcessDelegate diffusionProcessDelegate, ContinuousDataLikelihoodDelegate continuousDataLikelihoodDelegate, BranchSufficientStatistics branchSufficientStatistics, NodeRef nodeRef, DenseMatrix64F denseMatrix64F, DenseMatrix64F denseMatrix64F2) {
                    return ((AbstractDriftDiffusionModelDelegate) diffusionProcessDelegate).getGradientDisplacementWrtDrift(nodeRef, continuousDiffusionIntegrator, continuousDataLikelihoodDelegate, denseMatrix64F2).getData();
                }

                @Override // dr.evomodel.treedatalikelihood.continuous.ContinuousTraitGradientForBranch.ContinuousProcessParameterGradient.DerivationParameter
                public double[] chainRuleRoot(ContinuousDiffusionIntegrator continuousDiffusionIntegrator, DiffusionProcessDelegate diffusionProcessDelegate, ContinuousDataLikelihoodDelegate continuousDataLikelihoodDelegate, BranchSufficientStatistics branchSufficientStatistics, NodeRef nodeRef, DenseMatrix64F denseMatrix64F, DenseMatrix64F denseMatrix64F2) {
                    return new double[denseMatrix64F2.getNumRows()];
                }

                @Override // dr.evomodel.treedatalikelihood.continuous.ContinuousTraitGradientForBranch.ContinuousProcessParameterGradient.DerivationParameter
                public int getDimension(int i) {
                    return i;
                }
            },
            WRT_ROOT_MEAN { // from class: dr.evomodel.treedatalikelihood.continuous.ContinuousTraitGradientForBranch.ContinuousProcessParameterGradient.DerivationParameter.3
                @Override // dr.evomodel.treedatalikelihood.continuous.ContinuousTraitGradientForBranch.ContinuousProcessParameterGradient.DerivationParameter
                public double[] chainRule(ContinuousDiffusionIntegrator continuousDiffusionIntegrator, DiffusionProcessDelegate diffusionProcessDelegate, ContinuousDataLikelihoodDelegate continuousDataLikelihoodDelegate, BranchSufficientStatistics branchSufficientStatistics, NodeRef nodeRef, DenseMatrix64F denseMatrix64F, DenseMatrix64F denseMatrix64F2) {
                    return diffusionProcessDelegate.getGradientDisplacementWrtRoot(nodeRef, continuousDiffusionIntegrator, continuousDataLikelihoodDelegate, denseMatrix64F2);
                }

                @Override // dr.evomodel.treedatalikelihood.continuous.ContinuousTraitGradientForBranch.ContinuousProcessParameterGradient.DerivationParameter
                public double[] chainRuleRoot(ContinuousDiffusionIntegrator continuousDiffusionIntegrator, DiffusionProcessDelegate diffusionProcessDelegate, ContinuousDataLikelihoodDelegate continuousDataLikelihoodDelegate, BranchSufficientStatistics branchSufficientStatistics, NodeRef nodeRef, DenseMatrix64F denseMatrix64F, DenseMatrix64F denseMatrix64F2) {
                    return chainRule(continuousDiffusionIntegrator, diffusionProcessDelegate, continuousDataLikelihoodDelegate, branchSufficientStatistics, nodeRef, denseMatrix64F, denseMatrix64F2);
                }

                @Override // dr.evomodel.treedatalikelihood.continuous.ContinuousTraitGradientForBranch.ContinuousProcessParameterGradient.DerivationParameter
                public int getDimension(int i) {
                    return i;
                }
            },
            WRT_CONSTANT_DRIFT_AND_ROOT_MEAN { // from class: dr.evomodel.treedatalikelihood.continuous.ContinuousTraitGradientForBranch.ContinuousProcessParameterGradient.DerivationParameter.4
                @Override // dr.evomodel.treedatalikelihood.continuous.ContinuousTraitGradientForBranch.ContinuousProcessParameterGradient.DerivationParameter
                public double[] chainRule(ContinuousDiffusionIntegrator continuousDiffusionIntegrator, DiffusionProcessDelegate diffusionProcessDelegate, ContinuousDataLikelihoodDelegate continuousDataLikelihoodDelegate, BranchSufficientStatistics branchSufficientStatistics, NodeRef nodeRef, DenseMatrix64F denseMatrix64F, DenseMatrix64F denseMatrix64F2) {
                    double[] chainRule = WRT_CONSTANT_DRIFT.chainRule(continuousDiffusionIntegrator, diffusionProcessDelegate, continuousDataLikelihoodDelegate, branchSufficientStatistics, nodeRef, denseMatrix64F, denseMatrix64F2);
                    double[] chainRule2 = WRT_ROOT_MEAN.chainRule(continuousDiffusionIntegrator, diffusionProcessDelegate, continuousDataLikelihoodDelegate, branchSufficientStatistics, nodeRef, denseMatrix64F, denseMatrix64F2);
                    for (int i = 0; i < chainRule2.length; i++) {
                        int i2 = i;
                        chainRule2[i2] = chainRule2[i2] + chainRule[i];
                    }
                    return chainRule2;
                }

                @Override // dr.evomodel.treedatalikelihood.continuous.ContinuousTraitGradientForBranch.ContinuousProcessParameterGradient.DerivationParameter
                public double[] chainRuleRoot(ContinuousDiffusionIntegrator continuousDiffusionIntegrator, DiffusionProcessDelegate diffusionProcessDelegate, ContinuousDataLikelihoodDelegate continuousDataLikelihoodDelegate, BranchSufficientStatistics branchSufficientStatistics, NodeRef nodeRef, DenseMatrix64F denseMatrix64F, DenseMatrix64F denseMatrix64F2) {
                    return WRT_ROOT_MEAN.chainRuleRoot(continuousDiffusionIntegrator, diffusionProcessDelegate, continuousDataLikelihoodDelegate, branchSufficientStatistics, nodeRef, denseMatrix64F, denseMatrix64F2);
                }

                @Override // dr.evomodel.treedatalikelihood.continuous.ContinuousTraitGradientForBranch.ContinuousProcessParameterGradient.DerivationParameter
                public int getDimension(int i) {
                    return i;
                }
            },
            WRT_DIAGONAL_SELECTION_STRENGTH { // from class: dr.evomodel.treedatalikelihood.continuous.ContinuousTraitGradientForBranch.ContinuousProcessParameterGradient.DerivationParameter.5
                @Override // dr.evomodel.treedatalikelihood.continuous.ContinuousTraitGradientForBranch.ContinuousProcessParameterGradient.DerivationParameter
                public double[] chainRule(ContinuousDiffusionIntegrator continuousDiffusionIntegrator, DiffusionProcessDelegate diffusionProcessDelegate, ContinuousDataLikelihoodDelegate continuousDataLikelihoodDelegate, BranchSufficientStatistics branchSufficientStatistics, NodeRef nodeRef, DenseMatrix64F denseMatrix64F, DenseMatrix64F denseMatrix64F2) {
                    DenseMatrix64F gradientVarianceWrtAttenuation = ((OUDiffusionModelDelegate) diffusionProcessDelegate).getGradientVarianceWrtAttenuation(nodeRef, continuousDiffusionIntegrator, branchSufficientStatistics, denseMatrix64F);
                    CommonOps.addEquals(gradientVarianceWrtAttenuation, ((OUDiffusionModelDelegate) diffusionProcessDelegate).getGradientDisplacementWrtAttenuation(nodeRef, continuousDiffusionIntegrator, branchSufficientStatistics, denseMatrix64F2));
                    return gradientVarianceWrtAttenuation.getData();
                }

                @Override // dr.evomodel.treedatalikelihood.continuous.ContinuousTraitGradientForBranch.ContinuousProcessParameterGradient.DerivationParameter
                public double[] chainRuleRoot(ContinuousDiffusionIntegrator continuousDiffusionIntegrator, DiffusionProcessDelegate diffusionProcessDelegate, ContinuousDataLikelihoodDelegate continuousDataLikelihoodDelegate, BranchSufficientStatistics branchSufficientStatistics, NodeRef nodeRef, DenseMatrix64F denseMatrix64F, DenseMatrix64F denseMatrix64F2) {
                    return new double[continuousDataLikelihoodDelegate.getTraitDim()];
                }

                @Override // dr.evomodel.treedatalikelihood.continuous.ContinuousTraitGradientForBranch.ContinuousProcessParameterGradient.DerivationParameter
                public int getDimension(int i) {
                    return i;
                }
            },
            WRT_SAMPLING_VARIANCE { // from class: dr.evomodel.treedatalikelihood.continuous.ContinuousTraitGradientForBranch.ContinuousProcessParameterGradient.DerivationParameter.6
                @Override // dr.evomodel.treedatalikelihood.continuous.ContinuousTraitGradientForBranch.ContinuousProcessParameterGradient.DerivationParameter
                public double[] chainRule(ContinuousDiffusionIntegrator continuousDiffusionIntegrator, DiffusionProcessDelegate diffusionProcessDelegate, ContinuousDataLikelihoodDelegate continuousDataLikelihoodDelegate, BranchSufficientStatistics branchSufficientStatistics, NodeRef nodeRef, DenseMatrix64F denseMatrix64F, DenseMatrix64F denseMatrix64F2) {
                    return denseMatrix64F.getData();
                }

                @Override // dr.evomodel.treedatalikelihood.continuous.ContinuousTraitGradientForBranch.ContinuousProcessParameterGradient.DerivationParameter
                public double[] chainRuleRoot(ContinuousDiffusionIntegrator continuousDiffusionIntegrator, DiffusionProcessDelegate diffusionProcessDelegate, ContinuousDataLikelihoodDelegate continuousDataLikelihoodDelegate, BranchSufficientStatistics branchSufficientStatistics, NodeRef nodeRef, DenseMatrix64F denseMatrix64F, DenseMatrix64F denseMatrix64F2) {
                    throw new RuntimeException("Should never be called.");
                }

                @Override // dr.evomodel.treedatalikelihood.continuous.ContinuousTraitGradientForBranch.ContinuousProcessParameterGradient.DerivationParameter
                public int getDimension(int i) {
                    return i * i;
                }
            };

            abstract double[] chainRule(ContinuousDiffusionIntegrator continuousDiffusionIntegrator, DiffusionProcessDelegate diffusionProcessDelegate, ContinuousDataLikelihoodDelegate continuousDataLikelihoodDelegate, BranchSufficientStatistics branchSufficientStatistics, NodeRef nodeRef, DenseMatrix64F denseMatrix64F, DenseMatrix64F denseMatrix64F2);

            abstract double[] chainRuleRoot(ContinuousDiffusionIntegrator continuousDiffusionIntegrator, DiffusionProcessDelegate diffusionProcessDelegate, ContinuousDataLikelihoodDelegate continuousDataLikelihoodDelegate, BranchSufficientStatistics branchSufficientStatistics, NodeRef nodeRef, DenseMatrix64F denseMatrix64F, DenseMatrix64F denseMatrix64F2);

            public abstract int getDimension(int i);
        }

        public ContinuousProcessParameterGradient(int i, Tree tree, ContinuousDataLikelihoodDelegate continuousDataLikelihoodDelegate, List<DerivationParameter> list) {
            super(i, tree);
            this.likelihoodDelegate = continuousDataLikelihoodDelegate;
            this.cdi = continuousDataLikelihoodDelegate.getIntegrator();
            this.diffusionProcessDelegate = continuousDataLikelihoodDelegate.getDiffusionProcessDelegate();
            this.derivationParameter = list;
        }

        @Override // dr.evomodel.treedatalikelihood.continuous.ContinuousTraitGradientForBranch.Default, dr.evomodel.treedatalikelihood.continuous.ContinuousTraitGradientForBranch
        public int getParameterIndexFromNode(NodeRef nodeRef) {
            return 0;
        }

        @Override // dr.evomodel.treedatalikelihood.continuous.ContinuousTraitGradientForBranch
        public int getDimension() {
            int i = 0;
            Iterator<DerivationParameter> it = this.derivationParameter.iterator();
            while (it.hasNext()) {
                i += it.next().getDimension(this.dim);
            }
            return i;
        }

        @Override // dr.evomodel.treedatalikelihood.continuous.ContinuousTraitGradientForBranch.Default
        public double[] chainRule(BranchSufficientStatistics branchSufficientStatistics, NodeRef nodeRef, DenseMatrix64F denseMatrix64F, DenseMatrix64F denseMatrix64F2) {
            removeMissing(denseMatrix64F, branchSufficientStatistics.getMissing());
            double[] dArr = new double[getDimension()];
            int i = 0;
            for (DerivationParameter derivationParameter : this.derivationParameter) {
                int dimension = derivationParameter.getDimension(this.dim);
                System.arraycopy(derivationParameter.chainRule(this.cdi, this.diffusionProcessDelegate, this.likelihoodDelegate, branchSufficientStatistics, nodeRef, denseMatrix64F, denseMatrix64F2), 0, dArr, i, dimension);
                i += dimension;
            }
            return dArr;
        }

        @Override // dr.evomodel.treedatalikelihood.continuous.ContinuousTraitGradientForBranch.Default
        public double[] chainRuleRoot(BranchSufficientStatistics branchSufficientStatistics, NodeRef nodeRef, DenseMatrix64F denseMatrix64F, DenseMatrix64F denseMatrix64F2) {
            double[] dArr = new double[getDimension()];
            int i = 0;
            for (DerivationParameter derivationParameter : this.derivationParameter) {
                int dimension = derivationParameter.getDimension(this.dim);
                System.arraycopy(derivationParameter.chainRuleRoot(this.cdi, this.diffusionProcessDelegate, this.likelihoodDelegate, branchSufficientStatistics, nodeRef, denseMatrix64F, denseMatrix64F2), 0, dArr, i, dimension);
                i += dimension;
            }
            return dArr;
        }

        private static void removeMissing(DenseMatrix64F denseMatrix64F, int[] iArr) {
            for (int i : iArr) {
                for (int i2 = 0; i2 < denseMatrix64F.getNumCols(); i2++) {
                    denseMatrix64F.unsafe_set(i, i2, 0.0d);
                    denseMatrix64F.unsafe_set(i2, i, 0.0d);
                }
            }
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public List<DerivationParameter> getDerivationParameter() {
            return this.derivationParameter;
        }
    }

    /* loaded from: input_file:dr/evomodel/treedatalikelihood/continuous/ContinuousTraitGradientForBranch$Default.class */
    public static abstract class Default implements ContinuousTraitGradientForBranch {
        private final DenseMatrix64F matrixGradientQInv;
        private final DenseMatrix64F matrixGradientN;
        final DenseMatrix64F matrixDelta;
        DenseMatrix64F matrixQ;
        DenseMatrix64F matrixW;
        DenseMatrix64F matrixV;
        final int dim;
        final Tree tree;
        static final boolean DEBUG = false;

        public Default(int i, Tree tree) {
            this.dim = i;
            this.tree = tree;
            this.matrixGradientQInv = new DenseMatrix64F(i, i);
            this.matrixGradientN = new DenseMatrix64F(i, 1);
            this.matrixDelta = new DenseMatrix64F(i, 1);
            this.matrixQ = new DenseMatrix64F(i, i);
            this.matrixW = new DenseMatrix64F(i, i);
            this.matrixV = new DenseMatrix64F(i, i);
        }

        @Override // dr.evomodel.treedatalikelihood.continuous.ContinuousTraitGradientForBranch
        public int getParameterIndexFromNode(NodeRef nodeRef) {
            return nodeRef.getNumber();
        }

        @Override // dr.evomodel.treedatalikelihood.continuous.ContinuousTraitGradientForBranch
        public double[] getGradientForBranch(BranchSufficientStatistics branchSufficientStatistics, NodeRef nodeRef) {
            return getGradientForBranch(branchSufficientStatistics, nodeRef, true, true);
        }

        double[] getGradientForBranch(BranchSufficientStatistics branchSufficientStatistics, NodeRef nodeRef, boolean z, boolean z2) {
            getSufficientStatistics(branchSufficientStatistics);
            DenseMatrix64F denseMatrix64F = this.matrixQ;
            DenseMatrix64F denseMatrix64F2 = this.matrixW;
            DenseMatrix64F denseMatrix64F3 = this.matrixV;
            DenseMatrix64F denseMatrix64F4 = this.matrixDelta;
            DenseMatrix64F denseMatrix64F5 = this.matrixGradientQInv;
            DenseMatrix64F denseMatrix64F6 = this.matrixGradientN;
            if (z) {
                getGradientQInvForBranch(denseMatrix64F, denseMatrix64F2, denseMatrix64F3, denseMatrix64F4, denseMatrix64F5);
            }
            if (z2) {
                getGradientNForBranch(denseMatrix64F, denseMatrix64F4, denseMatrix64F6);
            }
            return this.tree.isRoot(nodeRef) ? chainRuleRoot(branchSufficientStatistics, nodeRef, denseMatrix64F5, denseMatrix64F6) : chainRule(branchSufficientStatistics, nodeRef, denseMatrix64F5, denseMatrix64F6);
        }

        void getSufficientStatistics(BranchSufficientStatistics branchSufficientStatistics) {
            NormalSufficientStatistics below = branchSufficientStatistics.getBelow();
            NormalSufficientStatistics above = branchSufficientStatistics.getAbove();
            NormalSufficientStatistics computeJointStatistics = BranchRateGradient.ContinuousTraitGradientForBranch.Default.computeJointStatistics(below, above, this.dim);
            this.matrixQ = above.getRawPrecision();
            this.matrixW = above.getRawVariance();
            this.matrixV = computeJointStatistics.getRawVariance();
            for (int i = 0; i < this.dim; i++) {
                this.matrixDelta.unsafe_set(i, 0, computeJointStatistics.getMean(i) - above.getMean(i));
            }
        }

        abstract double[] chainRule(BranchSufficientStatistics branchSufficientStatistics, NodeRef nodeRef, DenseMatrix64F denseMatrix64F, DenseMatrix64F denseMatrix64F2);

        abstract double[] chainRuleRoot(BranchSufficientStatistics branchSufficientStatistics, NodeRef nodeRef, DenseMatrix64F denseMatrix64F, DenseMatrix64F denseMatrix64F2);

        static void getGradientQInvForBranch(DenseMatrix64F denseMatrix64F, DenseMatrix64F denseMatrix64F2, DenseMatrix64F denseMatrix64F3, DenseMatrix64F denseMatrix64F4, DenseMatrix64F denseMatrix64F5) {
            CommonOps.scale(0.5d, denseMatrix64F2, denseMatrix64F5);
            CommonOps.multAddTransB(-0.5d, denseMatrix64F4, denseMatrix64F4, denseMatrix64F5);
            CommonOps.addEquals(denseMatrix64F5, -0.5d, denseMatrix64F3);
            new MultivariateChainRule.InverseGeneral(denseMatrix64F).chainGradient(denseMatrix64F5);
        }

        private void getGradientNForBranch(DenseMatrix64F denseMatrix64F, DenseMatrix64F denseMatrix64F2, DenseMatrix64F denseMatrix64F3) {
            CommonOps.multTransA(denseMatrix64F, denseMatrix64F2, denseMatrix64F3);
        }
    }

    /* loaded from: input_file:dr/evomodel/treedatalikelihood/continuous/ContinuousTraitGradientForBranch$RateGradient.class */
    public static class RateGradient extends Default {
        private final DenseMatrix64F matrixJacobianQInv;
        private final DenseMatrix64F matrixJacobianN;
        private final DenseMatrix64F matrix0;
        private final ArbitraryBranchRates branchRateModel;

        public RateGradient(int i, Tree tree, BranchRateModel branchRateModel) {
            super(i, tree);
            this.branchRateModel = branchRateModel instanceof ArbitraryBranchRates ? (ArbitraryBranchRates) branchRateModel : null;
            this.matrixJacobianQInv = new DenseMatrix64F(i, i);
            this.matrixJacobianN = new DenseMatrix64F(i, 1);
            this.matrix0 = new DenseMatrix64F(i, i);
        }

        @Override // dr.evomodel.treedatalikelihood.continuous.ContinuousTraitGradientForBranch.Default, dr.evomodel.treedatalikelihood.continuous.ContinuousTraitGradientForBranch
        public int getParameterIndexFromNode(NodeRef nodeRef) {
            if (this.tree.isRoot(nodeRef)) {
                return 0;
            }
            return this.branchRateModel == null ? nodeRef.getNumber() : this.branchRateModel.getParameterIndexFromNode(nodeRef);
        }

        @Override // dr.evomodel.treedatalikelihood.continuous.ContinuousTraitGradientForBranch
        public int getDimension() {
            return 1;
        }

        @Override // dr.evomodel.treedatalikelihood.continuous.ContinuousTraitGradientForBranch.Default
        public double[] chainRule(BranchSufficientStatistics branchSufficientStatistics, NodeRef nodeRef, DenseMatrix64F denseMatrix64F, DenseMatrix64F denseMatrix64F2) {
            double branchRateDifferential = this.branchRateModel.getBranchRateDifferential(this.tree, nodeRef) / this.branchRateModel.getBranchRate(this.tree, nodeRef);
            DenseMatrix64F denseMatrix64F3 = this.matrixJacobianQInv;
            CommonOps.scale(branchRateDifferential, branchSufficientStatistics.getBranch().getRawVariance(), denseMatrix64F3);
            double[] dArr = new double[1];
            for (int i = 0; i < denseMatrix64F3.getNumElements(); i++) {
                dArr[0] = dArr[0] + (denseMatrix64F3.get(i) * denseMatrix64F.get(i));
            }
            DenseMatrix64F denseMatrix64F4 = this.matrixJacobianN;
            CommonOps.scale(branchRateDifferential, branchSufficientStatistics.getBranch().getRawDisplacement(), denseMatrix64F4);
            for (int i2 = 0; i2 < denseMatrix64F4.numRows; i2++) {
                dArr[0] = dArr[0] + (denseMatrix64F4.get(i2) * denseMatrix64F2.get(i2));
            }
            return dArr;
        }

        @Override // dr.evomodel.treedatalikelihood.continuous.ContinuousTraitGradientForBranch.Default
        public double[] chainRuleRoot(BranchSufficientStatistics branchSufficientStatistics, NodeRef nodeRef, DenseMatrix64F denseMatrix64F, DenseMatrix64F denseMatrix64F2) {
            return new double[1];
        }
    }

    /* loaded from: input_file:dr/evomodel/treedatalikelihood/continuous/ContinuousTraitGradientForBranch$SamplingVarianceGradient.class */
    public static class SamplingVarianceGradient extends ContinuousProcessParameterGradient {
        ModelExtensionProvider.NormalExtensionProvider dataModel;

        public SamplingVarianceGradient(int i, Tree tree, ContinuousDataLikelihoodDelegate continuousDataLikelihoodDelegate, ModelExtensionProvider.NormalExtensionProvider normalExtensionProvider) {
            super(i, tree, continuousDataLikelihoodDelegate, Arrays.asList(ContinuousProcessParameterGradient.DerivationParameter.WRT_SAMPLING_VARIANCE));
            this.dataModel = normalExtensionProvider;
        }

        @Override // dr.evomodel.treedatalikelihood.continuous.ContinuousTraitGradientForBranch.Default, dr.evomodel.treedatalikelihood.continuous.ContinuousTraitGradientForBranch
        public double[] getGradientForBranch(BranchSufficientStatistics branchSufficientStatistics, NodeRef nodeRef) {
            return !this.tree.isExternal(nodeRef) ? new double[getDimension()] : getGradientForBranch(branchSufficientStatistics, nodeRef, true, false);
        }

        @Override // dr.evomodel.treedatalikelihood.continuous.ContinuousTraitGradientForBranch.Default
        void getSufficientStatistics(BranchSufficientStatistics branchSufficientStatistics) {
            NormalSufficientStatistics below = branchSufficientStatistics.getBelow();
            NormalSufficientStatistics above = branchSufficientStatistics.getAbove();
            DenseMatrix64F extensionVariance = this.dataModel.getExtensionVariance();
            DenseMatrix64F denseMatrix64F = new DenseMatrix64F(this.dim, this.dim);
            DenseMatrix64F denseMatrix64F2 = new DenseMatrix64F(this.dim, this.dim);
            CommonOps.add(above.getRawVariance(), extensionVariance, denseMatrix64F);
            MissingOps.safeInvert2(denseMatrix64F, denseMatrix64F2, false);
            NormalSufficientStatistics normalSufficientStatistics = new NormalSufficientStatistics(above.getRawMeanCopy(), denseMatrix64F2, denseMatrix64F);
            int[] missing = branchSufficientStatistics.getMissing();
            DenseMatrix64F denseMatrix64F3 = new DenseMatrix64F(this.dim, this.dim);
            for (int i = 0; i < this.dim; i++) {
                denseMatrix64F3.unsafe_set(i, i, Double.POSITIVE_INFINITY);
            }
            for (int i2 : missing) {
                denseMatrix64F3.unsafe_set(i2, i2, 0.0d);
            }
            NormalSufficientStatistics computeJointStatistics = BranchRateGradient.ContinuousTraitGradientForBranch.Default.computeJointStatistics(new NormalSufficientStatistics(below.getRawMeanCopy(), denseMatrix64F3), normalSufficientStatistics, this.dim);
            this.matrixQ = normalSufficientStatistics.getRawPrecision();
            this.matrixW = normalSufficientStatistics.getRawVariance();
            this.matrixV = computeJointStatistics.getRawVariance();
            for (int i3 = 0; i3 < this.dim; i3++) {
                this.matrixDelta.unsafe_set(i3, 0, computeJointStatistics.getMean(i3) - normalSufficientStatistics.getMean(i3));
            }
        }

        @Override // dr.evomodel.treedatalikelihood.continuous.ContinuousTraitGradientForBranch.ContinuousProcessParameterGradient, dr.evomodel.treedatalikelihood.continuous.ContinuousTraitGradientForBranch.Default, dr.evomodel.treedatalikelihood.continuous.ContinuousTraitGradientForBranch
        public int getParameterIndexFromNode(NodeRef nodeRef) {
            return 0;
        }
    }

    double[] getGradientForBranch(BranchSufficientStatistics branchSufficientStatistics, NodeRef nodeRef);

    int getParameterIndexFromNode(NodeRef nodeRef);

    int getDimension();
}
