package dr.inference.operators.hmc;

import cern.colt.matrix.DoubleFactory2D;
import cern.colt.matrix.DoubleMatrix1D;
import cern.colt.matrix.DoubleMatrix2D;
import cern.colt.matrix.impl.DenseDoubleMatrix2D;
import cern.colt.matrix.linalg.Algebra;
import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.hmc.HessianWrtParameterProvider;
import dr.inference.operators.hmc.HamiltonianMonteCarloOperator;
import dr.math.AdaptableCovariance;
import dr.math.AdaptableVector;
import dr.math.MathUtils;
import dr.math.MultivariateFunction;
import dr.math.distributions.MultivariateNormalDistribution;
import dr.math.matrixAlgebra.ReadableVector;
import dr.math.matrixAlgebra.RobustEigenDecomposition;
import dr.math.matrixAlgebra.WrappedMatrix;
import dr.math.matrixAlgebra.WrappedVector;
import dr.util.Transform;
import java.util.Arrays;

/* loaded from: input_file:dr/inference/operators/hmc/MassPreconditioner.class */
public interface MassPreconditioner {

    /* loaded from: input_file:dr/inference/operators/hmc/MassPreconditioner$AbstractMassPreconditioning.class */
    public static abstract class AbstractMassPreconditioning implements MassPreconditioner {
        protected final int dim;
        protected final Transform transform;
        double[] inverseMass;

        protected AbstractMassPreconditioning(int i, Transform transform) {
            this.dim = i;
            this.transform = transform;
        }

        protected abstract void initializeMass();

        protected abstract double[] computeInverseMass();

        @Override // dr.inference.operators.hmc.MassPreconditioner
        public void updateMass() {
            this.inverseMass = computeInverseMass();
        }

        @Override // dr.inference.operators.hmc.MassPreconditioner
        public abstract void storeSecant(ReadableVector readableVector, ReadableVector readableVector2);
    }

    /* loaded from: input_file:dr/inference/operators/hmc/MassPreconditioner$AdaptiveDiagonalPreconditioning.class */
    public static class AdaptiveDiagonalPreconditioning extends DiagonalPreconditioning {
        private AdaptableVector.AdaptableVariance variance;
        private final int minimumUpdates = 100;

        AdaptiveDiagonalPreconditioning(int i, Transform transform) {
            super(i, transform);
            this.minimumUpdates = 100;
            this.variance = new AdaptableVector.AdaptableVariance(i);
        }

        @Override // dr.inference.operators.hmc.MassPreconditioner.DiagonalPreconditioning, dr.inference.operators.hmc.MassPreconditioner.AbstractMassPreconditioning
        protected void initializeMass() {
            super.initializeMass();
            this.adaptiveDiagonal.update(new WrappedVector.Raw(this.inverseMass));
        }

        @Override // dr.inference.operators.hmc.MassPreconditioner.AbstractMassPreconditioning
        protected double[] computeInverseMass() {
            if (this.variance.getUpdateCount() > 100) {
                this.adaptiveDiagonal.update(new WrappedVector.Raw(this.variance.getVariance()));
            }
            return normalizeVector(this.adaptiveDiagonal.getMean(), this.dim);
        }

        private double[] normalizeVector(ReadableVector readableVector, double d) {
            double d2 = 0.0d;
            for (int i = 0; i < readableVector.getDim(); i++) {
                d2 += readableVector.get(i);
            }
            double d3 = d / d2;
            double[] dArr = new double[readableVector.getDim()];
            for (int i2 = 0; i2 < readableVector.getDim(); i2++) {
                dArr[i2] = readableVector.get(i2) * d3;
            }
            return dArr;
        }

        @Override // dr.inference.operators.hmc.MassPreconditioner.AbstractMassPreconditioning, dr.inference.operators.hmc.MassPreconditioner
        public void storeSecant(ReadableVector readableVector, ReadableVector readableVector2) {
            this.variance.update(readableVector2);
        }
    }

