package dr.inference.operators.hmc;

import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.hmc.PathGradient;
import dr.inference.model.Likelihood;
import dr.inference.model.Parameter;
import dr.inference.operators.AbstractAdaptableOperator;
import dr.inference.operators.AdaptationMode;
import dr.inference.operators.GeneralOperator;
import dr.inference.operators.PathDependent;
import dr.inference.operators.hmc.MassPreconditioner;
import dr.math.MathUtils;
import dr.math.MultivariateFunction;
import dr.math.NumericalDerivative;
import dr.math.matrixAlgebra.ReadableVector;
import dr.math.matrixAlgebra.WrappedVector;
import dr.util.Transform;

/* loaded from: input_file:dr/inference/operators/hmc/HamiltonianMonteCarloOperator.class */
public class HamiltonianMonteCarloOperator extends AbstractAdaptableOperator implements GeneralOperator, PathDependent {
    final GradientWrtParameterProvider gradientProvider;
    protected double stepSize;
    LeapFrogEngine leapFrogEngine;
    protected final Parameter parameter;
    protected final MassPreconditioner preconditioning;
    private final Options runtimeOptions;
    protected final double[] mask;
    protected final Transform transform;
    private static final boolean DEBUG = false;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:dr/inference/operators/hmc/HamiltonianMonteCarloOperator$InstabilityHandler.class */
    public enum InstabilityHandler {
        REJECT { // from class: dr.inference.operators.hmc.HamiltonianMonteCarloOperator.InstabilityHandler.1
            @Override // dr.inference.operators.hmc.HamiltonianMonteCarloOperator.InstabilityHandler
            void checkValue(double d) throws NumericInstabilityException {
                if (Double.isNaN(d)) {
                    throw new NumericInstabilityException();
                }
            }

            @Override // dr.inference.operators.hmc.HamiltonianMonteCarloOperator.InstabilityHandler
            void checkPosition(Transform transform, double[] dArr) throws NumericInstabilityException {
                if (!transform.isInInteriorDomain(dArr, 0, dArr.length)) {
                    throw new NumericInstabilityException();
                }
            }

            @Override // dr.inference.operators.hmc.HamiltonianMonteCarloOperator.InstabilityHandler
            boolean checkPositionTransform() {
                return true;
            }
        },
        DEBUG { // from class: dr.inference.operators.hmc.HamiltonianMonteCarloOperator.InstabilityHandler.2
            @Override // dr.inference.operators.hmc.HamiltonianMonteCarloOperator.InstabilityHandler
            void checkValue(double d) throws NumericInstabilityException {
                if (Double.isNaN(d)) {
                    System.err.println("Numerical instability in HMC momentum; throwing exception");
                    throw new NumericInstabilityException();
                }
            }

            @Override // dr.inference.operators.hmc.HamiltonianMonteCarloOperator.InstabilityHandler
            void checkPosition(Transform transform, double[] dArr) throws NumericInstabilityException {
                if (transform.isInInteriorDomain(dArr, 0, dArr.length)) {
                    return;
                }
                System.err.println("Numerical instability in HMC momentum; throwing exception");
                throw new NumericInstabilityException();
            }

            @Override // dr.inference.operators.hmc.HamiltonianMonteCarloOperator.InstabilityHandler
            boolean checkPositionTransform() {
                return true;
            }
        },
        IGNORE { // from class: dr.inference.operators.hmc.HamiltonianMonteCarloOperator.InstabilityHandler.3
            @Override // dr.inference.operators.hmc.HamiltonianMonteCarloOperator.InstabilityHandler
            void checkValue(double d) {
            }

            @Override // dr.inference.operators.hmc.HamiltonianMonteCarloOperator.InstabilityHandler
            void checkPosition(Transform transform, double[] dArr) throws NumericInstabilityException {
            }

            @Override // dr.inference.operators.hmc.HamiltonianMonteCarloOperator.InstabilityHandler
            boolean checkPositionTransform() {
                return false;
            }
        };

