package dr.inference.operators.factorAnalysis;

import cern.colt.matrix.impl.AbstractFormatter;
import dr.evomodel.tree.UniformNodeHeightPrior;
import dr.inference.distribution.DistributionLikelihood;
import dr.inference.operators.GibbsOperator;
import dr.inference.operators.SimpleMCMCOperator;
import dr.math.MathUtils;
import dr.math.distributions.MultivariateNormalDistribution;
import dr.math.distributions.NormalDistribution;
import dr.math.matrixAlgebra.CholeskyDecomposition;
import dr.math.matrixAlgebra.IllegalDimension;
import dr.math.matrixAlgebra.Matrix;
import dr.math.matrixAlgebra.SymmetricMatrix;
import dr.math.matrixAlgebra.Vector;
import dr.xml.Reportable;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.ListIterator;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

/* loaded from: input_file:dr/inference/operators/factorAnalysis/NewLoadingsGibbsOperator.class */
public class NewLoadingsGibbsOperator extends SimpleMCMCOperator implements GibbsOperator, Reportable {
    private NormalDistribution workingPrior;
    private final ArrayList<double[][]> precisionArray;
    private final ArrayList<double[]> meanMidArray;
    private final ArrayList<double[]> meanArray;
    private final boolean randomScan;
    private final double priorPrecision;
    private final double priorMean;
    private final double priorPrecisionWorking;
    private final FactorAnalysisOperatorAdaptor adaptor;
    private final ConstrainedSampler constrainedSampler;
    private static boolean DEBUG = false;
    private final ExecutorService pool;
    private double pathParameter = 1.0d;
    private final List<Callable<Double>> drawCallers = new ArrayList();

    /* loaded from: input_file:dr/inference/operators/factorAnalysis/NewLoadingsGibbsOperator$ConstrainedSampler.class */
    public enum ConstrainedSampler {
        NONE("none") { // from class: dr.inference.operators.factorAnalysis.NewLoadingsGibbsOperator.ConstrainedSampler.1
            @Override // dr.inference.operators.factorAnalysis.NewLoadingsGibbsOperator.ConstrainedSampler
            void applyConstraint(FactorAnalysisOperatorAdaptor factorAnalysisOperatorAdaptor) {
            }
        },
        REFLECTION("reflection") { // from class: dr.inference.operators.factorAnalysis.NewLoadingsGibbsOperator.ConstrainedSampler.2
            @Override // dr.inference.operators.factorAnalysis.NewLoadingsGibbsOperator.ConstrainedSampler
            void applyConstraint(FactorAnalysisOperatorAdaptor factorAnalysisOperatorAdaptor) {
                for (int i = 0; i < factorAnalysisOperatorAdaptor.getNumberOfFactors(); i++) {
                    factorAnalysisOperatorAdaptor.reflectLoadingsForFactor(i);
                }
            }
        };

        private String name;

        ConstrainedSampler(String str) {
            this.name = str;
        }

        public String getName() {
            return this.name;
        }

        public static ConstrainedSampler parse(String str) {
            String lowerCase = str.toLowerCase();
            for (ConstrainedSampler constrainedSampler : values()) {
                if (lowerCase.compareTo(constrainedSampler.getName()) == 0) {
                    return constrainedSampler;
                }
            }
            throw new IllegalArgumentException("Unknown sampler type");
        }

        abstract void applyConstraint(FactorAnalysisOperatorAdaptor factorAnalysisOperatorAdaptor);
    }

    /* loaded from: input_file:dr/inference/operators/factorAnalysis/NewLoadingsGibbsOperator$DrawCaller.class */
    class DrawCaller implements Callable<Double> {
        int i;
        double[][] precision;
        double[] midMean;
        double[] mean;
        private static final boolean DEBUG_PARALLEL_EVALUATION = false;