    /* loaded from: input_file:dr/inference/operators/hmc/MassPreconditioner$AdaptiveFullHessianPreconditioning.class */
    public static class AdaptiveFullHessianPreconditioning extends FullHessianPreconditioning {
        private final AdaptableCovariance adaptableCovariance;
        private final GradientWrtParameterProvider gradientProvider;
        protected MultivariateFunction numeric1;

        AdaptiveFullHessianPreconditioning(GradientWrtParameterProvider gradientWrtParameterProvider, AdaptableCovariance adaptableCovariance, Transform transform, int i) {
            super(null, transform, i);
            this.numeric1 = new MultivariateFunction() { // from class: dr.inference.operators.hmc.MassPreconditioner.AdaptiveFullHessianPreconditioning.1
                @Override // dr.math.MultivariateFunction
                public double evaluate(double[] dArr) {
                    for (int i2 = 0; i2 < dArr.length; i2++) {
                        AdaptiveFullHessianPreconditioning.this.gradientProvider.getParameter().setParameterValue(i2, dArr[i2]);
                    }
                    return AdaptiveFullHessianPreconditioning.this.gradientProvider.getLikelihood().getLogLikelihood();
                }

                @Override // dr.math.MultivariateFunction
                public int getNumArguments() {
                    return AdaptiveFullHessianPreconditioning.this.gradientProvider.getParameter().getDimension();
                }

                @Override // dr.math.MultivariateFunction
                public double getLowerBound(int i2) {
                    return 0.0d;
                }

                @Override // dr.math.MultivariateFunction
                public double getUpperBound(int i2) {
                    return Double.POSITIVE_INFINITY;
                }
            };
            this.adaptableCovariance = adaptableCovariance;
            this.gradientProvider = gradientWrtParameterProvider;
        }

        @Override // dr.inference.operators.hmc.MassPreconditioner.FullHessianPreconditioning, dr.inference.operators.hmc.MassPreconditioner.AbstractMassPreconditioning
        protected double[] computeInverseMass() {
            return computeInverseMass((WrappedMatrix.ArrayOfArray) this.adaptableCovariance.getCovariance(), this.gradientProvider, FullHessianPreconditioning.PDTransformMatrix.Negate);
        }

        @Override // dr.inference.operators.hmc.MassPreconditioner.FullHessianPreconditioning, dr.inference.operators.hmc.MassPreconditioner.AbstractMassPreconditioning, dr.inference.operators.hmc.MassPreconditioner
        public void storeSecant(ReadableVector readableVector, ReadableVector readableVector2) {
            this.adaptableCovariance.update(readableVector2);
        }
    }

    /* loaded from: input_file:dr/inference/operators/hmc/MassPreconditioner$DiagonalHessianPreconditioning.class */
    public static class DiagonalHessianPreconditioning extends DiagonalPreconditioning {
        protected final HessianWrtParameterProvider hessian;

        DiagonalHessianPreconditioning(HessianWrtParameterProvider hessianWrtParameterProvider, Transform transform, int i) {
            super(hessianWrtParameterProvider.getDimension(), transform);
            this.hessian = hessianWrtParameterProvider;
            if (i > 0) {
                this.adaptiveDiagonal = new AdaptableVector.LimitedMemory(hessianWrtParameterProvider.getDimension(), i);
            } else {
                this.adaptiveDiagonal = new AdaptableVector.Default(hessianWrtParameterProvider.getDimension());
            }
        }

        @Override // dr.inference.operators.hmc.MassPreconditioner.AbstractMassPreconditioning
        protected double[] computeInverseMass() {
            double[] diagonalHessianLogDensity = this.hessian.getDiagonalHessianLogDensity();
            if (this.transform != null) {
                double[] parameterValues = this.hessian.getParameter().getParameterValues();
                diagonalHessianLogDensity = this.transform.updateDiagonalHessianLogDensity(diagonalHessianLogDensity, this.hessian.getGradientLogDensity(), parameterValues, 0, this.dim);
            }
            this.adaptiveDiagonal.update(new WrappedVector.Raw(diagonalHessianLogDensity));
            return boundMassInverse(((WrappedVector) this.adaptiveDiagonal.getMean()).getBuffer());
        }