        abstract void checkValue(double d) throws NumericInstabilityException;

        abstract void checkPosition(Transform transform, double[] dArr) throws NumericInstabilityException;

        abstract boolean checkPositionTransform();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:dr/inference/operators/hmc/HamiltonianMonteCarloOperator$LeapFrogEngine.class */
    public interface LeapFrogEngine {

        /* loaded from: input_file:dr/inference/operators/hmc/HamiltonianMonteCarloOperator$LeapFrogEngine$Default.class */
        public static class Default implements LeapFrogEngine {
            protected final Parameter parameter;
            final InstabilityHandler instabilityHandler;
            private final MassPreconditioner preconditioning;
            final double[] mask;
            double[] lastGradient;
            double[] lastPosition;

            /* JADX INFO: Access modifiers changed from: package-private */
            public Default(Parameter parameter, InstabilityHandler instabilityHandler, MassPreconditioner massPreconditioner, double[] dArr) {
                this.parameter = parameter;
                this.instabilityHandler = instabilityHandler;
                this.preconditioning = massPreconditioner;
                this.mask = dArr;
            }

            @Override // dr.inference.operators.hmc.HamiltonianMonteCarloOperator.LeapFrogEngine
            public double[] getInitialPosition() {
                return this.parameter.getParameterValues();
            }

            @Override // dr.inference.operators.hmc.HamiltonianMonteCarloOperator.LeapFrogEngine
            public double getParameterLogJacobian() {
                return 0.0d;
            }

            @Override // dr.inference.operators.hmc.HamiltonianMonteCarloOperator.LeapFrogEngine
            public double[] getLastGradient() {
                return this.lastGradient;
            }

            @Override // dr.inference.operators.hmc.HamiltonianMonteCarloOperator.LeapFrogEngine
            public double[] getLastPosition() {
                return this.lastPosition;
            }

            @Override // dr.inference.operators.hmc.HamiltonianMonteCarloOperator.LeapFrogEngine
            public void updateMomentum(double[] dArr, double[] dArr2, double[] dArr3, double d) throws NumericInstabilityException {
                int length = dArr2.length;
                for (int i = 0; i < length; i++) {
                    int i2 = i;
                    dArr2[i2] = dArr2[i2] + (d * dArr3[i]);
                    this.instabilityHandler.checkValue(dArr2[i]);
                }
                this.lastGradient = dArr3;
                this.lastPosition = dArr;
            }

            @Override // dr.inference.operators.hmc.HamiltonianMonteCarloOperator.LeapFrogEngine
            public void updatePosition(double[] dArr, WrappedVector wrappedVector, double d) throws NumericInstabilityException {
                int dim = wrappedVector.getDim();
                for (int i = 0; i < dim; i++) {
                    int i2 = i;
                    dArr[i2] = dArr[i2] + (d * this.preconditioning.getVelocity(i, wrappedVector));
                    this.instabilityHandler.checkValue(dArr[i]);
                }
                setParameter(dArr);
            }

            @Override // dr.inference.operators.hmc.HamiltonianMonteCarloOperator.LeapFrogEngine
            public void setParameter(double[] dArr) {
                ReadableVector.Utils.setParameter(dArr, this.parameter);
            }
        }

        /* loaded from: input_file:dr/inference/operators/hmc/HamiltonianMonteCarloOperator$LeapFrogEngine$WithTransform.class */
        public static class WithTransform extends Default {
            private final Transform transform;
            double[] unTransformedPosition;

            private WithTransform(Parameter parameter, Transform transform, InstabilityHandler instabilityHandler, MassPreconditioner massPreconditioner, double[] dArr) {
                super(parameter, instabilityHandler, massPreconditioner, dArr);
                this.transform = transform;
            }

