package dr.inference.hmc;

import dr.inference.model.DerivativeOrder;
import dr.inference.model.DerivativeProvider;
import dr.inference.model.Likelihood;
import dr.inference.model.Parameter;
import dr.math.MultivariateFunction;
import dr.math.NumericalDerivative;
import dr.math.matrixAlgebra.Vector;
import dr.xml.Reportable;
import java.util.List;
import java.util.logging.Logger;

/* loaded from: input_file:dr/inference/hmc/DerivativeWrtParameterProvider.class */
public interface DerivativeWrtParameterProvider {
    public static final Double TOLERANCE = Double.valueOf(0.1d);

    /* loaded from: input_file:dr/inference/hmc/DerivativeWrtParameterProvider$CheckDerivativeNumerically.class */
    public static class CheckDerivativeNumerically {
        private final DerivativeWrtParameterProvider provider;
        private final DerivativeOrder type;
        private final Parameter parameter;
        private final double lowerBound;
        private final double upperBound;
        private final boolean checkValues;
        private final double tolerance;
        private MultivariateFunction numeric = new MultivariateFunction() { // from class: dr.inference.hmc.DerivativeWrtParameterProvider.CheckDerivativeNumerically.1
            @Override // dr.math.MultivariateFunction
            public double evaluate(double[] dArr) {
                CheckDerivativeNumerically.this.setParameter(dArr);
                if (CheckDerivativeNumerically.this.type == DerivativeOrder.GRADIENT) {
                    return CheckDerivativeNumerically.this.provider.getLikelihood().getLogLikelihood();
                }
                throw new RuntimeException("Not yet implemented");
            }

            @Override // dr.math.MultivariateFunction
            public int getNumArguments() {
                return CheckDerivativeNumerically.this.parameter.getDimension();
            }

            @Override // dr.math.MultivariateFunction
            public double getLowerBound(int i) {
                return CheckDerivativeNumerically.this.lowerBound;
            }

            @Override // dr.math.MultivariateFunction
            public double getUpperBound(int i) {
                return CheckDerivativeNumerically.this.upperBound;
            }
        };

        CheckDerivativeNumerically(DerivativeWrtParameterProvider derivativeWrtParameterProvider, DerivativeOrder derivativeOrder, double d, double d2, Double d3) {
            this.provider = derivativeWrtParameterProvider;
            this.type = derivativeOrder;
            this.parameter = derivativeWrtParameterProvider.getParameter();
            this.lowerBound = d;
            this.upperBound = d2;
            this.checkValues = d3 != null;
            this.tolerance = this.checkValues ? d3.doubleValue() : 0.0d;
        }

        /* JADX INFO: Access modifiers changed from: private */
        public void setParameter(double[] dArr) {
            for (int i = 0; i < dArr.length; i++) {
                this.parameter.setParameterValueQuietly(i, dArr[i]);
            }
            this.parameter.fireParameterChangedEvent();
        }

        private double[] getNumericalGradient() {
            double[] parameterValues = this.parameter.getParameterValues();
            double[] gradient = NumericalDerivative.gradient(this.numeric, this.parameter.getParameterValues());
            setParameter(parameterValues);
            return gradient;
        }

        public String getReport() throws MismatchException {
            return DerivativeWrtParameterProvider.makeReport("Gradient\n", this.provider.getDerivativeLogDensity(this.type), getNumericalGradient(), this.checkValues, this.tolerance);
        }
    }

    /* loaded from: input_file:dr/inference/hmc/DerivativeWrtParameterProvider$MismatchException.class */
    public static class MismatchException extends Exception {
    }

    /* loaded from: input_file:dr/inference/hmc/DerivativeWrtParameterProvider$ParameterWrapper.class */
    public static class ParameterWrapper implements DerivativeWrtParameterProvider, Reportable {
        final DerivativeProvider provider;
        final Parameter parameter;
        final Likelihood likelihood;
        static final /* synthetic */ boolean $assertionsDisabled;