        private double[] boundMassInverse(double[] dArr) {
            double d = 0.0d;
            double[] dArr2 = new double[this.dim];
            for (int i = 0; i < this.dim; i++) {
                dArr2[i] = (-1.0d) / dArr[i];
                if (dArr2[i] < 0.01d) {
                    dArr2[i] = 0.01d;
                } else if (dArr2[i] > 100.0d) {
                    dArr2[i] = 100.0d;
                }
                d += 1.0d / dArr2[i];
            }
            double d2 = d / this.dim;
            for (int i2 = 0; i2 < this.dim; i2++) {
                dArr2[i2] = dArr2[i2] * d2;
            }
            return dArr2;
        }

        @Override // dr.inference.operators.hmc.MassPreconditioner.AbstractMassPreconditioning, dr.inference.operators.hmc.MassPreconditioner
        public void storeSecant(ReadableVector readableVector, ReadableVector readableVector2) {
        }
    }

    /* loaded from: input_file:dr/inference/operators/hmc/MassPreconditioner$DiagonalPreconditioning.class */
    public static abstract class DiagonalPreconditioning extends AbstractMassPreconditioning {
        protected AdaptableVector adaptiveDiagonal;

        protected DiagonalPreconditioning(int i, Transform transform) {
            super(i, transform);
            this.adaptiveDiagonal = new AdaptableVector.Default(i);
            initializeMass();
        }

        @Override // dr.inference.operators.hmc.MassPreconditioner.AbstractMassPreconditioning
        protected void initializeMass() {
            double[] dArr = new double[this.dim];
            Arrays.fill(dArr, 1.0d);
            this.inverseMass = dArr;
        }

        @Override // dr.inference.operators.hmc.MassPreconditioner
        public WrappedVector drawInitialMomentum() {
            double[] dArr = new double[this.dim];
            for (int i = 0; i < this.dim; i++) {
                dArr[i] = MathUtils.nextGaussian() * Math.sqrt(1.0d / this.inverseMass[i]);
            }
            return new WrappedVector.Raw(dArr);
        }

        @Override // dr.inference.operators.hmc.MassPreconditioner
        public double getVelocity(int i, ReadableVector readableVector) {
            return readableVector.get(i) * this.inverseMass[i];
        }

        @Override // dr.inference.operators.hmc.MassPreconditioner
        public ReadableVector doCollision(int[] iArr, ReadableVector readableVector) {
            if (iArr.length != 2) {
                throw new RuntimeException("Not implemented for more than two dimensions yet.");
            }
            WrappedVector.Raw raw = new WrappedVector.Raw(new double[readableVector.getDim()]);
            for (int i = 0; i < readableVector.getDim(); i++) {
                raw.set(i, readableVector.get(i));
            }
            int i2 = iArr[0];
            int i3 = iArr[1];
            double d = (((this.inverseMass[i3] - this.inverseMass[i2]) * readableVector.get(i2)) + ((2.0d * this.inverseMass[i3]) * readableVector.get(i3))) / (this.inverseMass[i2] + this.inverseMass[i3]);
            double d2 = (((this.inverseMass[i2] - this.inverseMass[i3]) * readableVector.get(i3)) + ((2.0d * this.inverseMass[i2]) * readableVector.get(i2))) / (this.inverseMass[i2] + this.inverseMass[i3]);
            raw.set(i2, d);
            raw.set(i3, d2);
            return raw;
        }
    }

    /* loaded from: input_file:dr/inference/operators/hmc/MassPreconditioner$FullHessianPreconditioning.class */
    public static class FullHessianPreconditioning extends HessianBased {

