package dr.evomodel.treedatalikelihood.hmc;

import dr.evomodel.treedatalikelihood.continuous.BranchSpecificGradient;
import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.model.DiagonalMatrix;
import dr.inference.model.Likelihood;
import dr.inference.model.MatrixParameterInterface;
import dr.inference.model.Parameter;
import dr.math.MultivariateFunction;
import dr.math.NumericalDerivative;
import dr.math.matrixAlgebra.Vector;
import dr.xml.Reportable;

/* loaded from: input_file:dr/evomodel/treedatalikelihood/hmc/DiagonalAttenuationGradient.class */
public class DiagonalAttenuationGradient implements GradientWrtParameterProvider, Reportable {
    private final Likelihood likelihood;
    private final int dim;
    private final BranchSpecificGradient branchSpecificGradient;
    private final DiagonalMatrix attenuation;
    static final /* synthetic */ boolean $assertionsDisabled;

    public DiagonalAttenuationGradient(BranchSpecificGradient branchSpecificGradient, Likelihood likelihood, MatrixParameterInterface matrixParameterInterface) {
        if (!$assertionsDisabled && !(matrixParameterInterface instanceof DiagonalMatrix)) {
            throw new AssertionError("DiagonalAttenuationGradient can only be applied to a DiagonalMatrix.");
        }
        this.attenuation = (DiagonalMatrix) matrixParameterInterface;
        this.branchSpecificGradient = branchSpecificGradient;
        this.likelihood = likelihood;
        this.dim = matrixParameterInterface.getColumnDimension();
    }

    @Override // dr.inference.hmc.GradientWrtParameterProvider
    public Likelihood getLikelihood() {
        return this.likelihood;
    }

    @Override // dr.inference.hmc.GradientWrtParameterProvider
    public Parameter getParameter() {
        return this.attenuation.getDiagonalParameter();
    }

    @Override // dr.inference.hmc.GradientWrtParameterProvider
    public int getDimension() {
        return this.dim;
    }

    @Override // dr.inference.hmc.GradientWrtParameterProvider
    public double[] getGradientLogDensity() {
        return extractDiagonalGradient(this.branchSpecificGradient.getGradientLogDensity());
    }

    private double[] extractDiagonalGradient(double[] dArr) {
        double[] dArr2 = new double[this.dim];
        for (int i = 0; i < this.dim; i++) {
            dArr2[i] = dArr[i];
        }
        return dArr2;
    }

    String getReportString(double[] dArr, double[] dArr2) {
        return getClass().getCanonicalName() + "\nanalytic: " + new Vector(dArr) + "\nnumeric: " + new Vector(dArr2) + "\n";
    }

    @Override // dr.xml.Reportable
    public String getReport() {
        return checkNumeric(getGradientLogDensity());
    }

    MultivariateFunction getNumeric() {
        return new MultivariateFunction() { // from class: dr.evomodel.treedatalikelihood.hmc.DiagonalAttenuationGradient.1
            @Override // dr.math.MultivariateFunction
            public double evaluate(double[] dArr) {
                for (int i = 0; i < dArr.length; i++) {
                    DiagonalAttenuationGradient.this.attenuation.setParameterValue(i, dArr[i]);
                }
                DiagonalAttenuationGradient.this.likelihood.makeDirty();
                return DiagonalAttenuationGradient.this.likelihood.getLogLikelihood();
            }

            @Override // dr.math.MultivariateFunction
            public int getNumArguments() {
                return DiagonalAttenuationGradient.this.attenuation.getColumnDimension();
            }

            @Override // dr.math.MultivariateFunction
            public double getLowerBound(int i) {
                return 0.0d;
            }

            @Override // dr.math.MultivariateFunction
            public double getUpperBound(int i) {
                return Double.POSITIVE_INFINITY;
            }
        };
    }

    String checkNumeric(double[] dArr) {
        System.err.println("Numeric at: \n" + new Vector(this.attenuation.getParameterValues()));
        double[] parameterValues = this.attenuation.getDiagonalParameter().getParameterValues();
        double[] gradient = NumericalDerivative.gradient(getNumeric(), parameterValues);
        for (int i = 0; i < parameterValues.length; i++) {
            this.attenuation.setParameterValue(i, parameterValues[i]);
        }
        return getReportString(dArr, gradient);
    }

    static {
        $assertionsDisabled = !DiagonalAttenuationGradient.class.desiredAssertionStatus();
    }
}
