package dr.inference.operators.factorAnalysis;

import dr.evomodel.continuous.GibbsSampleFromTreeInterface;
import dr.inference.model.LatentFactorModel;
import dr.inference.model.MatrixParameterInterface;
import dr.inference.model.Parameter;
import dr.inference.operators.GibbsOperator;
import dr.inference.operators.PathDependent;
import dr.inference.operators.SimpleMCMCOperator;
import dr.math.MathUtils;
import dr.math.distributions.MultivariateNormalDistribution;
import dr.math.matrixAlgebra.Matrix;
import dr.math.matrixAlgebra.SymmetricMatrix;

/* loaded from: input_file:dr/inference/operators/factorAnalysis/FactorTreeGibbsOperator.class */
public class FactorTreeGibbsOperator extends SimpleMCMCOperator implements PathDependent, GibbsOperator {
    private final LatentFactorModel lfm;
    private double pathParameter = 1.0d;
    private final GibbsSampleFromTreeInterface tree;
    private final GibbsSampleFromTreeInterface workingTree;
    private final MatrixParameterInterface factors;
    private final MatrixParameterInterface errorPrec;
    private final boolean randomScan;
    private final Parameter missingIndicator;

    public FactorTreeGibbsOperator(double d, LatentFactorModel latentFactorModel, GibbsSampleFromTreeInterface gibbsSampleFromTreeInterface, Boolean bool) {
        setWeight(d);
        this.tree = gibbsSampleFromTreeInterface;
        this.lfm = latentFactorModel;
        this.factors = latentFactorModel.getFactors();
        this.errorPrec = latentFactorModel.getColumnPrecision();
        this.randomScan = bool.booleanValue();
        this.workingTree = null;
        this.missingIndicator = latentFactorModel.getMissingIndicator();
    }

    @Override // dr.inference.operators.SimpleMCMCOperator, dr.inference.operators.MCMCOperator
    public String getOperatorName() {
        return "Factor Tree Gibbs Operator";
    }

    @Override // dr.inference.operators.SimpleMCMCOperator
    public double doOperation() {
        if (this.randomScan) {
            int nextInt = MathUtils.nextInt(this.factors.getColumnDimension());
            double[] dArr = (double[]) getMVN(nextInt).nextRandom();
            for (int i = 0; i < this.factors.getRowDimension(); i++) {
                this.factors.setParameterValue(i, nextInt, dArr[i]);
            }
            return 0.0d;
        }
        for (int i2 = 0; i2 < this.factors.getColumnDimension(); i2++) {
            double[] dArr2 = (double[]) getMVN(i2).nextRandom();
            for (int i3 = 0; i3 < this.factors.getRowDimension(); i3++) {
                this.factors.setParameterValue(i3, i2, dArr2[i3]);
            }
        }
        return 0.0d;
    }

    MultivariateNormalDistribution getMVN(int i) {
        double[][] precision = getPrecision(i);
        return new MultivariateNormalDistribution(getMean(i, precision), precision);
    }

    double[][] getPrecision(int i) {
        double[][] treePrec = getTreePrec(i);
        for (int i2 = 0; i2 < this.lfm.getLoadings().getColumnDimension(); i2++) {
            for (int i3 = i2; i3 < this.lfm.getLoadings().getColumnDimension(); i3++) {
                for (int i4 = 0; i4 < this.lfm.getLoadings().getRowDimension(); i4++) {
                    if (this.missingIndicator == null || this.missingIndicator.getParameterValue((i * this.lfm.getLoadings().getRowDimension()) + i4) != 1.0d) {
                        double[] dArr = treePrec[i2];
                        int i5 = i3;
                        dArr[i5] = dArr[i5] + (this.lfm.getLoadings().getParameterValue(i4, i2) * this.errorPrec.getParameterValue(i4, i4) * this.lfm.getLoadings().getParameterValue(i4, i3) * this.pathParameter);
                    }
                }
                treePrec[i3][i2] = treePrec[i2][i3];
            }
        }
        return treePrec;
    }

    double[] getMean(int i, double[][] dArr) {
        Matrix inverse = new SymmetricMatrix(dArr).inverse();
        double[] dArr2 = new double[this.lfm.getLoadings().getColumnDimension()];
        double[] treeMean = getTreeMean(i);
        double[][] treePrec = getTreePrec(i);
        for (int i2 = 0; i2 < dArr2.length; i2++) {
            int i3 = i2;
            dArr2[i3] = dArr2[i3] + (treePrec[i2][i2] * treeMean[i2]);
        }
        for (int i4 = 0; i4 < this.lfm.getLoadings().getRowDimension(); i4++) {
            for (int i5 = 0; i5 < this.lfm.getLoadings().getColumnDimension(); i5++) {
                if (this.missingIndicator == null || this.missingIndicator.getParameterValue((i * this.lfm.getScaledData().getRowDimension()) + i4) != 1.0d) {
                    int i6 = i5;
                    dArr2[i6] = dArr2[i6] + (this.lfm.getScaledData().getParameterValue(i4, i) * this.errorPrec.getParameterValue(i4, i4) * this.lfm.getLoadings().getParameterValue(i4, i5) * this.pathParameter);
                }
            }
        }
        double[] dArr3 = new double[dArr2.length];
        for (int i7 = 0; i7 < dArr3.length; i7++) {
            for (int i8 = 0; i8 < dArr3.length; i8++) {
                int i9 = i7;
                dArr3[i9] = dArr3[i9] + (inverse.component(i7, i8) * dArr2[i8]);
            }
        }
        return dArr3;
    }

    public double[][] getTreePrec(int i) {
        double precisionFactor = this.tree.getPrecisionFactor(i);
        double[][] dArr = new double[this.factors.getRowDimension()][this.factors.getRowDimension()];
        for (int i2 = 0; i2 < this.factors.getRowDimension(); i2++) {
            dArr[i2][i2] = precisionFactor;
        }
        if (this.workingTree != null) {
            double[][] conditionalPrecision = this.workingTree.getConditionalPrecision(i);
            for (int i3 = 0; i3 < dArr.length; i3++) {
                for (int i4 = 0; i4 < dArr.length; i4++) {
                    dArr[i3][i4] = (dArr[i3][i4] * this.pathParameter) + (conditionalPrecision[i3][i4] * (1.0d - this.pathParameter));
                }
            }
        }
        return dArr;
    }

    public double[] getTreeMean(int i) {
        double[] conditionalMean = this.tree.getConditionalMean(i);
        if (this.workingTree != null) {
            double[] conditionalMean2 = this.workingTree.getConditionalMean(i);
            for (int i2 = 0; i2 < conditionalMean.length; i2++) {
                conditionalMean[i2] = (conditionalMean[i2] * this.pathParameter) + (conditionalMean2[i2] * (1.0d - this.pathParameter));
            }
        }
        return conditionalMean;
    }

    @Override // dr.inference.operators.SimpleMCMCOperator, dr.inference.operators.PathDependent
    public void setPathParameter(double d) {
        this.pathParameter = d;
    }
}