        /* JADX INFO: Access modifiers changed from: package-private */
        /* loaded from: input_file:dr/inference/operators/hmc/MassPreconditioner$FullHessianPreconditioning$PDTransformMatrix.class */
        public enum PDTransformMatrix {
            Invert("Transform inverse matrix into a PD matrix") { // from class: dr.inference.operators.hmc.MassPreconditioner.FullHessianPreconditioning.PDTransformMatrix.1
                @Override // dr.inference.operators.hmc.MassPreconditioner.FullHessianPreconditioning.PDTransformMatrix
                protected void transformEigenvalues(DoubleMatrix1D doubleMatrix1D) {
                    inverseNegateEigenvalues(doubleMatrix1D);
                }
            },
            Default("Transform matrix into a PD matrix") { // from class: dr.inference.operators.hmc.MassPreconditioner.FullHessianPreconditioning.PDTransformMatrix.2
                @Override // dr.inference.operators.hmc.MassPreconditioner.FullHessianPreconditioning.PDTransformMatrix
                protected void transformEigenvalues(DoubleMatrix1D doubleMatrix1D) {
                    negateEigenvalues(doubleMatrix1D);
                }
            },
            Negate("Transform negative matrix into a PD matrix") { // from class: dr.inference.operators.hmc.MassPreconditioner.FullHessianPreconditioning.PDTransformMatrix.3
                @Override // dr.inference.operators.hmc.MassPreconditioner.FullHessianPreconditioning.PDTransformMatrix
                protected void transformEigenvalues(DoubleMatrix1D doubleMatrix1D) {
                    negateEigenvalues(doubleMatrix1D);
                }

                @Override // dr.inference.operators.hmc.MassPreconditioner.FullHessianPreconditioning.PDTransformMatrix
                protected void normalizeEigenvalues(DoubleMatrix1D doubleMatrix1D) {
                    negateEigenvalues(doubleMatrix1D);
                    boundEigenvalues(doubleMatrix1D);
                    scaleEigenvalues(doubleMatrix1D);
                }
            },
            NegateInvert("Transform negative inverse matrix into a PD matrix") { // from class: dr.inference.operators.hmc.MassPreconditioner.FullHessianPreconditioning.PDTransformMatrix.4
                @Override // dr.inference.operators.hmc.MassPreconditioner.FullHessianPreconditioning.PDTransformMatrix
                protected void transformEigenvalues(DoubleMatrix1D doubleMatrix1D) {
                    inverseNegateEigenvalues(doubleMatrix1D);
                }

                @Override // dr.inference.operators.hmc.MassPreconditioner.FullHessianPreconditioning.PDTransformMatrix
                protected void normalizeEigenvalues(DoubleMatrix1D doubleMatrix1D) {
                    negateEigenvalues(doubleMatrix1D);
                    boundEigenvalues(doubleMatrix1D);
                    scaleEigenvalues(doubleMatrix1D);
                }
            };

            String desc;
            private static final double MIN_EIGENVALUE = -10.0d;
            private static final double MAX_EIGENVALUE = -0.5d;

            PDTransformMatrix(String str) {
                this.desc = str;
            }

            @Override // java.lang.Enum
            public String toString() {
                return this.desc;
            }

            protected void boundEigenvalues(DoubleMatrix1D doubleMatrix1D) {
                for (int i = 0; i < doubleMatrix1D.cardinality(); i++) {
                    if (doubleMatrix1D.get(i) > MAX_EIGENVALUE) {
                        doubleMatrix1D.set(i, MAX_EIGENVALUE);
                    } else if (doubleMatrix1D.get(i) < MIN_EIGENVALUE) {
                        doubleMatrix1D.set(i, MIN_EIGENVALUE);
                    }
                }
            }

            protected void scaleEigenvalues(DoubleMatrix1D doubleMatrix1D) {
                double d = 0.0d;
                for (int i = 0; i < doubleMatrix1D.cardinality(); i++) {
                    d += doubleMatrix1D.get(i);
                }
                double cardinality = (-d) / doubleMatrix1D.cardinality();
                for (int i2 = 0; i2 < doubleMatrix1D.cardinality(); i2++) {
                    doubleMatrix1D.set(i2, doubleMatrix1D.get(i2) / cardinality);
                }
            }

            protected void normalizeEigenvalues(DoubleMatrix1D doubleMatrix1D) {
                boundEigenvalues(doubleMatrix1D);
                scaleEigenvalues(doubleMatrix1D);
            }

