package dr.evomodel.treedatalikelihood.discrete;

import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evomodel.tree.TreeParameterModel;
import dr.evomodel.treedatalikelihood.TreeDataLikelihood;
import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.hmc.HessianWrtParameterProvider;
import dr.inference.loggers.LogColumn;
import dr.inference.loggers.Loggable;
import dr.inference.model.Likelihood;
import dr.inference.model.Parameter;
import dr.math.MachineAccuracy;
import dr.math.MultivariateFunction;
import dr.math.NumericalDerivative;
import dr.math.matrixAlgebra.Vector;
import dr.xml.Reportable;

/* loaded from: input_file:dr/evomodel/treedatalikelihood/discrete/HyperParameterGradient.class */
public abstract class HyperParameterGradient implements GradientWrtParameterProvider, HessianWrtParameterProvider, Reportable, Loggable {
    private final TreeDataLikelihood treeDataLikelihood;
    private final GradientWrtParameterProvider gradientWrtParameterProvider;
    private final Parameter parameter;
    private final Tree tree;
    private final boolean useHessian;
    protected final TreeParameterModel branchParameter;
    private static final boolean DEBUG = true;
    protected static final boolean COUNT_TOTAL_OPERATIONS = true;
    protected MultivariateFunction numeric1 = new MultivariateFunction() { // from class: dr.evomodel.treedatalikelihood.discrete.HyperParameterGradient.1
        @Override // dr.math.MultivariateFunction
        public double evaluate(double[] dArr) {
            for (int i = 0; i < dArr.length; i++) {
                HyperParameterGradient.this.parameter.setParameterValue(i, dArr[i]);
            }
            return HyperParameterGradient.this.treeDataLikelihood.getLogLikelihood();
        }

        @Override // dr.math.MultivariateFunction
        public int getNumArguments() {
            return HyperParameterGradient.this.parameter.getDimension();
        }

        @Override // dr.math.MultivariateFunction
        public double getLowerBound(int i) {
            return 0.0d;
        }

        @Override // dr.math.MultivariateFunction
        public double getUpperBound(int i) {
            return Double.POSITIVE_INFINITY;
        }
    };
    protected long getGradientLogDensityCount = 0;

    public HyperParameterGradient(TreeDataLikelihood treeDataLikelihood, GradientWrtParameterProvider gradientWrtParameterProvider, Parameter parameter, TreeParameterModel treeParameterModel, boolean z) {
        this.treeDataLikelihood = treeDataLikelihood;
        this.gradientWrtParameterProvider = gradientWrtParameterProvider;
        this.parameter = parameter;
        this.useHessian = z;
        this.tree = treeDataLikelihood.getTree();
        this.branchParameter = treeParameterModel;
    }

    @Override // dr.inference.hmc.GradientWrtParameterProvider
    public Likelihood getLikelihood() {
        return this.treeDataLikelihood;
    }

    @Override // dr.inference.hmc.GradientWrtParameterProvider
    public Parameter getParameter() {
        return this.parameter;
    }

    @Override // dr.inference.hmc.GradientWrtParameterProvider
    public int getDimension() {
        return this.parameter.getDimension();
    }

    @Override // dr.inference.hmc.GradientWrtParameterProvider
    public double[] getGradientLogDensity() {
        double[] gradientLogDensity = this.gradientWrtParameterProvider.getGradientLogDensity();
        if (gradientLogDensity.length != this.tree.getNodeCount() - 1) {
            throw new RuntimeException("Dimension mismatch!");
        }
        double[] dArr = new double[getDimension()];
        for (int i = 0; i < this.branchParameter.getParameterSize(); i++) {
            double[] differential = getDifferential(this.tree, this.tree.getNode(this.branchParameter.getNodeNumberFromParameterIndex(i)));
            for (int i2 = 0; i2 < dArr.length; i2++) {
                int i3 = i2;
                dArr[i3] = dArr[i3] + (gradientLogDensity[i] * differential[i2]);
            }
        }
        return dArr;
    }

    @Override // dr.inference.hmc.HessianWrtParameterProvider
    public double[] getDiagonalHessianLogDensity() {
        return NumericalDerivative.diagonalHessian(this.numeric1, this.parameter.getParameterValues());
    }

    abstract double[] getDifferential(Tree tree, NodeRef nodeRef);

    protected boolean valuesAreSufficientlyLarge(double[] dArr) {
        for (double d : dArr) {
            if (Math.abs(d) < MachineAccuracy.SQRT_EPSILON * 1.2d) {
                return false;
            }
        }
        return true;
    }

    @Override // dr.xml.Reportable
    public String getReport() {
        double[] parameterValues = this.parameter.getParameterValues();
        double[] dArr = null;
        boolean valuesAreSufficientlyLarge = valuesAreSufficientlyLarge(this.parameter.getParameterValues());
        double[] gradient = valuesAreSufficientlyLarge ? NumericalDerivative.gradient(this.numeric1, this.parameter.getParameterValues()) : null;
        if (this.useHessian && valuesAreSufficientlyLarge) {
            dArr = NumericalDerivative.diagonalHessian(this.numeric1, this.parameter.getParameterValues());
        }
        for (int i = 0; i < parameterValues.length; i++) {
            this.parameter.setParameterValue(i, parameterValues[i]);
        }
        StringBuilder sb = new StringBuilder();
        sb.append("Gradient Peeling: ").append(new Vector(getGradientLogDensity()));
        sb.append("\n");
        if (gradient == null || !valuesAreSufficientlyLarge) {
            sb.append("Gradient mumeric: too close to 0");
        } else {
            sb.append("Gradient numeric: ").append(new Vector(gradient));
        }
        sb.append("\n");
        if (this.useHessian) {
            if (valuesAreSufficientlyLarge) {
                sb.append("Hessian Peeling: ").append(new Vector(getDiagonalHessianLogDensity()));
                sb.append("\n");
            }
            if (dArr == null || !valuesAreSufficientlyLarge) {
                sb.append("Hessian mumeric: too close to 0");
            } else {
                sb.append("Hessian numeric: ").append(new Vector(dArr));
            }
            sb.append("\n");
        }
        sb.append("\n\tgetGradientLogDensityCount = ").append(this.getGradientLogDensityCount).append("\n");
        sb.append(this.treeDataLikelihood.getReport());
        return sb.toString();
    }

    @Override // dr.inference.loggers.Loggable
    public LogColumn[] getColumns() {
        return new LogColumn[]{new LogColumn.Default("gradient report", new Object() { // from class: dr.evomodel.treedatalikelihood.discrete.HyperParameterGradient.2
            public String toString() {
                return "\n" + HyperParameterGradient.this.getReport();
            }
        })};
    }
}
