package dr.inference.operators.factorAnalysis;

import dr.evomodel.continuous.OrderedLatentLiabilityLikelihood;
import dr.inference.model.DiagonalMatrix;
import dr.inference.model.LatentFactorModel;
import dr.inference.model.MatrixParameterInterface;
import dr.inference.model.Parameter;
import dr.inference.operators.GibbsOperator;
import dr.inference.operators.SimpleMCMCOperator;
import dr.math.MathUtils;
import dr.math.distributions.NormalDistribution;

/* loaded from: input_file:dr/inference/operators/factorAnalysis/LatentFactorLiabilityGibbsOperator.class */
public class LatentFactorLiabilityGibbsOperator extends SimpleMCMCOperator implements GibbsOperator {
    LatentFactorModel lfm;
    OrderedLatentLiabilityLikelihood liabilityLikelihood;

    public LatentFactorLiabilityGibbsOperator(double d, LatentFactorModel latentFactorModel, OrderedLatentLiabilityLikelihood orderedLatentLiabilityLikelihood) {
        setWeight(d);
        this.lfm = latentFactorModel;
        this.liabilityLikelihood = orderedLatentLiabilityLikelihood;
    }

    @Override // dr.inference.operators.SimpleMCMCOperator, dr.inference.operators.MCMCOperator
    public String getOperatorName() {
        return "LatentFactorLiabilityGibbsOperator";
    }

    @Override // dr.inference.operators.SimpleMCMCOperator
    public double doOperation() {
        if (this.liabilityLikelihood.getOrdering().booleanValue()) {
            doUnorderedOperation();
            return 0.0d;
        }
        doOrderedOperation();
        return 0.0d;
    }

    void doUnorderedOperation() {
        double[] lxF = this.lfm.getLxF();
        DiagonalMatrix diagonalMatrix = (DiagonalMatrix) this.lfm.getColumnPrecision();
        Parameter continuous = this.lfm.getContinuous();
        MatrixParameterInterface scaledData = this.lfm.getScaledData();
        for (int i = 0; i < scaledData.getColumnDimension(); i++) {
            int i2 = 0;
            int[] data = this.liabilityLikelihood.getData(i);
            for (int i3 = 0; i3 < data.length; i3++) {
                int i4 = data[i3];
                int parameterValue = (int) this.liabilityLikelihood.numClasses.getParameterValue(i3);
                if (i4 >= parameterValue && continuous.getParameterValue(i2) == 0.0d) {
                    scaledData.setParameterValue(i2, i, drawTruncatedNormalDistribution(lxF[(i * scaledData.getRowDimension()) + i2], diagonalMatrix.getParameterValue(i2, i2), Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY));
                }
                if (parameterValue == 1.0d) {
                    if (continuous.getParameterValue(i2) == 0.0d) {
                        scaledData.setParameterValue(i2, i, drawTruncatedNormalDistribution(lxF[(i * scaledData.getRowDimension()) + i2], diagonalMatrix.getParameterValue(i2, i2), Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY));
                    }
                    i2++;
                } else if (parameterValue == 2.0d) {
                    if (i4 == 0) {
                        scaledData.setParameterValue(i2, i, drawTruncatedNormalDistribution(lxF[(i * scaledData.getRowDimension()) + i2], diagonalMatrix.getParameterValue(i2, i2), Double.NEGATIVE_INFINITY, 0.0d));
                    } else {
                        scaledData.setParameterValue(i2, i, drawTruncatedNormalDistribution(lxF[(i * scaledData.getRowDimension()) + i2], diagonalMatrix.getParameterValue(i2, i2), 0.0d, Double.POSITIVE_INFINITY));
                    }
                    i2++;
                } else {
                    double[] dArr = new double[parameterValue];
                    dArr[0] = 0.0d;
                    if (i4 == 0) {
                        for (int i5 = 0; i5 < parameterValue - 1; i5++) {
                            scaledData.setParameterValue(i2 + i5, i, drawTruncatedNormalDistribution(lxF[(i * scaledData.getRowDimension()) + i2 + i5], diagonalMatrix.getParameterValue(i2 + i5, i2 + i5), Double.NEGATIVE_INFINITY, 0.0d));
                        }
                    } else {
                        dArr[i4] = drawTruncatedNormalDistribution(lxF[(i * scaledData.getRowDimension()) + ((i2 + i4) - 1)], diagonalMatrix.getParameterValue((i2 + i4) - 1, (i2 + i4) - 1), 0.0d, Double.POSITIVE_INFINITY);
                        scaledData.setParameterValue((i2 + i4) - 1, i, dArr[i4]);
                        for (int i6 = 1; i6 < parameterValue; i6++) {
                            if (i6 != i4) {
                                dArr[i6] = drawTruncatedNormalDistribution(lxF[(i * scaledData.getRowDimension()) + ((i2 + i6) - 1)], diagonalMatrix.getParameterValue((i2 + i6) - 1, (i2 + i6) - 1), Double.NEGATIVE_INFINITY, dArr[i4]);
                                scaledData.setParameterValue((i2 + i6) - 1, i, dArr[i6]);
                            }
                        }
                    }
                    i2 += parameterValue - 1;
                }
            }
        }
    }