            protected void inverseNegateEigenvalues(DoubleMatrix1D doubleMatrix1D) {
                for (int i = 0; i < doubleMatrix1D.cardinality(); i++) {
                    doubleMatrix1D.set(i, (-1.0d) / doubleMatrix1D.get(i));
                }
            }

            protected void negateEigenvalues(DoubleMatrix1D doubleMatrix1D) {
                for (int i = 0; i < doubleMatrix1D.cardinality(); i++) {
                    doubleMatrix1D.set(i, -doubleMatrix1D.get(i));
                }
            }

            public double[] transformMatrix(double[][] dArr, int i) {
                Algebra algebra = new Algebra();
                RobustEigenDecomposition robustEigenDecomposition = new RobustEigenDecomposition(new DenseDoubleMatrix2D(dArr));
                DoubleMatrix1D realEigenvalues = robustEigenDecomposition.getRealEigenvalues();
                normalizeEigenvalues(realEigenvalues);
                DoubleMatrix2D v = robustEigenDecomposition.getV();
                transformEigenvalues(realEigenvalues);
                double[][] array = algebra.mult(algebra.mult(v, DoubleFactory2D.dense.diagonal(realEigenvalues)), algebra.inverse(v)).toArray();
                double[] dArr2 = new double[i * i];
                for (int i2 = 0; i2 < i; i2++) {
                    System.arraycopy(array[i2], 0, dArr2, i2 * i, i);
                }
                return dArr2;
            }

            protected abstract void transformEigenvalues(DoubleMatrix1D doubleMatrix1D);
        }

        FullHessianPreconditioning(HessianWrtParameterProvider hessianWrtParameterProvider, Transform transform) {
            super(hessianWrtParameterProvider, transform);
        }

        FullHessianPreconditioning(HessianWrtParameterProvider hessianWrtParameterProvider, Transform transform, int i) {
            super(hessianWrtParameterProvider, transform, i);
        }

        @Override // dr.inference.operators.hmc.MassPreconditioner.AbstractMassPreconditioning
        protected void initializeMass() {
            double[] dArr = new double[this.dim * this.dim];
            for (int i = 0; i < this.dim; i++) {
                dArr[(i * this.dim) + i] = 1.0d;
            }
            this.inverseMass = dArr;
        }

        /* JADX INFO: Access modifiers changed from: private */
        public double[] computeInverseMass(WrappedMatrix.ArrayOfArray arrayOfArray, GradientWrtParameterProvider gradientWrtParameterProvider, PDTransformMatrix pDTransformMatrix) {
            double[][] arrays = arrayOfArray.getArrays();
            if (this.transform != null) {
                arrays = this.transform.updateHessianLogDensity(arrays, new double[this.dim][this.dim], gradientWrtParameterProvider.getGradientLogDensity(), gradientWrtParameterProvider.getParameter().getParameterValues(), 0, this.dim);
            }
            return pDTransformMatrix.transformMatrix(arrays, this.dim);
        }

        @Override // dr.inference.operators.hmc.MassPreconditioner.AbstractMassPreconditioning
        protected double[] computeInverseMass() {
            return computeInverseMass(new WrappedMatrix.ArrayOfArray(this.hessian.getHessianLogDensity()), this.hessian, PDTransformMatrix.Invert);
        }

        @Override // dr.inference.operators.hmc.MassPreconditioner.AbstractMassPreconditioning, dr.inference.operators.hmc.MassPreconditioner
        public void storeSecant(ReadableVector readableVector, ReadableVector readableVector2) {
        }

        @Override // dr.inference.operators.hmc.MassPreconditioner
        public WrappedVector drawInitialMomentum() {
            return new WrappedVector.Raw(new MultivariateNormalDistribution(new double[this.dim], toArray(this.inverseMass, this.dim, this.dim)).nextMultivariateNormal());
        }