            @Override // dr.inference.operators.hmc.HamiltonianMonteCarloOperator.LeapFrogEngine.Default, dr.inference.operators.hmc.HamiltonianMonteCarloOperator.LeapFrogEngine
            public double getParameterLogJacobian() {
                return this.transform.getLogJacobian(this.unTransformedPosition, 0, this.unTransformedPosition.length);
            }

            @Override // dr.inference.operators.hmc.HamiltonianMonteCarloOperator.LeapFrogEngine.Default, dr.inference.operators.hmc.HamiltonianMonteCarloOperator.LeapFrogEngine
            public double[] getInitialPosition() {
                this.unTransformedPosition = super.getInitialPosition();
                return this.transform.transform(this.unTransformedPosition, 0, this.unTransformedPosition.length);
            }

            @Override // dr.inference.operators.hmc.HamiltonianMonteCarloOperator.LeapFrogEngine.Default, dr.inference.operators.hmc.HamiltonianMonteCarloOperator.LeapFrogEngine
            public void updateMomentum(double[] dArr, double[] dArr2, double[] dArr3, double d) throws NumericInstabilityException {
                double[] updateGradientLogDensity = this.transform.updateGradientLogDensity(dArr3, this.unTransformedPosition, 0, this.unTransformedPosition.length);
                HamiltonianMonteCarloOperator.mask(updateGradientLogDensity, this.mask);
                super.updateMomentum(dArr, dArr2, updateGradientLogDensity, d);
            }

            @Override // dr.inference.operators.hmc.HamiltonianMonteCarloOperator.LeapFrogEngine.Default, dr.inference.operators.hmc.HamiltonianMonteCarloOperator.LeapFrogEngine
            public void updatePosition(double[] dArr, WrappedVector wrappedVector, double d) throws NumericInstabilityException {
                super.updatePosition(dArr, wrappedVector, d);
                if (this.instabilityHandler.checkPositionTransform()) {
                    checkPosition(this.unTransformedPosition);
                }
            }

            @Override // dr.inference.operators.hmc.HamiltonianMonteCarloOperator.LeapFrogEngine.Default, dr.inference.operators.hmc.HamiltonianMonteCarloOperator.LeapFrogEngine
            public void setParameter(double[] dArr) {
                this.unTransformedPosition = this.transform.inverse(dArr, 0, dArr.length);
                super.setParameter(this.unTransformedPosition);
            }

            private void checkPosition(double[] dArr) throws NumericInstabilityException {
                this.instabilityHandler.checkPosition(this.transform, dArr);
            }
        }

        double[] getInitialPosition();

        double getParameterLogJacobian();

        void updateMomentum(double[] dArr, double[] dArr2, double[] dArr3, double d) throws NumericInstabilityException;

        void updatePosition(double[] dArr, WrappedVector wrappedVector, double d) throws NumericInstabilityException;

        void setParameter(double[] dArr);

        double[] getLastGradient();

        double[] getLastPosition();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:dr/inference/operators/hmc/HamiltonianMonteCarloOperator$NumericInstabilityException.class */
    public static class NumericInstabilityException extends Exception {
        NumericInstabilityException() {
        }
    }

    /* loaded from: input_file:dr/inference/operators/hmc/HamiltonianMonteCarloOperator$Options.class */
    public static class Options {
        final double initialStepSize;
        final int nSteps;
        final double randomStepCountFraction;
        final int preconditioningUpdateFrequency;
        final int preconditioningDelay;
        final int preconditioningMemory;
        final int gradientCheckCount;
        final double gradientCheckTolerance;
        final int checkStepSizeMaxIterations;
        final double checkStepSizeReductionFactor;
        final double targetAcceptanceProbability;