        public ParameterWrapper(DerivativeProvider derivativeProvider, Parameter parameter, Likelihood likelihood) {
            this.provider = derivativeProvider;
            this.parameter = parameter;
            this.likelihood = likelihood;
        }

        @Override // dr.inference.hmc.DerivativeWrtParameterProvider
        public Likelihood getLikelihood() {
            return this.likelihood;
        }

        @Override // dr.inference.hmc.DerivativeWrtParameterProvider
        public Parameter getParameter() {
            return this.parameter;
        }

        @Override // dr.inference.hmc.DerivativeWrtParameterProvider
        public int getDimension(DerivativeOrder derivativeOrder) {
            return derivativeOrder.getDerivativeDimension(this.parameter.getDimension());
        }

        @Override // dr.inference.hmc.DerivativeWrtParameterProvider
        public double[] getDerivativeLogDensity(DerivativeOrder derivativeOrder) {
            if ($assertionsDisabled || this.provider.getHighestOrder().getValue() >= derivativeOrder.getValue()) {
                return this.provider.getDerivativeLogDensity(this.parameter.getParameterValues(), derivativeOrder);
            }
            throw new AssertionError();
        }

        @Override // dr.inference.hmc.DerivativeWrtParameterProvider
        public DerivativeOrder getHighestOrder() {
            return this.provider.getHighestOrder();
        }

        @Override // dr.xml.Reportable
        public String getReport() {
            return null;
        }

        static {
            $assertionsDisabled = !DerivativeWrtParameterProvider.class.desiredAssertionStatus();
        }
    }

    Likelihood getLikelihood();

    Parameter getParameter();

    int getDimension(DerivativeOrder derivativeOrder);

    double[] getDerivativeLogDensity(DerivativeOrder derivativeOrder);

    DerivativeOrder getHighestOrder();

    static DerivativeOrder getHighestOrder(List<DerivativeWrtParameterProvider> list) {
        if (list.size() == 0) {
            return DerivativeOrder.ZEROTH;
        }
        DerivativeOrder derivativeOrder = DerivativeOrder.FULL_HESSIAN;
        for (DerivativeWrtParameterProvider derivativeWrtParameterProvider : list) {
            if (derivativeWrtParameterProvider.getHighestOrder().getValue() < derivativeOrder.getValue()) {
                derivativeOrder = derivativeWrtParameterProvider.getHighestOrder();
            }
        }
        return derivativeOrder;
    }

    static String makeReport(String str, double[] dArr, double[] dArr2, boolean z, double d) throws MismatchException {
        StringBuilder sb = new StringBuilder(str);
        sb.append("analytic: ").append(new Vector(dArr));
        sb.append("\n");
        sb.append("numeric : ").append(new Vector(dArr2));
        if (z) {
            for (int i = 0; i < dArr.length; i++) {
                double d2 = (2.0d * (dArr[i] - dArr2[i])) / (dArr[i] + dArr2[i]);
                if (Math.abs(d2) > d) {
                    sb.append("\nDifference @ ").append(i + 1).append(": ").append(dArr[i]).append(" ").append(dArr2[i]).append(" ").append(d2).append("\n");
                    Logger.getLogger("dr.inference.hmc").info(sb.toString());
                    throw new MismatchException();
                }
            }
        }
        return sb.toString();
    }

    static String getReportAndCheckForError(DerivativeWrtParameterProvider derivativeWrtParameterProvider, DerivativeOrder derivativeOrder, double d, double d2, Double d3) {
        try {
            return new CheckDerivativeNumerically(derivativeWrtParameterProvider, derivativeOrder, d, d2, d3).getReport();
        } catch (MismatchException e) {
            String message = e.getMessage();
            if (message == null) {
                message = derivativeWrtParameterProvider.getParameter().getParameterName();
            }
            if (message == null) {
                message = "Gradient check failure";
            }
            throw new RuntimeException(message);
        }
    }
}
