package dr.inference.operators.hmc;

import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.hmc.HessianWrtParameterProvider;
import dr.inference.model.Likelihood;
import dr.inference.model.Parameter;
import dr.math.MachineAccuracy;
import dr.xml.Reportable;

/* loaded from: input_file:dr/inference/operators/hmc/NumericalHessianFromGradient.class */
public class NumericalHessianFromGradient implements HessianWrtParameterProvider, Reportable {
    private final GradientWrtParameterProvider gradientProvider;

    public NumericalHessianFromGradient(GradientWrtParameterProvider gradientWrtParameterProvider) {
        this.gradientProvider = gradientWrtParameterProvider;
    }

    @Override // dr.inference.hmc.HessianWrtParameterProvider
    public double[] getDiagonalHessianLogDensity() {
        int dimension = this.gradientProvider.getDimension();
        double[][] numericalHessianCentral = getNumericalHessianCentral();
        double[] dArr = new double[dimension];
        for (int i = 0; i < dimension; i++) {
            dArr[i] = numericalHessianCentral[i][i];
        }
        return dArr;
    }

    @Override // dr.inference.hmc.HessianWrtParameterProvider
    public double[][] getHessianLogDensity() {
        return getNumericalHessianCentral();
    }

    @Override // dr.inference.hmc.GradientWrtParameterProvider
    public Likelihood getLikelihood() {
        return this.gradientProvider.getLikelihood();
    }

    @Override // dr.inference.hmc.GradientWrtParameterProvider
    public Parameter getParameter() {
        return this.gradientProvider.getParameter();
    }

    @Override // dr.inference.hmc.GradientWrtParameterProvider
    public int getDimension() {
        return this.gradientProvider.getDimension();
    }

    @Override // dr.inference.hmc.GradientWrtParameterProvider
    public double[] getGradientLogDensity() {
        return this.gradientProvider.getGradientLogDensity();
    }

    private double[][] getNumericalHessianCentral() {
        int dimension = this.gradientProvider.getDimension();
        double[][] dArr = new double[dimension][dimension];
        double[] parameterValues = this.gradientProvider.getParameter().getParameterValues();
        double[][] dArr2 = new double[dimension][dimension];
        double[][] dArr3 = new double[dimension][dimension];
        double[] dArr4 = new double[dimension];
        for (int i = 0; i < dimension; i++) {
            dArr4[i] = MachineAccuracy.SQRT_SQRT_EPSILON * (Math.abs(parameterValues[i]) + 1.0d);
            this.gradientProvider.getParameter().setParameterValue(i, parameterValues[i] + dArr4[i]);
            dArr2[i] = this.gradientProvider.getGradientLogDensity();
            this.gradientProvider.getParameter().setParameterValue(i, parameterValues[i] - dArr4[i]);
            dArr3[i] = this.gradientProvider.getGradientLogDensity();
            this.gradientProvider.getParameter().setParameterValue(i, parameterValues[i]);
        }
        for (int i2 = 0; i2 < dimension; i2++) {
            for (int i3 = i2; i3 < dimension; i3++) {
                double d = ((dArr2[i3][i2] - dArr3[i3][i2]) / (4.0d * dArr4[i3])) + ((dArr2[i2][i3] - dArr3[i2][i3]) / (4.0d * dArr4[i2]));
                dArr[i2][i3] = d;
                dArr[i3][i2] = d;
            }
        }
        return dArr;
    }

    @Override // dr.xml.Reportable
    public String getReport() {
        return GradientWrtParameterProvider.getReportAndCheckForError(this, Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, GradientWrtParameterProvider.TOLERANCE);
    }
}
