package dr.inference.operators.factorAnalysis;

import dr.inference.distribution.DistributionLikelihood;
import dr.inference.model.LatentFactorModel;
import dr.inference.model.MatrixParameterInterface;
import dr.inference.model.Parameter;
import dr.inference.operators.AbstractAdaptableOperator;
import dr.inference.operators.AdaptationMode;
import dr.math.MathUtils;
import dr.math.distributions.MultivariateNormalDistribution;
import dr.math.distributions.NormalDistribution;
import dr.math.matrixAlgebra.CholeskyDecomposition;
import dr.math.matrixAlgebra.IllegalDimension;
import dr.math.matrixAlgebra.SymmetricMatrix;
import java.util.ArrayList;
import java.util.ListIterator;

/* loaded from: input_file:dr/inference/operators/factorAnalysis/LoadingsIndependenceOperator.class */
public class LoadingsIndependenceOperator extends AbstractAdaptableOperator {
    NormalDistribution prior;
    LatentFactorModel LFM;
    ArrayList<double[][]> precisionArray;
    ArrayList<double[]> meanMidArray;
    ArrayList<double[]> meanArray;
    boolean randomScan;
    double scaleFactor;
    double priorPrecision;
    double priorMeanPrecision;

    public LoadingsIndependenceOperator(LatentFactorModel latentFactorModel, DistributionLikelihood distributionLikelihood, double d, boolean z, double d2, AdaptationMode adaptationMode) {
        super(adaptationMode);
        setWeight(d);
        this.scaleFactor = d2;
        this.prior = (NormalDistribution) distributionLikelihood.getDistribution();
        this.LFM = latentFactorModel;
        this.precisionArray = new ArrayList<>();
        this.randomScan = z;
        this.meanArray = new ArrayList<>();
        this.meanMidArray = new ArrayList<>();
        if (z) {
            for (int i = 0; i < latentFactorModel.getFactorDimension(); i++) {
                this.precisionArray.add(new double[latentFactorModel.getFactorDimension() - i][latentFactorModel.getFactorDimension() - i]);
            }
            for (int i2 = 0; i2 < latentFactorModel.getFactorDimension(); i2++) {
                this.meanArray.add(new double[latentFactorModel.getFactorDimension() - i2]);
            }
            for (int i3 = 0; i3 < latentFactorModel.getFactorDimension(); i3++) {
                this.meanMidArray.add(new double[latentFactorModel.getFactorDimension() - i3]);
            }
        } else {
            for (int i4 = 0; i4 < latentFactorModel.getFactorDimension(); i4++) {
                this.precisionArray.add(new double[i4 + 1][i4 + 1]);
            }
            for (int i5 = 0; i5 < latentFactorModel.getFactorDimension(); i5++) {
                this.meanArray.add(new double[i5 + 1]);
            }
            for (int i6 = 0; i6 < latentFactorModel.getFactorDimension(); i6++) {
                this.meanMidArray.add(new double[i6 + 1]);
            }
        }
        this.priorPrecision = 1.0d / (this.prior.getSD() * this.prior.getSD());
        this.priorMeanPrecision = this.prior.getMean() * this.priorPrecision;
    }

    private void getPrecisionOfTruncated(MatrixParameterInterface matrixParameterInterface, int i, int i2, double[][] dArr) {
        int columnDimension = matrixParameterInterface.getColumnDimension();
        for (int i3 = 0; i3 < i; i3++) {
            for (int i4 = i3; i4 < i; i4++) {
                double d = 0.0d;
                for (int i5 = 0; i5 < columnDimension; i5++) {
                    d += matrixParameterInterface.getParameterValue(i3, i5) * matrixParameterInterface.getParameterValue(i4, i5);
                }
                dArr[i3][i4] = d * this.LFM.getColumnPrecision().getParameterValue(i2, i2);
                if (i3 == i4) {
                    double[] dArr2 = dArr[i3];
                    int i6 = i4;
                    dArr2[i6] = dArr2[i6] + this.priorPrecision;
                } else {
                    dArr[i4][i3] = dArr[i3][i4];
                }
            }
        }
    }

    private void getTruncatedMean(int i, int i2, double[][] dArr, double[] dArr2, double[] dArr3) {
        MatrixParameterInterface scaledData = this.LFM.getScaledData();
        MatrixParameterInterface factors = this.LFM.getFactors();
        int columnDimension = scaledData.getColumnDimension();
        for (int i3 = 0; i3 < i; i3++) {
            double d = 0.0d;
            for (int i4 = 0; i4 < columnDimension; i4++) {
                d += factors.getParameterValue(i3, i4) * scaledData.getParameterValue(i2, i4);
            }
            dArr2[i3] = (d * this.LFM.getColumnPrecision().getParameterValue(i2, i2)) + this.priorMeanPrecision;
        }
        for (int i5 = 0; i5 < i; i5++) {
            double d2 = 0.0d;
            for (int i6 = 0; i6 < i; i6++) {
                d2 += dArr[i5][i6] * dArr2[i6];
            }
            dArr3[i5] = d2;
        }
    }