        public Options(double d, int i, double d2, int i2, int i3, int i4, int i5, double d3, int i6, double d4, double d5) {
            this.initialStepSize = d;
            this.nSteps = i;
            this.randomStepCountFraction = d2;
            this.preconditioningUpdateFrequency = i2;
            this.preconditioningDelay = i3;
            this.preconditioningMemory = i4;
            this.gradientCheckCount = i5;
            this.gradientCheckTolerance = d3;
            this.checkStepSizeMaxIterations = i6;
            this.checkStepSizeReductionFactor = d4;
            this.targetAcceptanceProbability = d5;
        }
    }

    public HamiltonianMonteCarloOperator(AdaptationMode adaptationMode, double d, GradientWrtParameterProvider gradientWrtParameterProvider, Parameter parameter, Transform transform, Parameter parameter2, Options options, MassPreconditioner.Type type) {
        super(adaptationMode, options.targetAcceptanceProbability);
        setWeight(d);
        this.gradientProvider = gradientWrtParameterProvider;
        this.runtimeOptions = options;
        this.stepSize = options.initialStepSize;
        this.preconditioning = type.factory(gradientWrtParameterProvider, transform, options);
        this.parameter = parameter;
        this.mask = buildMask(parameter2);
        this.transform = transform;
        this.leapFrogEngine = constructLeapFrogEngine(transform);
    }

    protected LeapFrogEngine constructLeapFrogEngine(Transform transform) {
        return transform != null ? new LeapFrogEngine.WithTransform(this.parameter, transform, getDefaultInstabilityHandler(), this.preconditioning, this.mask) : new LeapFrogEngine.Default(this.parameter, getDefaultInstabilityHandler(), this.preconditioning, this.mask);
    }

    @Override // dr.inference.operators.SimpleMCMCOperator, dr.inference.operators.MCMCOperator
    public String getOperatorName() {
        return "VanillaHMC(" + this.parameter.getParameterName() + ")";
    }

    private boolean shouldUpdatePreconditioning() {
        return this.runtimeOptions.preconditioningUpdateFrequency > 0 && getCount() % ((long) this.runtimeOptions.preconditioningUpdateFrequency) == 0 && getCount() > ((long) this.runtimeOptions.preconditioningDelay);
    }

    private static double[] buildMask(Parameter parameter) {
        if (parameter == null) {
            return null;
        }
        double[] dArr = new double[parameter.getDimension()];
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = parameter.getParameterValue(i) == 0.0d ? 0.0d : 1.0d;
        }
        return dArr;
    }

    @Override // dr.inference.operators.SimpleMCMCOperator
    public double doOperation() {
        throw new RuntimeException("Should not be executed");
    }

    @Override // dr.inference.operators.SimpleMCMCOperator
    public double doOperation(Likelihood likelihood) {
        if (shouldCheckStepSize()) {
            checkStepSize();
        }
        if (shouldCheckGradient()) {
            checkGradient(likelihood);
        }
        if (shouldUpdatePreconditioning()) {
            this.preconditioning.storeSecant(new WrappedVector.Raw(this.leapFrogEngine.getLastGradient()), new WrappedVector.Raw(this.leapFrogEngine.getLastPosition()));
            this.preconditioning.updateMass();
        }
        try {
            return leapFrog();
        } catch (NumericInstabilityException e) {
            return Double.NEGATIVE_INFINITY;
        }
    }

    @Override // dr.inference.operators.SimpleMCMCOperator, dr.inference.operators.PathDependent
    public void setPathParameter(double d) {
        if (this.gradientProvider instanceof PathGradient) {
            ((PathGradient) this.gradientProvider).setPathParameter(d);
        }
    }

    private boolean shouldCheckStepSize() {
        return getCount() < 1 && getMode() == AdaptationMode.ADAPTATION_ON;
    }