        DrawCaller(int i, double[][] dArr, double[] dArr2, double[] dArr3) {
            this.i = i;
            this.precision = dArr;
            this.midMean = dArr2;
            this.mean = dArr3;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.concurrent.Callable
        public Double call() throws Exception {
            NewLoadingsGibbsOperator.this.drawI(this.i, this.precision, this.midMean, this.mean);
            return null;
        }
    }

    public NewLoadingsGibbsOperator(FactorAnalysisOperatorAdaptor factorAnalysisOperatorAdaptor, DistributionLikelihood distributionLikelihood, double d, boolean z, DistributionLikelihood distributionLikelihood2, boolean z2, int i, ConstrainedSampler constrainedSampler) {
        setWeight(d);
        this.adaptor = factorAnalysisOperatorAdaptor;
        NormalDistribution normalDistribution = (NormalDistribution) distributionLikelihood.getDistribution();
        if (distributionLikelihood2 != null) {
            this.workingPrior = (NormalDistribution) distributionLikelihood2.getDistribution();
        }
        this.precisionArray = new ArrayList<>();
        this.randomScan = z;
        this.constrainedSampler = constrainedSampler;
        this.meanArray = new ArrayList<>();
        this.meanMidArray = new ArrayList<>();
        if (z) {
            for (int i2 = 0; i2 < factorAnalysisOperatorAdaptor.getNumberOfFactors(); i2++) {
                this.precisionArray.add(new double[factorAnalysisOperatorAdaptor.getNumberOfFactors() - i2][factorAnalysisOperatorAdaptor.getNumberOfFactors() - i2]);
            }
            for (int i3 = 0; i3 < factorAnalysisOperatorAdaptor.getNumberOfFactors(); i3++) {
                this.meanArray.add(new double[factorAnalysisOperatorAdaptor.getNumberOfFactors() - i3]);
            }
            for (int i4 = 0; i4 < factorAnalysisOperatorAdaptor.getNumberOfFactors(); i4++) {
                this.meanMidArray.add(new double[factorAnalysisOperatorAdaptor.getNumberOfFactors() - i4]);
            }
        } else {
            for (int i5 = 0; i5 < factorAnalysisOperatorAdaptor.getNumberOfFactors(); i5++) {
                this.precisionArray.add(new double[i5 + 1][i5 + 1]);
            }
            for (int i6 = 0; i6 < factorAnalysisOperatorAdaptor.getNumberOfFactors(); i6++) {
                this.meanArray.add(new double[i6 + 1]);
            }
            for (int i7 = 0; i7 < factorAnalysisOperatorAdaptor.getNumberOfFactors(); i7++) {
                this.meanMidArray.add(new double[i7 + 1]);
            }
        }
        this.priorPrecision = 1.0d / (normalDistribution.getSD() * normalDistribution.getSD());
        this.priorMean = normalDistribution.getMean();
        if (distributionLikelihood2 == null) {
            this.priorPrecisionWorking = this.priorPrecision;
        } else {
            this.priorPrecisionWorking = 1.0d / (this.workingPrior.getSD() * this.workingPrior.getSD());
        }
        if (!z2) {
            this.pool = null;
            return;
        }
        for (int i8 = 0; i8 < factorAnalysisOperatorAdaptor.getNumberOfTraits(); i8++) {
            if (i8 < factorAnalysisOperatorAdaptor.getNumberOfFactors()) {
                this.drawCallers.add(new DrawCaller(i8, new double[i8 + 1][i8 + 1], new double[i8 + 1], new double[i8 + 1]));
            } else {
                this.drawCallers.add(new DrawCaller(i8, new double[factorAnalysisOperatorAdaptor.getNumberOfFactors()][factorAnalysisOperatorAdaptor.getNumberOfFactors()], new double[factorAnalysisOperatorAdaptor.getNumberOfFactors()], new double[factorAnalysisOperatorAdaptor.getNumberOfFactors()]));
            }
        }
        this.pool = Executors.newFixedThreadPool(i);
    }

