package dr.util;

import dr.math.matrixAlgebra.CholeskyDecomposition;
import dr.math.matrixAlgebra.IllegalDimension;
import dr.math.matrixAlgebra.SymmetricMatrix;
import dr.math.matrixAlgebra.WrappedMatrix;
import dr.util.Transform;

/* loaded from: input_file:dr/util/CorrelationToCholesky.class */
public class CorrelationToCholesky extends Transform.MultivariateTransform {
    private int dimVector;

    public CorrelationToCholesky(int i) {
        super((i * (i - 1)) / 2);
        this.dimVector = i;
    }

    @Override // dr.util.Transform.MultivariateTransform
    protected double[] inverse(double[] dArr) {
        return SymmetricMatrix.extractUpperTriangular(WrappedMatrix.WrappedUpperTriangularMatrix.fillDiagonal(dArr, this.dimVector).transposedProduct());
    }

    @Override // dr.util.Transform.MultivariateTransform
    protected double[] transform(double[] dArr) {
        try {
            return new CholeskyDecomposition(SymmetricMatrix.compoundCorrelationSymmetricMatrix(dArr, this.dimVector)).getStrictlyUpperTriangular();
        } catch (IllegalDimension e) {
            throw new RuntimeException("Unable to decompose matrix in LKJ inverse transform.");
        }
    }

    @Override // dr.util.Transform
    public double[] inverse(double[] dArr, int i, int i2, double d) {
        throw new RuntimeException("Not relevant for the correlation to Cholesky transform.");
    }

    @Override // dr.util.Transform.MultivariateTransform
    public boolean isInInteriorDomain(double[] dArr) {
        throw new RuntimeException("Not yet implemented");
    }

    @Override // dr.util.Transform
    public String getTransformName() {
        return "CorrelationToCholeskyTransform";
    }

    @Override // dr.util.Transform
    public double[] gradient(double[] dArr, int i, int i2) {
        throw new RuntimeException("Not yet implemented");
    }

    @Override // dr.util.Transform
    public double[] gradientInverse(double[] dArr, int i, int i2) {
        throw new RuntimeException("Not yet implemented");
    }

    @Override // dr.util.Transform.MultivariateTransform
    protected double getLogJacobian(double[] dArr) {
        WrappedMatrix.WrappedUpperTriangularMatrix fillDiagonal = WrappedMatrix.WrappedUpperTriangularMatrix.fillDiagonal(transform(dArr), this.dimVector);
        double d = 0.0d;
        for (int i = 0; i < this.dimVector - 1; i++) {
            d += ((this.dimVector - i) - 1) * Math.log(fillDiagonal.get(i, i));
        }
        return -d;
    }

    @Override // dr.util.Transform.MultivariateTransform
    protected double[] getGradientLogJacobianInverse(double[] dArr) {
        WrappedMatrix.WrappedUpperTriangularMatrix fillDiagonal = WrappedMatrix.WrappedUpperTriangularMatrix.fillDiagonal(dArr, this.dimVector);
        double[] dArr2 = new double[dArr.length];
        int i = 0;
        for (int i2 = 0; i2 < this.dimVector - 1; i2++) {
            for (int i3 = i2 + 1; i3 < this.dimVector; i3++) {
                dArr2[i] = ((-((this.dimVector - i3) - 1)) * fillDiagonal.get(i2, i3)) / Math.pow(fillDiagonal.get(i3, i3), 2.0d);
                i++;
            }
        }
        return dArr2;
    }

    @Override // dr.util.Transform.MultivariateTransform
    public double[][] computeJacobianMatrixInverse(double[] dArr) {
        double[][] dArr2 = new double[this.dim][this.dim];
        WrappedMatrix.WrappedUpperTriangularMatrix fillDiagonal = WrappedMatrix.WrappedUpperTriangularMatrix.fillDiagonal(dArr, this.dimVector);
        for (int i = 0; i < this.dimVector - 1; i++) {
            for (int i2 = i + 1; i2 < this.dimVector; i2++) {
                double d = fillDiagonal.get(i, i2) / fillDiagonal.get(i, i);
                for (int i3 = 0; i3 < i; i3++) {
                    dArr2[posStrict(i3, i)][posStrict(i, i2)] = fillDiagonal.get(i3, i2) - (fillDiagonal.get(i3, i) * d);
                    dArr2[posStrict(i3, i2)][posStrict(i, i2)] = fillDiagonal.get(i3, i);
                }
                dArr2[posStrict(i, i2)][posStrict(i, i2)] = fillDiagonal.get(i, i);
            }
        }
        return dArr2;
    }

    private int posStrict(int i, int i2) {
        return ((i * (((2 * this.dimVector) - i) - 1)) / 2) + ((i2 - i) - 1);
    }
}
