package dr.inference.operators.factorAnalysis;

import dr.inference.distribution.DistributionLikelihood;
import dr.inference.distribution.IndependentNormalDistributionModel;
import dr.inference.distribution.LatentFactorModelInterface;
import dr.inference.model.MatrixParameterInterface;
import dr.inference.model.Parameter;
import dr.inference.operators.GibbsOperator;
import dr.inference.operators.PathDependent;
import dr.inference.operators.SimpleMCMCOperator;
import dr.math.MathUtils;
import dr.math.distributions.MultivariateNormalDistribution;
import dr.math.distributions.NormalDistribution;
import dr.math.matrixAlgebra.CholeskyDecomposition;
import dr.math.matrixAlgebra.IllegalDimension;
import dr.math.matrixAlgebra.Matrix;
import dr.math.matrixAlgebra.SymmetricMatrix;
import dr.math.matrixAlgebra.Vector;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.ListIterator;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

/* loaded from: input_file:dr/inference/operators/factorAnalysis/LoadingsGibbsOperator.class */
public class LoadingsGibbsOperator extends SimpleMCMCOperator implements PathDependent, GibbsOperator {
    NormalDistribution prior;
    IndependentNormalDistributionModel prior3;
    NormalDistribution workingPrior;
    LatentFactorModelInterface LFM;
    ArrayList<double[][]> precisionArray;
    ArrayList<double[]> meanMidArray;
    ArrayList<double[]> meanArray;
    boolean randomScan;
    final Parameter missingIndicator;
    final MatrixParameterInterface loadings;
    final boolean upperTriangle;
    double priorPrecision;
    double priorMeanPrecision;
    double priorPrecisionWorking;
    double priorMeanPrecisionWorking;
    private static boolean DEBUG = false;
    private final ExecutorService pool;
    double pathParameter = 1.0d;
    private final List<Callable<Double>> drawCallers = new ArrayList();

    /* loaded from: input_file:dr/inference/operators/factorAnalysis/LoadingsGibbsOperator$DrawCaller.class */
    class DrawCaller implements Callable<Double> {
        int i;
        double[][] precision;
        double[] midMean;
        double[] mean;
        private final boolean DEBUG_PARALLEL_EVALUATION = false;

        public DrawCaller(int i, double[][] dArr, double[] dArr2, double[] dArr3) {
            this.i = i;
            this.precision = dArr;
            this.midMean = dArr2;
            this.mean = dArr3;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.concurrent.Callable
        public Double call() throws Exception {
            LoadingsGibbsOperator.this.drawI(this.i, this.precision, this.midMean, this.mean, LoadingsGibbsOperator.this.LFM.getFactors());
            return null;
        }
    }

