package dr.math.distributions;

import dr.evomodel.tree.UniformNodeHeightPrior;
import dr.math.GammaFunction;
import dr.math.MathUtils;
import dr.math.matrixAlgebra.CholeskyDecomposition;
import dr.math.matrixAlgebra.IllegalDimension;
import dr.math.matrixAlgebra.Matrix;

/* loaded from: input_file:dr/math/distributions/WishartDistribution.class */
public class WishartDistribution implements MultivariateDistribution, WishartStatistics {
    public static final String TYPE = "Wishart";
    private double df;
    private int dim;
    private double[][] scaleMatrix;
    private double[] Sinv;
    private Matrix SinvMat;
    private double logNormalizationConstant;

    public WishartDistribution(double d, double[][] dArr) {
        this.df = d;
        this.scaleMatrix = dArr;
        this.dim = dArr.length;
        this.SinvMat = new Matrix(dArr).inverse();
        double[][] components = this.SinvMat.toComponents();
        this.Sinv = new double[this.dim * this.dim];
        for (int i = 0; i < this.dim; i++) {
            System.arraycopy(components[i], 0, this.Sinv, i * this.dim, this.dim);
        }
        computeNormalizationConstant();
    }

    public WishartDistribution(int i) {
        this.df = 0.0d;
        this.scaleMatrix = null;
        this.dim = i;
        this.logNormalizationConstant = 0.0d;
    }

    private void computeNormalizationConstant() {
        this.logNormalizationConstant = computeNormalizationConstant(new Matrix(this.scaleMatrix), this.df, this.dim);
    }

    public static double computeNormalizationConstant(Matrix matrix, double d, int i) {
        if (d == 0.0d) {
            return 0.0d;
        }
        double d2 = 0.0d;
        try {
            d2 = ((-d) / 2.0d) * Math.log(matrix.determinant());
        } catch (IllegalDimension e) {
            e.printStackTrace();
        }
        double log = (d2 - (((d * i) / 2.0d) * Math.log(2.0d))) - (((i * (i - 1)) / 4.0d) * Math.log(3.141592653589793d));
        for (int i2 = 1; i2 <= i; i2++) {
            log -= GammaFunction.lnGamma(((d + 1.0d) - i2) / 2.0d);
        }
        return log;
    }

    @Override // dr.math.distributions.MultivariateDistribution
    public String getType() {
        return TYPE;
    }

    @Override // dr.math.distributions.MultivariateDistribution
    public double[][] getScaleMatrix() {
        return this.scaleMatrix;
    }

    @Override // dr.math.distributions.MultivariateDistribution
    public double[] getMean() {
        return null;
    }

    public void testMe() {
        double d = 0.0d;
        double d2 = 0.0d;
        double d3 = 0.0d;
        double d4 = 0.0d;
        for (int i = 0; i < 100000; i++) {
            double[][] nextWishart = nextWishart();
            d += nextWishart[0][0];
            d2 += nextWishart[0][1];
            d3 += nextWishart[1][0];
            d4 += nextWishart[1][1];
        }
        System.err.println("S1: " + (d / UniformNodeHeightPrior.DEFAULT_MC_SAMPLE));
        System.err.println("S2: " + (d2 / UniformNodeHeightPrior.DEFAULT_MC_SAMPLE));
        System.err.println("S3: " + (d3 / UniformNodeHeightPrior.DEFAULT_MC_SAMPLE));
        System.err.println("S4: " + (d4 / UniformNodeHeightPrior.DEFAULT_MC_SAMPLE));
    }

    @Override // dr.math.distributions.WishartStatistics
    public double getDF() {
        return this.df;
    }

    public double[][] nextWishart() {
        return nextWishart(this.df, this.scaleMatrix);
    }

    public static double[][] nextWishart(double d, double[][] dArr) {
        int length = dArr.length;
        double[][] dArr2 = new double[length][length];
        double[][] dArr3 = new double[length][length];
        for (int i = 0; i < length; i++) {
            for (int i2 = 0; i2 < i; i2++) {
                dArr3[i][i2] = MathUtils.nextGaussian();
            }
        }
        for (int i3 = 0; i3 < length; i3++) {
            dArr3[i3][i3] = Math.sqrt(MathUtils.nextGamma((d - i3) * 0.5d, 0.5d));
        }
        double[][] dArr4 = new double[length][length];
        for (int i4 = 0; i4 < length; i4++) {
            for (int i5 = i4; i5 < length; i5++) {
                double d2 = dArr[i4][i5];
                dArr4[i5][i4] = d2;
                dArr4[i4][i5] = d2;
            }
        }
        try {
            double[][] l = new CholeskyDecomposition(dArr4).getL();
            double[][] dArr5 = new double[length][length];
            for (int i6 = 0; i6 < length; i6++) {
                for (int i7 = 0; i7 < length; i7++) {
                    for (int i8 = 0; i8 < length; i8++) {
                        double[] dArr6 = dArr5[i6];
                        int i9 = i7;
                        dArr6[i9] = dArr6[i9] + (l[i6][i8] * dArr3[i8][i7]);
                    }
                }
            }
            for (int i10 = 0; i10 < length; i10++) {
                for (int i11 = 0; i11 < length; i11++) {
                    for (int i12 = 0; i12 < length; i12++) {
                        double[] dArr7 = dArr2[i10];
                        int i13 = i11;
                        dArr7[i13] = dArr7[i13] + (dArr5[i10][i12] * dArr5[i11][i12]);
                    }
                }
            }
            return dArr2;
        } catch (IllegalDimension e) {
            throw new RuntimeException("Numerical exception in WishartDistribution");
        }
    }