        @Override // dr.inference.operators.hmc.MassPreconditioner
        public double getVelocity(int i, ReadableVector readableVector) {
            double d = 0.0d;
            for (int i2 = 0; i2 < this.dim; i2++) {
                d += this.inverseMass[(i * this.dim) + i2] * readableVector.get(i2);
            }
            return d;
        }

        /* JADX WARN: Multi-variable type inference failed */
        /* JADX WARN: Type inference failed for: r0v1, types: [double[], double[][]] */
        private static double[][] toArray(double[] dArr, int i, int i2) {
            ?? r0 = new double[i];
            for (int i3 = 0; i3 < i; i3++) {
                r0[i3] = new double[i2];
                System.arraycopy(dArr, i2 * i3, r0[i3], 0, i2);
            }
            return r0;
        }
    }

    /* loaded from: input_file:dr/inference/operators/hmc/MassPreconditioner$HessianBased.class */
    public static abstract class HessianBased extends AbstractMassPreconditioning {
        protected final HessianWrtParameterProvider hessian;

        HessianBased(HessianWrtParameterProvider hessianWrtParameterProvider, Transform transform) {
            this(hessianWrtParameterProvider, transform, hessianWrtParameterProvider.getDimension());
        }

        HessianBased(HessianWrtParameterProvider hessianWrtParameterProvider, Transform transform, int i) {
            super(i, transform);
            this.hessian = hessianWrtParameterProvider;
            initializeMass();
        }

        @Override // dr.inference.operators.hmc.MassPreconditioner
        public ReadableVector doCollision(int[] iArr, ReadableVector readableVector) {
            throw new RuntimeException("Not yet implemented.");
        }
    }

    /* loaded from: input_file:dr/inference/operators/hmc/MassPreconditioner$NoPreconditioning.class */
    public static class NoPreconditioning implements MassPreconditioner {
        final int dim;

        NoPreconditioning(int i) {
            this.dim = i;
        }

        @Override // dr.inference.operators.hmc.MassPreconditioner
        public WrappedVector drawInitialMomentum() {
            double[] dArr = new double[this.dim];
            for (int i = 0; i < this.dim; i++) {
                dArr[i] = MathUtils.nextGaussian();
            }
            return new WrappedVector.Raw(dArr);
        }

        @Override // dr.inference.operators.hmc.MassPreconditioner
        public double getVelocity(int i, ReadableVector readableVector) {
            return readableVector.get(i);
        }

        @Override // dr.inference.operators.hmc.MassPreconditioner
        public void storeSecant(ReadableVector readableVector, ReadableVector readableVector2) {
        }

        @Override // dr.inference.operators.hmc.MassPreconditioner
        public void updateMass() {
        }

        @Override // dr.inference.operators.hmc.MassPreconditioner
        public ReadableVector doCollision(int[] iArr, ReadableVector readableVector) {
            if (iArr.length != 2) {
                throw new RuntimeException("Not implemented for more than two dimensions yet.");
            }
            WrappedVector.Raw raw = new WrappedVector.Raw(new double[readableVector.getDim()]);
            for (int i = 0; i < readableVector.getDim(); i++) {
                raw.set(i, readableVector.get(i));
            }
            raw.set(iArr[0], readableVector.get(iArr[1]));
            raw.set(iArr[1], readableVector.get(iArr[0]));
            return raw;
        }
    }

    /* loaded from: input_file:dr/inference/operators/hmc/MassPreconditioner$Secant.class */
    public static class Secant extends FullHessianPreconditioning {
        private final SecantHessian secantHessian;

        Secant(SecantHessian secantHessian, Transform transform) {
            super(secantHessian, transform);
            this.secantHessian = secantHessian;
        }

        @Override // dr.inference.operators.hmc.MassPreconditioner.FullHessianPreconditioning, dr.inference.operators.hmc.MassPreconditioner.AbstractMassPreconditioning, dr.inference.operators.hmc.MassPreconditioner
        public void storeSecant(ReadableVector readableVector, ReadableVector readableVector2) {
            this.secantHessian.storeSecant(readableVector, readableVector2);
        }
    }