    private void getPrecisionOfTruncated(FactorAnalysisOperatorAdaptor factorAnalysisOperatorAdaptor, int i, int i2, double[][] dArr) {
        int numberOfTaxa = factorAnalysisOperatorAdaptor.getNumberOfTaxa();
        for (int i3 = 0; i3 < i; i3++) {
            for (int i4 = i3; i4 < i; i4++) {
                double d = 0.0d;
                for (int i5 = 0; i5 < numberOfTaxa; i5++) {
                    if (factorAnalysisOperatorAdaptor.isNotMissing(i2, i5)) {
                        d += factorAnalysisOperatorAdaptor.getFactorValue(i3, i5) * factorAnalysisOperatorAdaptor.getFactorValue(i4, i5);
                    }
                }
                dArr[i3][i4] = d * this.adaptor.getColumnPrecision(i2);
                if (i3 == i4) {
                    dArr[i3][i4] = (dArr[i3][i4] * this.pathParameter) + getAdjustedPriorPrecision();
                } else {
                    double[] dArr2 = dArr[i3];
                    int i6 = i4;
                    dArr2[i6] = dArr2[i6] * this.pathParameter;
                    dArr[i4][i3] = dArr[i3][i4];
                }
            }
        }
    }

    private void getTruncatedMean(int i, int i2, double[][] dArr, double[] dArr2, double[] dArr3) {
        int numberOfTaxa = this.adaptor.getNumberOfTaxa();
        for (int i3 = 0; i3 < i; i3++) {
            double d = 0.0d;
            for (int i4 = 0; i4 < numberOfTaxa; i4++) {
                if (this.adaptor.isNotMissing(i2, i4)) {
                    d += this.adaptor.getFactorValue(i3, i4) * this.adaptor.getDataValue(i2, i4);
                }
            }
            dArr2[i3] = (d * this.adaptor.getColumnPrecision(i2)) + (this.priorMean * this.priorPrecision);
        }
        for (int i5 = 0; i5 < i; i5++) {
            double d2 = 0.0d;
            for (int i6 = 0; i6 < i; i6++) {
                d2 += dArr[i5][i6] * dArr2[i6];
            }
            dArr3[i5] = d2;
        }
    }

    private void getPrecision(int i, double[][] dArr) {
        int numberOfFactors = this.adaptor.getNumberOfFactors();
        if (i < numberOfFactors) {
            getPrecisionOfTruncated(this.adaptor, i + 1, i, dArr);
        } else {
            getPrecisionOfTruncated(this.adaptor, numberOfFactors, i, dArr);
        }
    }

