package dr.inference.operators.factorAnalysis;

import dr.inference.distribution.DistributionLikelihood;
import dr.inference.distribution.LatentFactorModelInterface;
import dr.inference.distribution.MomentDistributionModel;
import dr.inference.model.Likelihood;
import dr.inference.model.MatrixParameterInterface;
import dr.inference.model.Parameter;
import dr.inference.operators.GibbsOperator;
import dr.inference.operators.PathDependent;
import dr.inference.operators.SimpleMCMCOperator;
import dr.math.MathUtils;
import dr.math.distributions.NormalDistribution;
import dr.math.matrixAlgebra.SymmetricMatrix;

/* loaded from: input_file:dr/inference/operators/factorAnalysis/LoadingsGibbsTruncatedOperator.class */
public class LoadingsGibbsTruncatedOperator extends SimpleMCMCOperator implements PathDependent, GibbsOperator {
    Likelihood prior;
    LatentFactorModelInterface LFM;
    double[][] precisionArray;
    double[] meanMidArray;
    double[] meanArray;
    boolean randomScan;
    double pathParameter = 1.0d;
    final Parameter missingIndicator;
    double priorPrecision;
    double priorMeanPrecision;
    MatrixParameterInterface loadings;
    DistributionLikelihood cutoffPrior;

    public LoadingsGibbsTruncatedOperator(LatentFactorModelInterface latentFactorModelInterface, Likelihood likelihood, double d, boolean z, MatrixParameterInterface matrixParameterInterface, DistributionLikelihood distributionLikelihood) {
        setWeight(d);
        this.loadings = matrixParameterInterface;
        this.prior = likelihood;
        this.LFM = latentFactorModelInterface;
        if (likelihood instanceof MomentDistributionModel) {
            this.priorPrecision = ((MomentDistributionModel) this.prior).getScaleMatrix()[0][0];
            this.priorMeanPrecision = ((MomentDistributionModel) this.prior).getMean()[0] * this.priorPrecision;
        } else if (likelihood instanceof DistributionLikelihood) {
            this.priorPrecision = 1.0d / ((DistributionLikelihood) this.prior).getDistribution().variance();
            this.priorMeanPrecision = ((DistributionLikelihood) this.prior).getDistribution().mean() * this.priorPrecision;
        }
        this.cutoffPrior = distributionLikelihood;
        this.missingIndicator = latentFactorModelInterface.getMissingIndicator();
    }

    private void getPrecisionOfTruncated(MatrixParameterInterface matrixParameterInterface, int i, int i2, double[][] dArr) {
        int columnDimension = matrixParameterInterface.getColumnDimension();
        for (int i3 = 0; i3 < i; i3++) {
            for (int i4 = i3; i4 < i; i4++) {
                double d = 0.0d;
                for (int i5 = 0; i5 < columnDimension; i5++) {
                    if (this.missingIndicator == null || this.missingIndicator.getParameterValue((i5 * this.LFM.getScaledData().getRowDimension()) + i2) != 1.0d) {
                        d += matrixParameterInterface.getParameterValue(i3, i5) * matrixParameterInterface.getParameterValue(i4, i5);
                    }
                }
                dArr[i3][i4] = d * this.LFM.getColumnPrecision().getParameterValue(i2, i2);
                if (i3 == i4) {
                    dArr[i3][i4] = (dArr[i3][i4] * this.pathParameter) + this.priorPrecision;
                } else {
                    double[] dArr2 = dArr[i3];
                    int i6 = i4;
                    dArr2[i6] = dArr2[i6] * this.pathParameter;
                    dArr[i4][i3] = dArr[i3][i4];
                }
            }
        }
    }

    private void getTruncatedMean(int i, int i2, double[][] dArr, double[] dArr2, double[] dArr3) {
        MatrixParameterInterface scaledData = this.LFM.getScaledData();
        MatrixParameterInterface factors = this.LFM.getFactors();
        int columnDimension = scaledData.getColumnDimension();
        for (int i3 = 0; i3 < i; i3++) {
            double d = 0.0d;
            for (int i4 = 0; i4 < columnDimension; i4++) {
                if (this.missingIndicator == null || this.missingIndicator.getParameterValue((i4 * this.LFM.getScaledData().getRowDimension()) + i2) != 1.0d) {
                    d += factors.getParameterValue(i3, i4) * scaledData.getParameterValue(i2, i4);
                }
            }
            dArr2[i3] = (d * this.LFM.getColumnPrecision().getParameterValue(i2, i2)) + this.priorMeanPrecision;
        }
        for (int i5 = 0; i5 < i; i5++) {
            double d2 = 0.0d;
            for (int i6 = 0; i6 < i; i6++) {
                d2 += dArr[i5][i6] * dArr2[i6];
            }
            dArr3[i5] = d2;
        }
    }

