package dr.evomodel.treedatalikelihood.discrete;

import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evomodel.treedatalikelihood.BeagleDataLikelihoodDelegate;
import dr.evomodel.treedatalikelihood.TreeDataLikelihood;
import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.loggers.Loggable;
import dr.inference.model.Parameter;
import dr.xml.Reportable;

/* loaded from: input_file:dr/evomodel/treedatalikelihood/discrete/LocalBranchRateGradientForDiscreteTrait.class */
public class LocalBranchRateGradientForDiscreteTrait extends DiscreteTraitBranchRateGradient implements GradientWrtParameterProvider, Reportable, Loggable {
    public LocalBranchRateGradientForDiscreteTrait(String str, TreeDataLikelihood treeDataLikelihood, BeagleDataLikelihoodDelegate beagleDataLikelihoodDelegate, Parameter parameter, boolean z) {
        super(str, treeDataLikelihood, beagleDataLikelihoodDelegate, parameter, z);
    }

    @Override // dr.evomodel.treedatalikelihood.discrete.DiscreteTraitBranchRateGradient, dr.inference.hmc.GradientWrtParameterProvider
    public double[] getGradientLogDensity() {
        double[] gradientLogDensity = super.getGradientLogDensity();
        double[] dArr = new double[this.tree.getNodeCount() - 1];
        for (int i = 0; i < this.tree.getNodeCount(); i++) {
            NodeRef node = this.tree.getNode(i);
            if (!this.tree.isRoot(node)) {
                dArr[this.branchRateModel.getParameterIndexFromNode(node)] = getSubTreeGradient(this.tree, node, gradientLogDensity);
            }
        }
        return dArr;
    }

    private double getSubTreeGradient(Tree tree, NodeRef nodeRef, double[] dArr) {
        double branchRateDifferential = this.branchRateModel.getBranchRateDifferential(tree, nodeRef);
        double d = dArr[this.branchRateModel.getParameterIndexFromNode(nodeRef)];
        if (tree.isExternal(nodeRef)) {
            return branchRateDifferential * d;
        }
        double d2 = branchRateDifferential * d;
        for (int i = 0; i < tree.getChildCount(nodeRef); i++) {
            d2 += getSubTreeGradient(tree, tree.getChild(nodeRef, i), dArr);
        }
        return d2;
    }
}
