package dr.evomodel.branchratemodel;

import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evomodel.branchratemodel.AutoCorrelatedBranchRatesDistribution;
import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.model.Likelihood;
import dr.inference.model.Parameter;
import dr.xml.Reportable;

/* loaded from: input_file:dr/evomodel/branchratemodel/BranchRateGradientWrtIncrements.class */
public class BranchRateGradientWrtIncrements implements GradientWrtParameterProvider, Reportable {
    private final GradientWrtParameterProvider rateGradientProvider;
    private final AutoCorrelatedGradientWrtIncrements priorGradientProvider;
    private final ArbitraryBranchRates branchRates;
    private final Tree tree;
    private final AutoCorrelatedBranchRatesDistribution.BranchVarianceScaling scaling;
    private final AutoCorrelatedBranchRatesDistribution.BranchRateUnits units;

    public BranchRateGradientWrtIncrements(GradientWrtParameterProvider gradientWrtParameterProvider, AutoCorrelatedGradientWrtIncrements autoCorrelatedGradientWrtIncrements) {
        this.rateGradientProvider = gradientWrtParameterProvider;
        this.priorGradientProvider = autoCorrelatedGradientWrtIncrements;
        AutoCorrelatedBranchRatesDistribution distribution = autoCorrelatedGradientWrtIncrements.getDistribution();
        this.branchRates = distribution.getBranchRateModel();
        this.tree = distribution.getTree();
        this.scaling = distribution.getScaling();
        this.units = distribution.getUnits();
    }

    @Override // dr.inference.hmc.GradientWrtParameterProvider
    public Likelihood getLikelihood() {
        return this.rateGradientProvider.getLikelihood();
    }

    @Override // dr.inference.hmc.GradientWrtParameterProvider
    public Parameter getParameter() {
        return this.priorGradientProvider.getParameter();
    }

    @Override // dr.inference.hmc.GradientWrtParameterProvider
    public int getDimension() {
        return this.priorGradientProvider.getDimension();
    }

    @Override // dr.inference.hmc.GradientWrtParameterProvider
    public double[] getGradientLogDensity() {
        double[] gradientLogDensity = this.rateGradientProvider.getGradientLogDensity();
        double[] dArr = new double[gradientLogDensity.length];
        recursePostOrderToAccumulateGradient(this.tree.getRoot(), gradientLogDensity, dArr);
        return dArr;
    }

    private double recursePostOrderToAccumulateGradient(NodeRef nodeRef, double[] dArr, double[] dArr2) {
        double d = 0.0d;
        if (!this.tree.isExternal(nodeRef)) {
            d = 0.0d + recursePostOrderToAccumulateGradient(this.tree.getChild(nodeRef, 0), dArr, dArr2) + recursePostOrderToAccumulateGradient(this.tree.getChild(nodeRef, 1), dArr, dArr2);
        }
        if (!this.tree.isRoot(nodeRef)) {
            int parameterIndexFromNode = this.branchRates.getParameterIndexFromNode(nodeRef);
            d += this.units.inverseTransformGradient(dArr[parameterIndexFromNode], this.branchRates.getUntransformedBranchRate(this.tree, nodeRef));
            dArr2[parameterIndexFromNode] = this.scaling.inverseRescaleIncrement(d, this.tree.getBranchLength(nodeRef));
        }
        return d;
    }

    @Override // dr.xml.Reportable
    public String getReport() {
        return GradientWrtParameterProvider.getReportAndCheckForError(this, Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, null);
    }
}