    private void getPrecision(int i, double[][] dArr) {
        getPrecisionOfTruncated(this.LFM.getFactors(), this.loadings.getColumnDimension(), i, dArr);
    }

    private void getMean(int i, double[][] dArr, double[] dArr2, double[] dArr3) {
        getTruncatedMean(this.loadings.getColumnDimension(), i, dArr, dArr2, dArr3);
        for (int i2 = 0; i2 < dArr3.length; i2++) {
            int i3 = i2;
            dArr3[i3] = dArr3[i3] * this.pathParameter;
        }
    }

    private void copy(int i, double[] dArr) {
        MatrixParameterInterface matrixParameterInterface = this.loadings;
        for (int i2 = 0; i2 < dArr.length; i2++) {
            matrixParameterInterface.setParameterValueQuietly(i, i2, dArr[i2]);
        }
    }

    private double getTruncatedDraw(int i, int i2, NormalDistribution normalDistribution, boolean z) {
        double d = -Math.sqrt(((MatrixParameterInterface) ((MomentDistributionModel) this.prior).getCutoff()).getParameterValue(i, i2));
        double d2 = -d;
        double cdf = normalDistribution.cdf(d);
        double cdf2 = normalDistribution.cdf(d2);
        double d3 = cdf / (cdf + (1.0d - cdf2));
        double d4 = 0.0d;
        int i3 = 0;
        if (z) {
            while (true) {
                if (((d4 >= d2 || d4 <= d) && !Double.isNaN(d4)) || i3 >= 10) {
                    break;
                }
                d4 = MathUtils.nextDouble() < d3 ? normalDistribution.quantile(MathUtils.nextDouble() * cdf) : normalDistribution.quantile((MathUtils.nextDouble() * (1.0d - cdf2)) + cdf2);
                i3++;
            }
            if (i3 < 10) {
                this.loadings.setParameterValue(i, i2, d4);
            }
        } else {
            d4 = this.loadings.getParameterValue(i, i2);
        }
        return (Double.isNaN(d4) || Double.isNaN(Math.log(1.0d - (cdf2 - cdf)))) ? Double.NEGATIVE_INFINITY : normalDistribution.logPdf(d4) - Math.log(1.0d - (cdf2 - cdf));
    }

    public double drawI(int i, int i2, boolean z) {
        NormalDistribution normalDistribution;
        this.precisionArray = new double[this.loadings.getColumnDimension()][this.loadings.getColumnDimension()];
        this.meanMidArray = new double[this.loadings.getColumnDimension()];
        this.meanArray = new double[this.loadings.getColumnDimension()];
        getPrecision(i, this.precisionArray);
        if (this.LFM.getLoadings().getParameterValue(i, i2) != 0.0d) {
            double[][] components = new SymmetricMatrix(this.precisionArray).inverse().toComponents();
            getMean(i, components, this.meanMidArray, this.meanArray);
            normalDistribution = this.LFM.getFactorDimension() != 1 ? getConditionalDistribution(this.meanArray, components, i2, i) : new NormalDistribution(this.meanArray[0], Math.sqrt(components[0][0]));
        } else {
            normalDistribution = new NormalDistribution(0.0d, Math.sqrt(1.0d / this.priorPrecision));
        }
        double d = 0.0d;
        if (this.prior instanceof MomentDistributionModel) {
            d = MathUtils.nextDouble() < 0.5d ? getTruncatedDraw(i, i2, normalDistribution, z) : getTruncatedDraw(i, i2, normalDistribution, z);
        } else {
            this.loadings.setParameterValue(i, i2, normalDistribution.quantile(MathUtils.nextDouble()));
        }
        return d;
    }