    void doOrderedOperation() {
        double[] lxF = this.lfm.getLxF();
        DiagonalMatrix diagonalMatrix = (DiagonalMatrix) this.lfm.getColumnPrecision();
        Parameter continuous = this.lfm.getContinuous();
        MatrixParameterInterface scaledData = this.lfm.getScaledData();
        Parameter threshold = this.liabilityLikelihood.getThreshold();
        for (int i = 0; i < scaledData.getColumnDimension(); i++) {
            int i2 = 0;
            int[] data = this.liabilityLikelihood.getData(i);
            for (int i3 = 0; i3 < data.length; i3++) {
                int i4 = data[i3];
                int parameterValue = (int) this.liabilityLikelihood.numClasses.getParameterValue(i3);
                if (i4 >= parameterValue && continuous.getParameterValue(i3) == 0.0d) {
                    scaledData.setParameterValue(i3, i, drawTruncatedNormalDistribution(lxF[(i * scaledData.getRowDimension()) + i3], diagonalMatrix.getParameterValue(i3, i3), Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY));
                } else if (parameterValue == 1.0d) {
                    if (continuous.getParameterValue(i3) == 0.0d) {
                        scaledData.setParameterValue(i3, i, drawTruncatedNormalDistribution(lxF[(i * scaledData.getRowDimension()) + i3], diagonalMatrix.getParameterValue(i3, i3), Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY));
                    }
                } else if (parameterValue != 2.0d) {
                    double[] dArr = new double[parameterValue + 1];
                    dArr[0] = Double.NEGATIVE_INFINITY;
                    dArr[1] = 0.0d;
                    dArr[parameterValue] = Double.POSITIVE_INFINITY;
                    for (int i5 = 0; i5 < dArr.length - 3; i5++) {
                        dArr[i5 + 2] = threshold.getParameterValue(i2 + i5);
                    }
                    i2 += parameterValue - 2;
                    scaledData.setParameterValue(i3, i, drawTruncatedNormalDistribution(lxF[(i * scaledData.getRowDimension()) + i3], diagonalMatrix.getParameterValue(i3, i3), dArr[i4], dArr[i4 + 1]));
                } else if (i4 == 0) {
                    scaledData.setParameterValue(i3, i, drawTruncatedNormalDistribution(lxF[(i * scaledData.getRowDimension()) + i3], diagonalMatrix.getParameterValue(i3, i3), Double.NEGATIVE_INFINITY, 0.0d));
                } else {
                    scaledData.setParameterValue(i3, i, drawTruncatedNormalDistribution(lxF[(i * scaledData.getRowDimension()) + i3], diagonalMatrix.getParameterValue(i3, i3), 0.0d, Double.POSITIVE_INFINITY));
                }
            }
        }
    }

    double drawTruncatedNormalDistribution(double d, double d2, double d3, double d4) {
        NormalDistribution normalDistribution = new NormalDistribution(d, Math.sqrt(1.0d / d2));
        double cdf = normalDistribution.cdf(d3);
        double cdf2 = normalDistribution.cdf(d4);
        boolean z = true;
        double d5 = 0.0d;
        for (int i = 0; i < 10000 && z; i++) {
            d5 = normalDistribution.quantile((MathUtils.nextDouble() * (cdf2 - cdf)) + cdf);
            if (!Double.isNaN(d5) && d5 > d3 && d5 < d4) {
                z = false;
            }
        }
        return (Double.isNaN(d5) || Double.isInfinite(d5)) ? Double.isInfinite(d3) ? d4 : Double.isInfinite(d4) ? d3 : (d3 + d4) / 2.0d : d5;
    }
}