    private void checkStepSize() {
        double[] parameterValues = this.parameter.getParameterValues();
        boolean z = false;
        for (int i = 0; !z && i < this.runtimeOptions.checkStepSizeMaxIterations; i++) {
            try {
                leapFrog();
                double logLikelihood = this.gradientProvider.getLikelihood().getLogLikelihood();
                if (!Double.isNaN(logLikelihood) && !Double.isInfinite(logLikelihood)) {
                    z = true;
                }
            } catch (Exception e) {
            }
            if (!z) {
                this.stepSize *= this.runtimeOptions.checkStepSizeReductionFactor;
            }
            ReadableVector.Utils.setParameter(parameterValues, this.parameter);
        }
        if (!z) {
            throw new RuntimeException("Unable to find acceptable initial HMC step-size");
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public boolean shouldCheckGradient() {
        return getCount() < ((long) this.runtimeOptions.gradientCheckCount);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void checkGradient(final Likelihood likelihood) {
        if (this.parameter.getDimension() != this.gradientProvider.getDimension()) {
            throw new RuntimeException("Unequal dimensions");
        }
        MultivariateFunction multivariateFunction = new MultivariateFunction() { // from class: dr.inference.operators.hmc.HamiltonianMonteCarloOperator.1
            @Override // dr.math.MultivariateFunction
            public double evaluate(double[] dArr) {
                if (HamiltonianMonteCarloOperator.this.transform == null) {
                    ReadableVector.Utils.setParameter(dArr, HamiltonianMonteCarloOperator.this.parameter);
                    return likelihood.getLogLikelihood();
                }
                double[] inverse = HamiltonianMonteCarloOperator.this.transform.inverse(dArr, 0, dArr.length);
                ReadableVector.Utils.setParameter(inverse, HamiltonianMonteCarloOperator.this.parameter);
                return likelihood.getLogLikelihood() - HamiltonianMonteCarloOperator.this.transform.getLogJacobian(inverse, 0, inverse.length);
            }

            @Override // dr.math.MultivariateFunction
            public int getNumArguments() {
                return HamiltonianMonteCarloOperator.this.parameter.getDimension();
            }

            @Override // dr.math.MultivariateFunction
            public double getLowerBound(int i) {
                return HamiltonianMonteCarloOperator.this.parameter.getBounds().getLowerLimit(i).doubleValue();
            }

            @Override // dr.math.MultivariateFunction
            public double getUpperBound(int i) {
                return HamiltonianMonteCarloOperator.this.parameter.getBounds().getUpperLimit(i).doubleValue();
            }
        };
        double[] gradientLogDensity = this.gradientProvider.getGradientLogDensity();
        double[] parameterValues = this.parameter.getParameterValues();
        if (this.transform == null) {
            double[] gradient = NumericalDerivative.gradient(multivariateFunction, this.parameter.getParameterValues());
            if (!MathUtils.isClose(gradientLogDensity, gradient, this.runtimeOptions.gradientCheckTolerance)) {
                throw new RuntimeException("Gradients do not match:\n\tAnalytic: " + new WrappedVector.Raw(gradientLogDensity) + "\n\tNumeric : " + new WrappedVector.Raw(gradient) + "\n");
            }
        } else {
            double[] transform = this.transform.transform(this.parameter.getParameterValues(), 0, this.parameter.getParameterValues().length);
            double[] gradient2 = NumericalDerivative.gradient(multivariateFunction, transform);
            double[] updateGradientLogDensity = this.transform.updateGradientLogDensity(gradientLogDensity, this.parameter.getParameterValues(), 0, this.parameter.getParameterValues().length);
            if (!MathUtils.isClose(updateGradientLogDensity, gradient2, this.runtimeOptions.gradientCheckTolerance)) {
                throw new RuntimeException("Transformed Gradients do not match:\n\tAnalytic: " + new WrappedVector.Raw(updateGradientLogDensity) + "\n\tNumeric : " + new WrappedVector.Raw(gradient2) + "\n\tParameter : " + new WrappedVector.Raw(this.parameter.getParameterValues()) + "\n\tTransformed Parameter : " + new WrappedVector.Raw(transform) + "\n");
            }
        }
        ReadableVector.Utils.setParameter(parameterValues, this.parameter);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static double[] mask(double[] dArr, double[] dArr2) {
        if (!$assertionsDisabled && dArr2 != null && dArr2.length != dArr.length) {
            throw new AssertionError();
        }
        if (dArr2 != null) {
            for (int i = 0; i < dArr.length; i++) {
                int i2 = i;
                dArr[i2] = dArr[i2] * dArr2[i];
            }
        }
        return dArr;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static WrappedVector mask(WrappedVector wrappedVector, double[] dArr) {
        if (!$assertionsDisabled && dArr != null && dArr.length != wrappedVector.getDim()) {
            throw new AssertionError();
        }
        if (dArr != null) {
            for (int i = 0; i < wrappedVector.getDim(); i++) {
                wrappedVector.set(i, wrappedVector.get(i) * dArr[i]);
            }
        }
        return wrappedVector;
    }

    private int getNumberOfSteps() {
        int i = this.runtimeOptions.nSteps;
        if (this.runtimeOptions.randomStepCountFraction > 0.0d) {
            i = Math.max(1, (int) (i * (1.0d + (this.runtimeOptions.randomStepCountFraction * (MathUtils.nextDouble() - 0.5d)))));
        }
        return i;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public double getKineticEnergy(ReadableVector readableVector) {
        int dim = readableVector.getDim();
        double d = 0.0d;
        for (int i = 0; i < dim; i++) {
            d += readableVector.get(i) * this.preconditioning.getVelocity(i, readableVector);
        }
        return d / 2.0d;
    }

    private double leapFrog() throws NumericInstabilityException {
        double[] initialPosition = this.leapFrogEngine.getInitialPosition();
        WrappedVector mask = mask(this.preconditioning.drawInitialMomentum(), this.mask);
        double kineticEnergy = getKineticEnergy(mask) + this.leapFrogEngine.getParameterLogJacobian();
        this.leapFrogEngine.updateMomentum(initialPosition, mask.getBuffer(), mask(this.gradientProvider.getGradientLogDensity(), this.mask), this.stepSize / 2.0d);
        int numberOfSteps = getNumberOfSteps();
        for (int i = 0; i < numberOfSteps; i++) {
            try {
                this.leapFrogEngine.updatePosition(initialPosition, mask, this.stepSize);
                if (i < numberOfSteps - 1) {
                    try {
                        this.leapFrogEngine.updateMomentum(initialPosition, mask.getBuffer(), mask(this.gradientProvider.getGradientLogDensity(), this.mask), this.stepSize);
                    } catch (ArithmeticException e) {
                        throw new NumericInstabilityException();
                    }
                }
            } catch (ArithmeticException e2) {
                throw new NumericInstabilityException();
            }
        }
        this.leapFrogEngine.updateMomentum(initialPosition, mask.getBuffer(), mask(this.gradientProvider.getGradientLogDensity(), this.mask), this.stepSize / 2.0d);
        return kineticEnergy - (getKineticEnergy(mask) + this.leapFrogEngine.getParameterLogJacobian());
    }

    @Override // dr.inference.operators.AbstractAdaptableOperator
    protected double getAdaptableParameterValue() {
        return Math.log(this.stepSize);
    }

    @Override // dr.inference.operators.AbstractAdaptableOperator
    public void setAdaptableParameterValue(double d) {
        this.stepSize = Math.exp(d);
    }

    @Override // dr.inference.operators.AdaptableMCMCOperator
    public double getRawParameter() {
        return this.stepSize;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public InstabilityHandler getDefaultInstabilityHandler() {
        return InstabilityHandler.REJECT;
    }

    @Override // dr.inference.operators.AdaptableMCMCOperator
    public String getAdaptableParameterName() {
        return "stepSize";
    }

    static {
        $assertionsDisabled = !HamiltonianMonteCarloOperator.class.desiredAssertionStatus();
    }
}