    public LoadingsGibbsOperator(LatentFactorModelInterface latentFactorModelInterface, DistributionLikelihood distributionLikelihood, IndependentNormalDistributionModel independentNormalDistributionModel, MatrixParameterInterface matrixParameterInterface, double d, boolean z, DistributionLikelihood distributionLikelihood2, boolean z2, int i, boolean z3) {
        setWeight(d);
        this.upperTriangle = z3;
        if (distributionLikelihood != null) {
            this.prior = (NormalDistribution) distributionLikelihood.getDistribution();
        } else {
            this.prior3 = independentNormalDistributionModel;
        }
        if (distributionLikelihood2 != null) {
            this.workingPrior = (NormalDistribution) distributionLikelihood2.getDistribution();
        }
        if (matrixParameterInterface != null) {
            this.loadings = matrixParameterInterface;
        } else {
            this.loadings = latentFactorModelInterface.getLoadings();
        }
        this.LFM = latentFactorModelInterface;
        this.precisionArray = new ArrayList<>();
        this.randomScan = z;
        this.meanArray = new ArrayList<>();
        this.meanMidArray = new ArrayList<>();
        if (z) {
            for (int i2 = 0; i2 < latentFactorModelInterface.getFactorDimension(); i2++) {
                this.precisionArray.add(new double[latentFactorModelInterface.getFactorDimension() - i2][latentFactorModelInterface.getFactorDimension() - i2]);
            }
            for (int i3 = 0; i3 < latentFactorModelInterface.getFactorDimension(); i3++) {
                this.meanArray.add(new double[latentFactorModelInterface.getFactorDimension() - i3]);
            }
            for (int i4 = 0; i4 < latentFactorModelInterface.getFactorDimension(); i4++) {
                this.meanMidArray.add(new double[latentFactorModelInterface.getFactorDimension() - i4]);
            }
        } else {
            for (int i5 = 0; i5 < latentFactorModelInterface.getFactorDimension(); i5++) {
                this.precisionArray.add(new double[i5 + 1][i5 + 1]);
            }
            for (int i6 = 0; i6 < latentFactorModelInterface.getFactorDimension(); i6++) {
                this.meanArray.add(new double[i6 + 1]);
            }
            for (int i7 = 0; i7 < latentFactorModelInterface.getFactorDimension(); i7++) {
                this.meanMidArray.add(new double[i7 + 1]);
            }
        }
        if (distributionLikelihood != null) {
            this.priorPrecision = 1.0d / (this.prior.getSD() * this.prior.getSD());
            this.priorMeanPrecision = this.prior.getMean() * this.priorPrecision;
        }
        if (distributionLikelihood2 == null) {
            this.priorMeanPrecisionWorking = this.priorMeanPrecision;
            this.priorPrecisionWorking = this.priorPrecision;
        } else {
            this.priorPrecisionWorking = 1.0d / (this.workingPrior.getSD() * this.workingPrior.getSD());
            this.priorMeanPrecisionWorking = this.workingPrior.getMean() * this.priorPrecisionWorking;
        }
        if (z2) {
            for (int i8 = 0; i8 < matrixParameterInterface.getRowDimension(); i8++) {
                if (i8 >= latentFactorModelInterface.getFactorDimension() || !z3) {
                    this.drawCallers.add(new DrawCaller(i8, new double[latentFactorModelInterface.getFactorDimension()][latentFactorModelInterface.getFactorDimension()], new double[latentFactorModelInterface.getFactorDimension()], new double[latentFactorModelInterface.getFactorDimension()]));
                } else {
                    this.drawCallers.add(new DrawCaller(i8, new double[i8 + 1][i8 + 1], new double[i8 + 1], new double[i8 + 1]));
                }
            }
            this.pool = Executors.newFixedThreadPool(i);
        } else {
            this.pool = null;
        }
        this.missingIndicator = latentFactorModelInterface.getMissingIndicator();
    }

    private void getPrecisionOfTruncated(MatrixParameterInterface matrixParameterInterface, int i, int i2, double[][] dArr) {
        int columnDimension = matrixParameterInterface.getColumnDimension();
        for (int i3 = 0; i3 < i; i3++) {
            for (int i4 = i3; i4 < i; i4++) {
                double d = 0.0d;
                for (int i5 = 0; i5 < columnDimension; i5++) {
                    if (this.missingIndicator == null || this.missingIndicator.getParameterValue((i5 * this.LFM.getScaledData().getRowDimension()) + i2) != 1.0d) {
                        d += matrixParameterInterface.getParameterValue(i3, i5) * matrixParameterInterface.getParameterValue(i4, i5);
                    }
                }
                dArr[i3][i4] = d * this.LFM.getColumnPrecision().getParameterValue(i2, i2);
                if (i3 == i4) {
                    if (this.prior3 != null) {
                        if (this.prior3.getVariance() == null) {
                            this.priorPrecision = this.prior3.getPrecision().getParameterValue((i3 * this.loadings.getRowDimension()) + i2);
                        } else if (this.upperTriangle) {
                            this.priorPrecision = 1.0d / this.prior3.getVariance().getParameterValue((i3 * this.loadings.getRowDimension()) + i2);
                        } else {
                            this.priorPrecision = 1.0d / this.prior3.getVariance().getParameterValue((i3 * this.loadings.getRowDimension()) + i2);
                        }
                        if (this.workingPrior == null) {
                            this.priorMeanPrecisionWorking = this.priorMeanPrecision;
                            this.priorPrecisionWorking = this.priorPrecision;
                        }
                    }
                    dArr[i3][i4] = (dArr[i3][i4] * this.pathParameter) + getAdjustedPriorPrecision();
                } else {
                    double[] dArr2 = dArr[i3];
                    int i6 = i4;
                    dArr2[i6] = dArr2[i6] * this.pathParameter;
                    dArr[i4][i3] = dArr[i3][i4];
                }
            }
        }
    }