    private NormalDistribution getConditionalDistribution(double[] dArr, double[][] dArr2, int i, int i2) {
        double d;
        double d2;
        double d3;
        double d4;
        double d5;
        double d6;
        double[][] dArr3 = new double[dArr.length - 1][dArr.length - 1];
        for (int i3 = 0; i3 < dArr.length; i3++) {
            for (int i4 = 0; i4 < dArr.length; i4++) {
                if (i3 < i && i4 < i) {
                    dArr3[i3][i4] = dArr2[i3][i4];
                } else if (i3 < i && i4 > i) {
                    dArr3[i3][i4 - 1] = dArr2[i3][i4];
                } else if (i3 > i && i4 < i) {
                    dArr3[i3 - 1][i4] = dArr2[i3][i4];
                } else if (i3 > i && i4 > i) {
                    dArr3[i3 - 1][i4 - 1] = dArr2[i3][i4];
                }
            }
        }
        double[][] components = new SymmetricMatrix(dArr3).inverse().toComponents();
        double[] dArr4 = new double[dArr.length - 1];
        double[] dArr5 = new double[dArr.length - 1];
        double[] dArr6 = new double[dArr.length - 1];
        for (int i5 = 0; i5 < dArr.length; i5++) {
            if (i5 < i) {
                dArr4[i5] = this.LFM.getLoadings().getParameterValue(i2, i5) - dArr[i5];
            } else if (i5 > i) {
                dArr4[i5 - 1] = this.LFM.getLoadings().getParameterValue(i2, i5) - dArr[i5];
            }
        }
        for (int i6 = 0; i6 < dArr.length - 1; i6++) {
            for (int i7 = 0; i7 < dArr.length - 1; i7++) {
                int i8 = i6;
                dArr5[i8] = dArr5[i8] + (components[i6][i7] * dArr4[i7]);
            }
        }
        double d7 = dArr[i];
        for (int i9 = 0; i9 < dArr.length - 1; i9++) {
            if (i9 < i) {
                d4 = d7;
                d5 = dArr5[i9];
                d6 = dArr2[i9][i];
            } else {
                d4 = d7;
                d5 = dArr5[i9];
                d6 = dArr2[i9 + 1][i];
            }
            d7 = d4 + (d5 * d6);
        }
        for (int i10 = 0; i10 < dArr.length - 1; i10++) {
            for (int i11 = 0; i11 < dArr.length - 1; i11++) {
                if (i10 < i) {
                    int i12 = i10;
                    dArr6[i12] = dArr6[i12] + (components[i10][i11] * dArr2[i11][i]);
                } else {
                    int i13 = i10;
                    dArr6[i13] = dArr6[i13] + (components[i10][i11] * dArr2[i11 + 1][i]);
                }
            }
        }
        double d8 = dArr2[i][i];
        for (int i14 = 0; i14 < dArr.length - 1; i14++) {
            if (i14 < i) {
                d = d8;
                d2 = dArr6[i14];
                d3 = dArr2[i14][i];
            } else {
                d = d8;
                d2 = dArr6[i14];
                d3 = dArr2[i14 + 1][i];
            }
            d8 = d - (d2 * d3);
        }
        return new NormalDistribution(d7, Math.sqrt(d8));
    }

    void getCutoffDraw(int i, int i2, NormalDistribution normalDistribution) {
        double nextDouble = MathUtils.nextDouble() * Math.abs(this.loadings.getParameterValue(i, i2));
        double sqrt = Math.sqrt(((MatrixParameterInterface) ((MomentDistributionModel) this.prior).getCutoff()).getParameterValue(i, i2));
        if (MathUtils.nextDouble() < (this.cutoffPrior.getDistribution().pdf(Math.pow(nextDouble, 2.0d)) / (1.0d - (normalDistribution.cdf(nextDouble) - normalDistribution.cdf(-nextDouble)))) / (this.cutoffPrior.getDistribution().pdf(Math.pow(sqrt, 2.0d)) / (1.0d - (normalDistribution.cdf(sqrt) - normalDistribution.cdf(-sqrt))))) {
            ((MatrixParameterInterface) ((MomentDistributionModel) this.prior).getCutoff()).setParameterValue(i, i2, Math.pow(nextDouble, 2.0d));
        }
    }

    public int getStepCount() {
        return 0;
    }

    @Override // dr.inference.operators.SimpleMCMCOperator, dr.inference.operators.MCMCOperator
    public String getOperatorName() {
        return "loadingsGibbsTruncatedOperator";
    }

    @Override // dr.inference.operators.SimpleMCMCOperator
    public double doOperation() {
        int rowDimension = this.LFM.getLoadings().getRowDimension();
        int nextInt = MathUtils.nextInt(this.LFM.getLoadings().getColumnDimension());
        for (int i = 0; i < rowDimension; i++) {
            drawI(i, nextInt, true);
        }
        this.loadings.fireParameterChangedEvent();
        return 0.0d;
    }

    @Override // dr.inference.operators.SimpleMCMCOperator, dr.inference.operators.PathDependent
    public void setPathParameter(double d) {
        this.pathParameter = d;
    }
}
