package dr.util;

import dr.math.matrixAlgebra.WrappedMatrix;
import dr.util.Transform;

/* loaded from: input_file:dr/util/LKJCholeskyTransformConstrained.class */
public class LKJCholeskyTransformConstrained extends Transform.MultivariateTransform {
    int dimVector;
    static final /* synthetic */ boolean $assertionsDisabled;

    public LKJCholeskyTransformConstrained(int i) {
        super((i * (i - 1)) / 2);
        this.dimVector = i;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // dr.util.Transform.MultivariateTransform
    public double[] inverse(double[] dArr) {
        for (int i = 0; i < this.dim; i++) {
            if (!$assertionsDisabled && (dArr[i] > 1.0d || dArr[i] < -1.0d)) {
                throw new AssertionError("CPCs must be between -1.0 and 1.0");
            }
        }
        WrappedMatrix.WrappedStrictlyUpperTriangularMatrix wrappedStrictlyUpperTriangularMatrix = new WrappedMatrix.WrappedStrictlyUpperTriangularMatrix(this.dimVector);
        WrappedMatrix.WrappedStrictlyUpperTriangularMatrix wrappedStrictlyUpperTriangularMatrix2 = new WrappedMatrix.WrappedStrictlyUpperTriangularMatrix(dArr, this.dimVector, 1.0d);
        for (int i2 = 1; i2 < this.dimVector; i2++) {
            double d = 1.0d;
            for (int i3 = 0; i3 < i2; i3++) {
                double d2 = wrappedStrictlyUpperTriangularMatrix2.get(i3, i2);
                wrappedStrictlyUpperTriangularMatrix.set(i3, i2, d2 * d);
                d *= Math.sqrt(1.0d - (d2 * d2));
            }
        }
        return wrappedStrictlyUpperTriangularMatrix.getBuffer();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // dr.util.Transform.MultivariateTransform
    public double[] transform(double[] dArr) {
        WrappedMatrix.WrappedStrictlyUpperTriangularMatrix wrappedStrictlyUpperTriangularMatrix = new WrappedMatrix.WrappedStrictlyUpperTriangularMatrix(dArr, this.dimVector);
        WrappedMatrix.WrappedStrictlyUpperTriangularMatrix wrappedStrictlyUpperTriangularMatrix2 = new WrappedMatrix.WrappedStrictlyUpperTriangularMatrix(this.dimVector);
        for (int i = 1; i < this.dimVector; i++) {
            double d = 1.0d;
            for (int i2 = 0; i2 < i; i2++) {
                double d2 = wrappedStrictlyUpperTriangularMatrix.get(i2, i) / d;
                wrappedStrictlyUpperTriangularMatrix2.set(i2, i, d2);
                d *= Math.sqrt(1.0d - (d2 * d2));
            }
        }
        return wrappedStrictlyUpperTriangularMatrix2.getBuffer();
    }

    @Override // dr.util.Transform.MultivariateTransform
    public boolean isInInteriorDomain(double[] dArr) {
        WrappedMatrix.WrappedStrictlyUpperTriangularMatrix wrappedStrictlyUpperTriangularMatrix = new WrappedMatrix.WrappedStrictlyUpperTriangularMatrix(dArr, this.dimVector);
        if (Math.abs(wrappedStrictlyUpperTriangularMatrix.get(0, 0)) >= 1.0d) {
            return false;
        }
        for (int i = 1; i < this.dimVector; i++) {
            double d = 0.0d;
            for (int i2 = 0; i2 < i; i2++) {
                d += Math.pow(wrappedStrictlyUpperTriangularMatrix.get(i2, i), 2.0d);
            }
            if (d >= 1.0d) {
                return false;
            }
        }
        return true;
    }

    @Override // dr.util.Transform
    public double[] inverse(double[] dArr, int i, int i2, double d) {
        throw new RuntimeException("Not relevant for the LKJ transform.");
    }

    @Override // dr.util.Transform
    public String getTransformName() {
        return "LKJCholeskyTransform";
    }

    @Override // dr.util.Transform
    public double[] gradient(double[] dArr, int i, int i2) {
        throw new RuntimeException("Not yet implemented");
    }

    @Override // dr.util.Transform
    public double[] gradientInverse(double[] dArr, int i, int i2) {
        throw new RuntimeException("Not yet implemented");
    }

    @Override // dr.util.Transform.MultivariateTransform
    protected double getLogJacobian(double[] dArr) {
        double[] transform = transform(dArr);
        double d = 0.0d;
        int i = 0;
        for (int i2 = 0; i2 < this.dimVector - 2; i2++) {
            i++;
            for (int i3 = i2 + 2; i3 < this.dimVector; i3++) {
                d += ((i3 - i2) - 1) * (Math.log1p(-transform[i]) + Math.log1p(transform[i]));
                i++;
            }
        }
        return (-0.5d) * d;
    }

    @Override // dr.util.Transform.MultivariateTransform
    public double[] getGradientLogJacobianInverse(double[] dArr) {
        double[] dArr2 = new double[dArr.length];
        int i = 0;
        for (int i2 = 0; i2 < this.dimVector - 2; i2++) {
            i++;
            for (int i3 = i2 + 2; i3 < this.dimVector; i3++) {
                dArr2[i] = ((-((i3 - i2) - 1)) * dArr[i]) / (1.0d - Math.pow(dArr[i], 2.0d));
                i++;
            }
        }
        return dArr2;
    }

    @Override // dr.util.Transform.MultivariateTransform
    public double[][] computeJacobianMatrixInverse(double[] dArr) {
        double[][] dArr2 = new double[this.dim][this.dim];
        WrappedMatrix.WrappedStrictlyUpperTriangularMatrix wrappedStrictlyUpperTriangularMatrix = new WrappedMatrix.WrappedStrictlyUpperTriangularMatrix(dArr, this.dimVector, 1.0d);
        for (int i = 1; i < this.dimVector; i++) {
            for (int i2 = 0; i2 < i; i2++) {
                recursionJacobian(dArr2, wrappedStrictlyUpperTriangularMatrix, i2, i);
            }
        }
        return dArr2;
    }

    private void recursionJacobian(double[][] dArr, WrappedMatrix.WrappedStrictlyUpperTriangularMatrix wrappedStrictlyUpperTriangularMatrix, int i, int i2) {
        WrappedMatrix.WrappedStrictlyUpperTriangularMatrix wrappedStrictlyUpperTriangularMatrix2 = new WrappedMatrix.WrappedStrictlyUpperTriangularMatrix(dArr[posStrict(i, i2)], this.dimVector);
        double d = 1.0d;
        for (int i3 = 0; i3 < i; i3++) {
            d *= Math.sqrt(1.0d - Math.pow(wrappedStrictlyUpperTriangularMatrix.get(i3, i2), 2.0d));
        }
        wrappedStrictlyUpperTriangularMatrix2.set(i, i2, d);
        double d2 = wrappedStrictlyUpperTriangularMatrix.get(i, i2);
        double sqrt = d * ((-d2) / Math.sqrt(1.0d - Math.pow(d2, 2.0d)));
        for (int i4 = i + 1; i4 < i2; i4++) {
            double d3 = wrappedStrictlyUpperTriangularMatrix.get(i4, i2);
            wrappedStrictlyUpperTriangularMatrix2.set(i4, i2, d3 * sqrt);
            sqrt *= Math.sqrt(1.0d - Math.pow(d3, 2.0d));
        }
    }

    private int posStrict(int i, int i2) {
        return ((i * (((2 * this.dimVector) - i) - 1)) / 2) + ((i2 - i) - 1);
    }

    static {
        $assertionsDisabled = !LKJCholeskyTransformConstrained.class.desiredAssertionStatus();
    }
}