    private void getTruncatedMean(int i, int i2, double[][] dArr, double[] dArr2, double[] dArr3, MatrixParameterInterface matrixParameterInterface) {
        MatrixParameterInterface scaledData = this.LFM.getScaledData();
        int columnDimension = scaledData.getColumnDimension();
        for (int i3 = 0; i3 < i; i3++) {
            double d = 0.0d;
            for (int i4 = 0; i4 < columnDimension; i4++) {
                if (this.missingIndicator == null || this.missingIndicator.getParameterValue((i4 * this.LFM.getScaledData().getRowDimension()) + i2) != 1.0d) {
                    d += matrixParameterInterface.getParameterValue(i3, i4) * scaledData.getParameterValue(i2, i4);
                }
            }
            double parameterValue = d * this.LFM.getColumnPrecision().getParameterValue(i2, i2);
            if (this.prior3 != null && this.prior3.getVariance() != null) {
                this.priorMeanPrecision = (1.0d / this.prior3.getVariance().getParameterValue((i3 * this.loadings.getRowDimension()) + i2)) * this.prior3.getMean().getParameterValue((i3 * this.loadings.getRowDimension()) + i2);
            }
            dArr2[i3] = parameterValue + this.priorMeanPrecision;
        }
        for (int i5 = 0; i5 < i; i5++) {
            double d2 = 0.0d;
            for (int i6 = 0; i6 < i; i6++) {
                d2 += dArr[i5][i6] * dArr2[i6];
            }
            dArr3[i5] = d2;
        }
    }

    private void getPrecision(int i, double[][] dArr, MatrixParameterInterface matrixParameterInterface) {
        int factorDimension = this.LFM.getFactorDimension();
        if (i >= factorDimension || !this.upperTriangle) {
            getPrecisionOfTruncated(matrixParameterInterface, factorDimension, i, dArr);
        } else {
            getPrecisionOfTruncated(matrixParameterInterface, i + 1, i, dArr);
        }
    }

    private void getMean(int i, double[][] dArr, double[] dArr2, double[] dArr3, MatrixParameterInterface matrixParameterInterface) {
        int factorDimension = this.LFM.getFactorDimension();
        if (i >= factorDimension || !this.upperTriangle) {
            getTruncatedMean(factorDimension, i, dArr, dArr2, dArr3, matrixParameterInterface);
        } else {
            getTruncatedMean(i + 1, i, dArr, dArr2, dArr3, matrixParameterInterface);
        }
        for (int i2 = 0; i2 < dArr3.length; i2++) {
            int i3 = i2;
            dArr3[i3] = dArr3[i3] * this.pathParameter;
        }
    }

