package dr.evomodel.branchratemodel;

import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evomodel.branchratemodel.AutoCorrelatedBranchRatesDistribution;
import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.model.Likelihood;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.xml.Reportable;

/* loaded from: input_file:dr/evomodel/branchratemodel/AutoCorrelatedGradientWrtIncrements.class */
public class AutoCorrelatedGradientWrtIncrements implements GradientWrtParameterProvider, Reportable {
    private final AutoCorrelatedBranchRatesDistribution distribution;
    private final ArbitraryBranchRates branchRates;
    private final Tree tree;
    private final AutoCorrelatedBranchRatesDistribution.BranchVarianceScaling scaling;
    private final AutoCorrelatedBranchRatesDistribution.BranchRateUnits units;
    private Parameter parameter;
    private double[] cachedIncrements;

    public AutoCorrelatedGradientWrtIncrements(AutoCorrelatedBranchRatesDistribution autoCorrelatedBranchRatesDistribution) {
        this.distribution = autoCorrelatedBranchRatesDistribution;
        this.branchRates = autoCorrelatedBranchRatesDistribution.getBranchRateModel();
        this.tree = autoCorrelatedBranchRatesDistribution.getTree();
        this.scaling = autoCorrelatedBranchRatesDistribution.getScaling();
        this.units = autoCorrelatedBranchRatesDistribution.getUnits();
    }

    @Override // dr.inference.hmc.GradientWrtParameterProvider
    public Likelihood getLikelihood() {
        return this.distribution.getLikelihood();
    }

    @Override // dr.inference.hmc.GradientWrtParameterProvider
    public Parameter getParameter() {
        if (this.parameter == null) {
            this.parameter = createParameter();
        }
        return this.parameter;
    }

    @Override // dr.inference.hmc.GradientWrtParameterProvider
    public int getDimension() {
        return this.distribution.getDimension();
    }

    @Override // dr.inference.hmc.GradientWrtParameterProvider
    public double[] getGradientLogDensity() {
        double[] gradientWrtIncrements = this.distribution.getGradientWrtIncrements();
        if (this.units.needsIncrementCorrection()) {
            recursePostOrderToCorrectGradient(this.tree.getRoot(), gradientWrtIncrements);
        }
        return gradientWrtIncrements;
    }

    public AutoCorrelatedBranchRatesDistribution getDistribution() {
        return this.distribution;
    }

    private int recursePostOrderToCorrectGradient(NodeRef nodeRef, double[] dArr) {
        int i = 1;
        if (!this.tree.isExternal(nodeRef)) {
            i = 1 + recursePostOrderToCorrectGradient(this.tree.getChild(nodeRef, 0), dArr) + recursePostOrderToCorrectGradient(this.tree.getChild(nodeRef, 1), dArr);
        }
        if (!this.tree.isRoot(nodeRef)) {
            int parameterIndexFromNode = this.branchRates.getParameterIndexFromNode(nodeRef);
            dArr[parameterIndexFromNode] = dArr[parameterIndexFromNode] - this.scaling.inverseRescaleIncrement(1.0d * i, this.tree.getBranchLength(nodeRef));
        }
        return i;
    }

    @Override // dr.xml.Reportable
    public String getReport() {
        return GradientWrtParameterProvider.getReportAndCheckForError(this, Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, null);
    }

    private Parameter createParameter() {
        return new Parameter.Proxy("increments", this.distribution.getDimension()) { // from class: dr.evomodel.branchratemodel.AutoCorrelatedGradientWrtIncrements.1
            @Override // dr.inference.model.Parameter
            public double getParameterValue(int i) {
                return AutoCorrelatedGradientWrtIncrements.this.distribution.getIncrement(i);
            }

            @Override // dr.inference.model.Parameter
            public void setParameterValue(int i, double d) {
                throw new RuntimeException("Do not set single value at a time");
            }

            @Override // dr.inference.model.Parameter
            public void setParameterValueQuietly(int i, double d) {
                if (AutoCorrelatedGradientWrtIncrements.this.cachedIncrements == null) {
                    AutoCorrelatedGradientWrtIncrements.this.cachedIncrements = new double[getDimension()];
                }
                AutoCorrelatedGradientWrtIncrements.this.cachedIncrements[i] = d;
            }

            @Override // dr.inference.model.Parameter
            public void setParameterValueNotifyChangedAll(int i, double d) {
                throw new RuntimeException("Do not set single value at a time");
            }

            @Override // dr.inference.model.Parameter.Abstract, dr.inference.model.Parameter
            public void fireParameterChangedEvent(int i, Variable.ChangeType changeType) {
                double[] dArr = new double[getDimension()];
                AutoCorrelatedGradientWrtIncrements.this.recurse(AutoCorrelatedGradientWrtIncrements.this.tree.getRoot(), dArr, AutoCorrelatedGradientWrtIncrements.this.cachedIncrements, 0.0d);
                Parameter parameter = AutoCorrelatedGradientWrtIncrements.this.distribution.getParameter();
                for (int i2 = 0; i2 < dArr.length; i2++) {
                    parameter.setParameterValueQuietly(i2, dArr[i2]);
                }
                parameter.fireParameterChangedEvent();
            }

            @Override // dr.inference.model.Parameter.Abstract, dr.inference.model.Statistic.Abstract
            public String toString() {
                StringBuilder sb = new StringBuilder(String.valueOf(getParameterValue(0)));
                for (int i = 1; i < this.dim; i++) {
                    sb.append(", ").append(getParameterValue(i));
                }
                return sb.toString();
            }
        };
    }

    /* JADX INFO: Access modifiers changed from: private */
    public void recurse(NodeRef nodeRef, double[] dArr, double[] dArr2, double d) {
        double d2 = d;
        if (!this.tree.isRoot(nodeRef)) {
            int parameterIndexFromNode = this.branchRates.getParameterIndexFromNode(nodeRef);
            d2 += this.scaling.inverseRescaleIncrement(dArr2[parameterIndexFromNode], this.tree.getBranchLength(nodeRef));
            dArr[parameterIndexFromNode] = this.units.inverseTransform(d2);
        }
        if (this.tree.isExternal(nodeRef)) {
            return;
        }
        recurse(this.tree.getChild(nodeRef, 0), dArr, dArr2, d2);
        recurse(this.tree.getChild(nodeRef, 1), dArr, dArr2, d2);
    }
}