    private void getPrecision(int i, double[][] dArr) {
        int factorDimension = this.LFM.getFactorDimension();
        if (i < factorDimension) {
            getPrecisionOfTruncated(this.LFM.getFactors(), i + 1, i, dArr);
        } else {
            getPrecisionOfTruncated(this.LFM.getFactors(), factorDimension, i, dArr);
        }
    }

    private void getMean(int i, double[][] dArr, double[] dArr2, double[] dArr3) {
        int factorDimension = this.LFM.getFactorDimension();
        if (i < factorDimension) {
            getTruncatedMean(i + 1, i, dArr, dArr2, dArr3);
        } else {
            getTruncatedMean(factorDimension, i, dArr, dArr2, dArr3);
        }
    }

    private void copy(int i, double[] dArr) {
        Parameter parameter = this.LFM.getLoadings().getParameter(i);
        for (int i2 = 0; i2 < dArr.length; i2++) {
            parameter.setParameterValueQuietly(i2, dArr[i2]);
        }
    }

    private void drawI(int i, ListIterator<double[][]> listIterator, ListIterator<double[]> listIterator2, ListIterator<double[]> listIterator3) {
        double[][] dArr = null;
        double[] dArr2 = null;
        double[] dArr3 = null;
        double[][] dArr4 = null;
        if (listIterator.hasNext()) {
            dArr = listIterator.next();
        }
        if (listIterator2.hasNext()) {
            dArr2 = listIterator2.next();
        }
        if (listIterator3.hasNext()) {
            dArr3 = listIterator3.next();
        }
        getPrecision(i, dArr);
        double[][] components = new SymmetricMatrix(dArr).inverse().toComponents();
        try {
            dArr4 = new CholeskyDecomposition(components).getL();
        } catch (IllegalDimension e) {
            e.printStackTrace();
        }
        getMean(i, components, dArr2, dArr3);
        double[] nextMultivariateNormalCholesky = MultivariateNormalDistribution.nextMultivariateNormalCholesky(dArr3, dArr4, this.scaleFactor);
        if (i >= nextMultivariateNormalCholesky.length) {
            copy(i, nextMultivariateNormalCholesky);
        } else if (nextMultivariateNormalCholesky[i] > 0.0d) {
            copy(i, nextMultivariateNormalCholesky);
        }
    }

    @Override // dr.inference.operators.SimpleMCMCOperator, dr.inference.operators.MCMCOperator
    public String getOperatorName() {
        return "loadingsGibbsOperator";
    }

    @Override // dr.inference.operators.SimpleMCMCOperator
    public double doOperation() {
        ListIterator<double[][]> listIterator;
        ListIterator<double[]> listIterator2;
        ListIterator<double[]> listIterator3;
        int columnDimension = this.LFM.getLoadings().getColumnDimension();
        if (!this.randomScan) {
            ListIterator<double[][]> listIterator4 = this.precisionArray.listIterator();
            ListIterator<double[]> listIterator5 = this.meanMidArray.listIterator();
            ListIterator<double[]> listIterator6 = this.meanArray.listIterator();
            for (int i = 0; i < columnDimension; i++) {
                drawI(i, listIterator4, listIterator5, listIterator6);
            }
            this.LFM.getLoadings().fireParameterChangedEvent();
            return 0.0d;
        }
        int nextInt = MathUtils.nextInt(this.LFM.getLoadings().getColumnDimension());
        if (nextInt < this.LFM.getFactorDimension()) {
            listIterator = this.precisionArray.listIterator((this.LFM.getFactorDimension() - nextInt) - 1);
            listIterator2 = this.meanMidArray.listIterator((this.LFM.getFactorDimension() - nextInt) - 1);
            listIterator3 = this.meanArray.listIterator((this.LFM.getFactorDimension() - nextInt) - 1);
        } else {
            listIterator = this.precisionArray.listIterator();
            listIterator2 = this.meanMidArray.listIterator();
            listIterator3 = this.meanArray.listIterator();
        }
        drawI(nextInt, listIterator, listIterator2, listIterator3);
        this.LFM.getLoadings().fireParameterChangedEvent();
        return 0.0d;
    }

    @Override // dr.inference.operators.AbstractAdaptableOperator
    protected double getAdaptableParameterValue() {
        return Math.log(this.scaleFactor);
    }

    @Override // dr.inference.operators.AbstractAdaptableOperator
    public void setAdaptableParameterValue(double d) {
        this.scaleFactor = Math.exp(d);
    }

    @Override // dr.inference.operators.AdaptableMCMCOperator
    public double getRawParameter() {
        return this.scaleFactor;
    }

    @Override // dr.inference.operators.AdaptableMCMCOperator
    public String getAdaptableParameterName() {
        return "scaleFactor";
    }
}