    private void copy(int i, double[] dArr) {
        MatrixParameterInterface matrixParameterInterface = this.loadings;
        for (int i2 = 0; i2 < dArr.length; i2++) {
            matrixParameterInterface.setParameterValueQuietly(i, i2, dArr[i2]);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public void drawI(int i, double[][] dArr, double[] dArr2, double[] dArr3, MatrixParameterInterface matrixParameterInterface) {
        double[][] dArr4 = null;
        getPrecision(i, dArr, matrixParameterInterface);
        double[][] components = new SymmetricMatrix(dArr).inverse().toComponents();
        try {
            dArr4 = new CholeskyDecomposition(components).getL();
        } catch (IllegalDimension e) {
            e.printStackTrace();
        }
        getMean(i, components, dArr2, dArr3, matrixParameterInterface);
        double[] nextMultivariateNormalCholesky = MultivariateNormalDistribution.nextMultivariateNormalCholesky(dArr3, dArr4);
        if (i < nextMultivariateNormalCholesky.length) {
            copy(i, nextMultivariateNormalCholesky);
        } else {
            copy(i, nextMultivariateNormalCholesky);
        }
        if (DEBUG) {
            System.err.println("draw: " + new Vector(nextMultivariateNormalCholesky));
        }
    }

    public int getStepCount() {
        return 0;
    }

    @Override // dr.inference.operators.SimpleMCMCOperator, dr.inference.operators.MCMCOperator
    public String getOperatorName() {
        return "loadingsGibbsOperator";
    }

    @Override // dr.inference.operators.SimpleMCMCOperator
    public double doOperation() {
        double[][] dArr;
        double[] dArr2;
        double[] dArr3;
        MatrixParameterInterface factors = this.LFM.getFactors();
        if (DEBUG) {
            System.err.println("Start doOp");
        }
        int rowDimension = this.loadings.getRowDimension();
        if (this.LFM.getFactorDimension() != this.precisionArray.listIterator().next().length) {
            if (DEBUG) {
                System.err.println("!= length");
            }
            this.precisionArray.clear();
            this.meanArray.clear();
            this.meanMidArray.clear();
            if (this.randomScan) {
                for (int i = 0; i < this.LFM.getFactorDimension(); i++) {
                    this.precisionArray.add(new double[this.LFM.getFactorDimension() - i][this.LFM.getFactorDimension() - i]);
                }
                for (int i2 = 0; i2 < this.LFM.getFactorDimension(); i2++) {
                    this.meanArray.add(new double[this.LFM.getFactorDimension() - i2]);
                }
                for (int i3 = 0; i3 < this.LFM.getFactorDimension(); i3++) {
                    this.meanMidArray.add(new double[this.LFM.getFactorDimension() - i3]);
                }
            } else {
                for (int i4 = 0; i4 < this.LFM.getFactorDimension(); i4++) {
                    this.precisionArray.add(new double[i4 + 1][i4 + 1]);
                }
                for (int i5 = 0; i5 < this.LFM.getFactorDimension(); i5++) {
                    this.meanArray.add(new double[i5 + 1]);
                }
                for (int i6 = 0; i6 < this.LFM.getFactorDimension(); i6++) {
                    this.meanMidArray.add(new double[i6 + 1]);
                }
            }
        }
        if (this.pool != null) {
            if (DEBUG) {
                System.err.println("!= poll");
            }
            try {
                this.pool.invokeAll(this.drawCallers);
                this.loadings.fireParameterChangedEvent();
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        } else {
            if (DEBUG) {
                System.err.println("inner");
            }
            if (this.randomScan) {
                int nextInt = MathUtils.nextInt(this.loadings.getRowDimension());
                if (nextInt < this.LFM.getFactorDimension() && this.upperTriangle) {
                    dArr = this.precisionArray.listIterator((this.LFM.getFactorDimension() - nextInt) - 1).next();
                    dArr2 = this.meanMidArray.listIterator((this.LFM.getFactorDimension() - nextInt) - 1).next();
                    dArr3 = this.meanArray.listIterator((this.LFM.getFactorDimension() - nextInt) - 1).next();
                } else if (this.LFM.getFactorDimension() == nextInt) {
                    dArr = this.precisionArray.listIterator().next();
                    dArr2 = this.meanMidArray.listIterator().next();
                    dArr3 = this.meanArray.listIterator().next();
                } else {
                    dArr = new double[this.loadings.getColumnDimension()][this.loadings.getColumnDimension()];
                    dArr2 = new double[this.loadings.getColumnDimension()];
                    dArr3 = new double[this.loadings.getColumnDimension()];
                }
                drawI(nextInt, dArr, dArr2, dArr3, factors);
                this.loadings.fireParameterChangedEvent(nextInt, null);
            } else {
                ListIterator<double[][]> listIterator = this.precisionArray.listIterator();
                ListIterator<double[]> listIterator2 = this.meanMidArray.listIterator();
                ListIterator<double[]> listIterator3 = this.meanArray.listIterator();
                double[][] dArr4 = new double[this.loadings.getColumnDimension()][this.loadings.getColumnDimension()];
                double[] dArr5 = new double[this.loadings.getColumnDimension()];
                double[] dArr6 = new double[this.loadings.getColumnDimension()];
                for (int i7 = 0; i7 < rowDimension; i7++) {
                    if (i7 < this.LFM.getFactorDimension() && this.upperTriangle) {
                        dArr4 = listIterator.next();
                        dArr5 = listIterator2.next();
                        dArr6 = listIterator3.next();
                    }
                    drawI(i7, dArr4, dArr5, dArr6, factors);
                }
                this.loadings.fireParameterChangedEvent();
            }
        }
        if (DEBUG) {
            Iterator<double[]> it = this.meanArray.iterator();
            while (it.hasNext()) {
                System.err.println(new Vector(it.next()));
            }
            Iterator<double[]> it2 = this.meanMidArray.iterator();
            while (it2.hasNext()) {
                System.err.println(new Vector(it2.next()));
            }
            Iterator<double[][]> it3 = this.precisionArray.iterator();
            while (it3.hasNext()) {
                System.err.println(new Matrix(it3.next()));
            }
        }
        if (!DEBUG) {
            return 0.0d;
        }
        System.err.println("End doOp");
        return 0.0d;
    }

    @Override // dr.inference.operators.SimpleMCMCOperator, dr.inference.operators.PathDependent
    public void setPathParameter(double d) {
        this.pathParameter = d;
    }

    public double getAdjustedPriorPrecision() {
        return (this.priorPrecision * this.pathParameter) + ((1.0d - this.pathParameter) * this.priorPrecisionWorking);
    }
}