    /* loaded from: input_file:dr/inference/operators/hmc/MassPreconditioner$Type.class */
    public enum Type {
        NONE("none") { // from class: dr.inference.operators.hmc.MassPreconditioner.Type.1
            @Override // dr.inference.operators.hmc.MassPreconditioner.Type
            public MassPreconditioner factory(GradientWrtParameterProvider gradientWrtParameterProvider, Transform transform, HamiltonianMonteCarloOperator.Options options) {
                int dimension = gradientWrtParameterProvider.getParameter().getDimension();
                if (transform != null && (transform instanceof Transform.MultivariableTransform)) {
                    dimension = ((Transform.MultivariableTransform) transform).getDimension();
                }
                return new NoPreconditioning(dimension);
            }
        },
        DIAGONAL("diagonal") { // from class: dr.inference.operators.hmc.MassPreconditioner.Type.2
            @Override // dr.inference.operators.hmc.MassPreconditioner.Type
            public MassPreconditioner factory(GradientWrtParameterProvider gradientWrtParameterProvider, Transform transform, HamiltonianMonteCarloOperator.Options options) {
                return new DiagonalHessianPreconditioning((HessianWrtParameterProvider) gradientWrtParameterProvider, transform, options.preconditioningMemory);
            }
        },
        ADAPTIVE_DIAGONAL("adaptiveDiagonal") { // from class: dr.inference.operators.hmc.MassPreconditioner.Type.3
            @Override // dr.inference.operators.hmc.MassPreconditioner.Type
            public MassPreconditioner factory(GradientWrtParameterProvider gradientWrtParameterProvider, Transform transform, HamiltonianMonteCarloOperator.Options options) {
                return new AdaptiveDiagonalPreconditioning(gradientWrtParameterProvider.getDimension(), transform);
            }
        },
        FULL("full") { // from class: dr.inference.operators.hmc.MassPreconditioner.Type.4
            @Override // dr.inference.operators.hmc.MassPreconditioner.Type
            public MassPreconditioner factory(GradientWrtParameterProvider gradientWrtParameterProvider, Transform transform, HamiltonianMonteCarloOperator.Options options) {
                return new FullHessianPreconditioning((HessianWrtParameterProvider) gradientWrtParameterProvider, transform);
            }
        },
        SECANT("secant") { // from class: dr.inference.operators.hmc.MassPreconditioner.Type.5
            @Override // dr.inference.operators.hmc.MassPreconditioner.Type
            public MassPreconditioner factory(GradientWrtParameterProvider gradientWrtParameterProvider, Transform transform, HamiltonianMonteCarloOperator.Options options) {
                return new Secant(new SecantHessian(gradientWrtParameterProvider, options.preconditioningMemory), transform);
            }
        },
        ADAPTIVE("adaptive") { // from class: dr.inference.operators.hmc.MassPreconditioner.Type.6
            @Override // dr.inference.operators.hmc.MassPreconditioner.Type
            public MassPreconditioner factory(GradientWrtParameterProvider gradientWrtParameterProvider, Transform transform, HamiltonianMonteCarloOperator.Options options) {
                return new AdaptiveFullHessianPreconditioning(gradientWrtParameterProvider, new AdaptableCovariance(gradientWrtParameterProvider.getDimension()), transform, gradientWrtParameterProvider.getDimension());
            }
        };

        private final String name;

        Type(String str) {
            this.name = str;
        }

        public abstract MassPreconditioner factory(GradientWrtParameterProvider gradientWrtParameterProvider, Transform transform, HamiltonianMonteCarloOperator.Options options);

        public String getName() {
            return this.name;
        }

        public static Type parseFromString(String str) {
            for (Type type : values()) {
                if (type.name.toLowerCase().compareToIgnoreCase(str) == 0) {
                    return type;
                }
            }
            return NONE;
        }
    }

    WrappedVector drawInitialMomentum();

    double getVelocity(int i, ReadableVector readableVector);

    void storeSecant(ReadableVector readableVector, ReadableVector readableVector2);

    void updateMass();

    ReadableVector doCollision(int[] iArr, ReadableVector readableVector);
}