    @Override // dr.math.distributions.MultivariateDistribution, dr.inference.distribution.DensityModel
    public double logPdf(double[] dArr) {
        return dArr.length == 4 ? logPdf2D(dArr, this.Sinv, this.df, this.dim, this.logNormalizationConstant) : logPdfSlow(dArr);
    }

    public double logPdfSlow(double[] dArr) {
        return logPdf(new Matrix(dArr, this.dim, this.dim), this.SinvMat, this.df, this.dim, this.logNormalizationConstant);
    }

    public static double logPdf2D(double[] dArr, double[] dArr2, double d, int i, double d2) {
        double d3 = (dArr[0] * dArr[3]) - (dArr[1] * dArr[2]);
        if (d3 <= 0.0d) {
            return Double.NEGATIVE_INFINITY;
        }
        return ((Math.log(d3) * (0.5d * ((d - i) - 1.0d))) - (0.5d * ((((dArr2[0] * dArr[0]) + (dArr2[1] * dArr[2])) + (dArr2[2] * dArr[1])) + (dArr2[3] * dArr[3])))) + d2;
    }

    public static double logPdf(Matrix matrix, Matrix matrix2, double d, int i, double d2) {
        double logDeterminant;
        double d3 = 0.0d;
        try {
            logDeterminant = matrix.logDeterminant();
        } catch (IllegalDimension e) {
            e.printStackTrace();
        }
        if (Double.isInfinite(logDeterminant) || Double.isNaN(logDeterminant)) {
            return Double.NEGATIVE_INFINITY;
        }
        d3 = logDeterminant * 0.5d * ((d - i) - 1.0d);
        if (matrix2 != null) {
            Matrix product = matrix2.product(matrix);
            for (int i2 = 0; i2 < i; i2++) {
                d3 -= 0.5d * product.component(i2, i2);
            }
        }
        return d3 + d2;
    }

    /* JADX WARN: Type inference failed for: r3v1, types: [double[], double[][]] */
    public static void testBivariateMethod() {
        System.out.println("Testing new computations ...");
        WishartDistribution wishartDistribution = new WishartDistribution(5.0d, new double[]{new double[]{2.0d, -0.5d}, new double[]{-0.5d, 2.0d}});
        double[] dArr = {4.0d, 1.0d, 1.0d, 3.0d};
        System.out.println("Fast logPdf = " + wishartDistribution.logPdf(dArr));
        System.out.println("Slow logPdf = " + wishartDistribution.logPdfSlow(dArr));
    }

    /* JADX WARN: Type inference failed for: r3v1, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r3v8, types: [double[], double[][]] */
    public static void main(String[] strArr) {
        WishartDistribution wishartDistribution = new WishartDistribution(2.0d, new double[]{new double[]{500.0d}});
        GammaDistribution gammaDistribution = new GammaDistribution(0.001d, 1000.0d);
        double[] dArr = {1.0d};
        System.out.println("Wishart, df=2, scale = 500, PDF(1.0): " + wishartDistribution.logPdf(dArr));
        System.out.println("Gamma, shape = 1/1000, scale = 1000, PDF(1.0): " + gammaDistribution.logPdf(dArr[0]));
        WishartDistribution wishartDistribution2 = new WishartDistribution(4.0d, new double[]{new double[]{5.0d}});
        GammaDistribution gammaDistribution2 = new GammaDistribution(2.0d, 10.0d);
        double[] dArr2 = {1.0d};
        System.out.println("Wishart, df=4, scale = 5, PDF(1.0): " + wishartDistribution2.logPdf(dArr2));
        System.out.println("Gamma, shape = 1/1000, scale = 10, PDF(1.0): " + gammaDistribution2.logPdf(dArr2[0]));
        WishartDistribution wishartDistribution3 = new WishartDistribution(1);
        System.out.println("Wishart, uninformative, PDF(0.1): " + wishartDistribution3.logPdf(new double[]{0.1d}));
        System.out.println("Wishart, uninformative, PDF(1.0): " + wishartDistribution3.logPdf(new double[]{1.0d}));
        System.out.println("Wishart, uninformative, PDF(10.0): " + wishartDistribution3.logPdf(new double[]{10.0d}));
        testBivariateMethod();
    }
}
