package dr.inference.operators.hmc;

import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.hmc.HessianWrtParameterProvider;
import dr.inference.model.Likelihood;
import dr.inference.model.Parameter;
import dr.math.matrixAlgebra.ReadableVector;
import dr.math.matrixAlgebra.WrappedVector;

/* JADX INFO: Access modifiers changed from: package-private */
/* loaded from: input_file:dr/inference/operators/hmc/SecantHessian.class */
public class SecantHessian implements HessianWrtParameterProvider {
    private final int dim;
    private final GradientWrtParameterProvider gradientProvider;
    private final int secantSize;
    private final Secant[] queue;
    private int secantIndex;
    private int secantUpdateCount;
    private double[][] secantHessian;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:dr/inference/operators/hmc/SecantHessian$Secant.class */
    public class Secant {
        ReadableVector gradient;
        ReadableVector position;
        WrappedVector sk;
        WrappedVector yk;
        private double reciprocalInnerProduct = 0.0d;

        Secant(ReadableVector readableVector, ReadableVector readableVector2) {
            this.gradient = WrappedVector.Utils.copy(readableVector);
            this.position = WrappedVector.Utils.copy(readableVector2);
            this.sk = new WrappedVector.Raw(new double[readableVector2.getDim()]);
            this.yk = new WrappedVector.Raw(new double[readableVector2.getDim()]);
        }

        double getPosition(int i) {
            return this.position.get(i);
        }

        double getGradient(int i) {
            return this.gradient.get(i);
        }

        void updateSkYk(Secant secant) {
            this.reciprocalInnerProduct = 0.0d;
            for (int i = 0; i < SecantHessian.this.dim; i++) {
                double position = secant.getPosition(i) - getPosition(i);
                double gradient = secant.getGradient(i) - getGradient(i);
                this.sk.set(i, position);
                this.yk.set(i, gradient);
                this.reciprocalInnerProduct += this.sk.get(i) * this.yk.get(i);
            }
            this.reciprocalInnerProduct = 1.0d / this.reciprocalInnerProduct;
        }

        double getReciprocalInnerProduct() {
            return this.reciprocalInnerProduct;
        }

        double getYk(int i) {
            return this.yk.get(i);
        }

        double getSk(int i) {
            return this.sk.get(i);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public SecantHessian(GradientWrtParameterProvider gradientWrtParameterProvider, int i) {
        this.gradientProvider = gradientWrtParameterProvider;
        this.secantSize = i;
        this.dim = gradientWrtParameterProvider.getDimension();
        this.secantHessian = new double[this.dim][this.dim];
        for (int i2 = 0; i2 < this.dim; i2++) {
            this.secantHessian[i2][i2] = 1.0d;
        }
        this.queue = new Secant[i];
        this.secantIndex = 0;
        this.secantUpdateCount = 0;
    }

    @Override // dr.inference.hmc.HessianWrtParameterProvider
    public double[] getDiagonalHessianLogDensity() {
        double[] dArr = new double[this.dim];
        for (int i = 0; i < this.dim; i++) {
            dArr[i] = this.secantHessian[i][i];
        }
        return dArr;
    }

    @Override // dr.inference.hmc.HessianWrtParameterProvider
    public double[][] getHessianLogDensity() {
        return (double[][]) this.secantHessian.clone();
    }

    @Override // dr.inference.hmc.GradientWrtParameterProvider
    public Likelihood getLikelihood() {
        return this.gradientProvider.getLikelihood();
    }

    @Override // dr.inference.hmc.GradientWrtParameterProvider
    public Parameter getParameter() {
        return this.gradientProvider.getParameter();
    }

    @Override // dr.inference.hmc.GradientWrtParameterProvider
    public int getDimension() {
        return this.gradientProvider.getDimension();
    }

    @Override // dr.inference.hmc.GradientWrtParameterProvider
    public double[] getGradientLogDensity() {
        return this.gradientProvider.getGradientLogDensity();
    }

    public void storeSecant(ReadableVector readableVector, ReadableVector readableVector2) {
        this.queue[this.secantIndex] = new Secant(readableVector, readableVector2);
        if (this.secantUpdateCount == 0) {
            initializeSecantHessian(this.queue[this.secantIndex]);
        } else {
            updateSecantHessian(this.queue[this.secantIndex], this.queue[((this.secantIndex + this.secantSize) - 1) % this.queue.length]);
        }
        this.secantIndex = (this.secantIndex + 1) % this.queue.length;
        this.secantUpdateCount++;
    }

    private void initializeSecantHessian(Secant secant) {
        double d = 0.0d;
        double d2 = 0.0d;
        for (int i = 0; i < this.dim; i++) {
            d += secant.getGradient(i) * secant.getPosition(i);
            d2 += secant.getGradient(i) * secant.getGradient(i);
        }
        double d3 = d2 == 0.0d ? 1.0d : d / d2;
        for (int i2 = 0; i2 < this.dim; i2++) {
            this.secantHessian[i2][i2] = d3;
        }
    }

    private void updateSecantHessian(Secant secant, Secant secant2) {
        double[] dArr = new double[this.dim];
        double d = 0.0d;
        secant.updateSkYk(secant2);
        double reciprocalInnerProduct = secant.getReciprocalInnerProduct();
        for (int i = 0; i < this.dim; i++) {
            double d2 = 0.0d;
            for (int i2 = 0; i2 < this.dim; i2++) {
                d2 += this.secantHessian[i][i2] * secant.getSk(i2);
            }
            dArr[i] = d2;
            d += secant.getSk(i) * dArr[i];
        }
        double d3 = (-1.0d) / d;
        for (int i3 = 0; i3 < this.dim; i3++) {
            for (int i4 = 0; i4 < this.dim; i4++) {
                double[] dArr2 = this.secantHessian[i3];
                int i5 = i4;
                dArr2[i5] = dArr2[i5] + (reciprocalInnerProduct * secant.getYk(i3) * secant.getYk(i4)) + (d3 * dArr[i3] * dArr[i4]);
            }
        }
    }
}