    private void getMean(int i, double[][] dArr, double[] dArr2, double[] dArr3) {
        int numberOfFactors = this.adaptor.getNumberOfFactors();
        if (i < numberOfFactors) {
            getTruncatedMean(i + 1, i, dArr, dArr2, dArr3);
        } else {
            getTruncatedMean(numberOfFactors, i, dArr, dArr2, dArr3);
        }
        for (int i2 = 0; i2 < dArr3.length; i2++) {
            int i3 = i2;
            dArr3[i3] = dArr3[i3] * this.pathParameter;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public void drawI(int i, double[][] dArr, double[] dArr2, double[] dArr3) {
        getPrecision(i, dArr);
        double[][] components = new SymmetricMatrix(dArr).inverse().toComponents();
        double[][] dArr4 = null;
        try {
            dArr4 = new CholeskyDecomposition(components).getL();
        } catch (IllegalDimension e) {
            e.printStackTrace();
        }
        getMean(i, components, dArr2, dArr3);
        double[] nextMultivariateNormalCholesky = MultivariateNormalDistribution.nextMultivariateNormalCholesky(dArr3, dArr4);
        this.adaptor.setLoadingsForTraitQuietly(i, nextMultivariateNormalCholesky);
        if (DEBUG) {
            System.err.println("draw: " + new Vector(nextMultivariateNormalCholesky));
        }
    }

    @Override // dr.inference.operators.SimpleMCMCOperator, dr.inference.operators.MCMCOperator
    public String getOperatorName() {
        return "newLoadingsGibbsOperator";
    }

    @Override // dr.xml.Reportable
    public String getReport() {
        int numberOfFactors = this.adaptor.getNumberOfFactors();
        int numberOfTaxa = this.adaptor.getNumberOfTaxa();
        int i = numberOfFactors * numberOfTaxa;
        double[] dArr = new double[i];
        double[][] dArr2 = new double[i][i];
        for (int i2 = 0; i2 < 100000; i2++) {
            this.adaptor.fireLoadingsChanged();
            this.adaptor.drawFactors();
            for (int i3 = 0; i3 < numberOfTaxa; i3++) {
                for (int i4 = 0; i4 < numberOfFactors; i4++) {
                    double factorValue = this.adaptor.getFactorValue(i4, i3);
                    int i5 = (i4 * numberOfTaxa) + i3;
                    dArr[i5] = dArr[i5] + factorValue;
                    for (int i6 = 0; i6 < numberOfTaxa; i6++) {
                        for (int i7 = 0; i7 < numberOfFactors; i7++) {
                            double factorValue2 = this.adaptor.getFactorValue(i7, i6);
                            double[] dArr3 = dArr2[(i4 * numberOfTaxa) + i3];
                            int i8 = (i7 * numberOfTaxa) + i6;
                            dArr3[i8] = dArr3[i8] + (factorValue * factorValue2);
                        }
                    }
                }
            }
        }
        double[] dArr4 = new double[i];
        double[][] dArr5 = new double[i][i];
        for (int i9 = 0; i9 < i; i9++) {
            dArr4[i9] = dArr[i9] / UniformNodeHeightPrior.DEFAULT_MC_SAMPLE;
            for (int i10 = 0; i10 < i; i10++) {
                double[] dArr6 = dArr2[i9];
                int i11 = i10;
                dArr6[i11] = dArr6[i11] / UniformNodeHeightPrior.DEFAULT_MC_SAMPLE;
            }
        }
        for (int i12 = 0; i12 < i; i12++) {
            for (int i13 = 0; i13 < i; i13++) {
                dArr5[i12][i13] = dArr2[i12][i13] - (dArr4[i12] * dArr4[i13]);
            }
        }
        StringBuilder sb = new StringBuilder();
        sb.append(getOperatorName() + "Report:\n");
        sb.append("Factor mean:\n");
        sb.append(new Vector(dArr4));
        sb.append(AbstractFormatter.DEFAULT_SLICE_SEPARATOR);
        sb.append("Factor covariance:\n");
        sb.append(new Matrix(dArr5));
        return sb.toString();
    }

    @Override // dr.inference.operators.SimpleMCMCOperator
    public double doOperation() {
        ListIterator<double[][]> listIterator;
        ListIterator<double[]> listIterator2;
        ListIterator<double[]> listIterator3;
        if (DEBUG) {
            System.err.println("Start doOp");
        }
        this.adaptor.drawFactors();
        int numberOfTraits = this.adaptor.getNumberOfTraits();
        if (this.adaptor.getNumberOfFactors() != this.precisionArray.listIterator().next().length) {
            if (DEBUG) {
                System.err.println("!= length");
            }
            this.precisionArray.clear();
            this.meanArray.clear();
            this.meanMidArray.clear();
            if (this.randomScan) {
                for (int i = 0; i < this.adaptor.getNumberOfFactors(); i++) {
                    this.precisionArray.add(new double[this.adaptor.getNumberOfFactors() - i][this.adaptor.getNumberOfFactors() - i]);
                }
                for (int i2 = 0; i2 < this.adaptor.getNumberOfFactors(); i2++) {
                    this.meanArray.add(new double[this.adaptor.getNumberOfFactors() - i2]);
                }
                for (int i3 = 0; i3 < this.adaptor.getNumberOfFactors(); i3++) {
                    this.meanMidArray.add(new double[this.adaptor.getNumberOfFactors() - i3]);
                }
            } else {
                for (int i4 = 0; i4 < this.adaptor.getNumberOfFactors(); i4++) {
                    this.precisionArray.add(new double[i4 + 1][i4 + 1]);
                }
                for (int i5 = 0; i5 < this.adaptor.getNumberOfFactors(); i5++) {
                    this.meanArray.add(new double[i5 + 1]);
                }
                for (int i6 = 0; i6 < this.adaptor.getNumberOfFactors(); i6++) {
                    this.meanMidArray.add(new double[i6 + 1]);
                }
            }
        }
        if (this.pool != null) {
            if (DEBUG) {
                System.err.println("!= poll");
            }
            try {
                this.pool.invokeAll(this.drawCallers);
                this.adaptor.fireLoadingsChanged();
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        } else {
            if (DEBUG) {
                System.err.println("inner");
            }
            if (this.randomScan) {
                int nextInt = MathUtils.nextInt(this.adaptor.getNumberOfTraits());
                if (nextInt < this.adaptor.getNumberOfFactors()) {
                    listIterator = this.precisionArray.listIterator((this.adaptor.getNumberOfFactors() - nextInt) - 1);
                    listIterator2 = this.meanMidArray.listIterator((this.adaptor.getNumberOfFactors() - nextInt) - 1);
                    listIterator3 = this.meanArray.listIterator((this.adaptor.getNumberOfFactors() - nextInt) - 1);
                } else {
                    listIterator = this.precisionArray.listIterator();
                    listIterator2 = this.meanMidArray.listIterator();
                    listIterator3 = this.meanArray.listIterator();
                }
                drawI(nextInt, listIterator.next(), listIterator2.next(), listIterator3.next());
                this.constrainedSampler.applyConstraint(this.adaptor);
                this.adaptor.fireLoadingsChanged();
            } else {
                ListIterator<double[][]> listIterator4 = this.precisionArray.listIterator();
                ListIterator<double[]> listIterator5 = this.meanMidArray.listIterator();
                ListIterator<double[]> listIterator6 = this.meanArray.listIterator();
                double[][] dArr = null;
                double[] dArr2 = null;
                double[] dArr3 = null;
                for (int i7 = 0; i7 < numberOfTraits; i7++) {
                    if (i7 < this.adaptor.getNumberOfFactors()) {
                        dArr = listIterator4.next();
                        dArr2 = listIterator5.next();
                        dArr3 = listIterator6.next();
                    }
                    drawI(i7, dArr, dArr2, dArr3);
                }
                this.constrainedSampler.applyConstraint(this.adaptor);
                this.adaptor.fireLoadingsChanged();
            }
        }
        if (!DEBUG) {
            return 0.0d;
        }
        Iterator<double[]> it = this.meanArray.iterator();
        while (it.hasNext()) {
            System.err.println(new Vector(it.next()));
        }
        Iterator<double[]> it2 = this.meanMidArray.iterator();
        while (it2.hasNext()) {
            System.err.println(new Vector(it2.next()));
        }
        Iterator<double[][]> it3 = this.precisionArray.iterator();
        while (it3.hasNext()) {
            System.err.println(new Matrix(it3.next()));
        }
        System.err.println("End doOp");
        return 0.0d;
    }

    @Override // dr.inference.operators.SimpleMCMCOperator, dr.inference.operators.PathDependent
    public void setPathParameter(double d) {
        this.pathParameter = d;
    }

    private double getAdjustedPriorPrecision() {
        return (this.priorPrecision * this.pathParameter) + ((1.0d - this.pathParameter) * this.priorPrecisionWorking);
    }
}
