package dr.inference.operators.factorAnalysis;

import dr.evomodel.continuous.FullyConjugateMultivariateTraitLikelihood;
import dr.evomodel.continuous.GaussianProcessFromTree;
import dr.inference.model.AdaptableSizeFastMatrixParameter;
import dr.inference.operators.AbstractAdaptableOperator;
import dr.inference.operators.AdaptationMode;
import dr.math.MathUtils;
import dr.math.distributions.GammaDistribution;
import dr.math.distributions.NormalDistribution;

/* loaded from: input_file:dr/inference/operators/factorAnalysis/LFMSplitMergeOperator.class */
public class LFMSplitMergeOperator extends AbstractAdaptableOperator {
    NormalDistribution drawDistribution;
    AdaptableSizeFastMatrixParameter factors;
    AdaptableSizeFastMatrixParameter sparseLoadings;
    AdaptableSizeFastMatrixParameter denseLoadings;
    AdaptableSizeFastMatrixParameter cutoffs;
    FullyConjugateMultivariateTraitLikelihood tree;
    GaussianProcessFromTree treeDraw;
    NormalDistribution standardNormal;
    GammaDistribution gamma;
    private String lastCall;
    final boolean upperTriangular = true;

    public LFMSplitMergeOperator(double d, double d2, AdaptableSizeFastMatrixParameter adaptableSizeFastMatrixParameter, AdaptableSizeFastMatrixParameter adaptableSizeFastMatrixParameter2, AdaptableSizeFastMatrixParameter adaptableSizeFastMatrixParameter3, AdaptableSizeFastMatrixParameter adaptableSizeFastMatrixParameter4, FullyConjugateMultivariateTraitLikelihood fullyConjugateMultivariateTraitLikelihood) {
        super(AdaptationMode.DEFAULT);
        this.standardNormal = new NormalDistribution(0.0d, 1.0d);
        this.gamma = new GammaDistribution(1.0d, 1.0d);
        this.upperTriangular = true;
        setWeight(d);
        this.drawDistribution = new NormalDistribution(0.0d, Math.sqrt(d2));
        this.factors = adaptableSizeFastMatrixParameter;
        this.sparseLoadings = adaptableSizeFastMatrixParameter3;
        this.denseLoadings = adaptableSizeFastMatrixParameter2;
        this.cutoffs = adaptableSizeFastMatrixParameter4;
        this.tree = fullyConjugateMultivariateTraitLikelihood;
        this.treeDraw = new GaussianProcessFromTree(fullyConjugateMultivariateTraitLikelihood);
    }

    @Override // dr.inference.operators.SimpleMCMCOperator, dr.inference.operators.MCMCOperator
    public String getOperatorName() {
        return "Latent Factor Model Split-Merge Operator";
    }

    @Override // dr.inference.operators.SimpleMCMCOperator
    public double doOperation() {
        double decrement;
        double d = 0.0d;
        if ((MathUtils.nextDouble() < 0.5d || this.denseLoadings.getColumnDimension() == this.denseLoadings.getMaxColumnDimension()) && this.denseLoadings.getColumnDimension() != 1) {
            this.lastCall = "decrement";
            if (this.denseLoadings.getColumnDimension() == 2) {
                d = 0.0d + Math.log(2.0d);
            }
            if (this.denseLoadings.getColumnDimension() == this.denseLoadings.getMaxColumnDimension()) {
                d += -Math.log(2.0d);
            }
            decrement = decrement();
        } else {
            this.lastCall = "increment";
            if (this.denseLoadings.getColumnDimension() == 1) {
                d = 0.0d + (-Math.log(2.0d));
            }
            if (this.denseLoadings.getColumnDimension() == this.denseLoadings.getMaxColumnDimension() - 1) {
                d += Math.log(2.0d);
            }
            decrement = increment();
        }
        this.denseLoadings.fireParameterChangedEvent();
        this.sparseLoadings.fireParameterChangedEvent();
        this.cutoffs.fireParameterChangedEvent();
        this.factors.fireParameterChangedEvent();
        return decrement + d;
    }

    @Override // dr.inference.operators.AbstractAdaptableOperator
    protected double getAdaptableParameterValue() {
        return 0.0d;
    }

    @Override // dr.inference.operators.AbstractAdaptableOperator
    public void setAdaptableParameterValue(double d) {
    }

    @Override // dr.inference.operators.AdaptableMCMCOperator
    public double getRawParameter() {
        return 0.0d;
    }

    @Override // dr.inference.operators.AdaptableMCMCOperator
    public String getAdaptableParameterName() {
        return "";
    }

    private double increment() {
        double d = 0.0d;
        int nextInt = MathUtils.nextInt(this.denseLoadings.getColumnDimension());
        int columnDimension = this.denseLoadings.getColumnDimension();
        this.denseLoadings.setColumnDimension(this.denseLoadings.getColumnDimension() + 1);
        this.sparseLoadings.setColumnDimension(this.sparseLoadings.getColumnDimension() + 1);
        this.cutoffs.setColumnDimension(this.cutoffs.getColumnDimension() + 1);
        for (int i = 0; i < this.denseLoadings.getRowDimension(); i++) {
            if (i < columnDimension) {
                double nextGamma = this.gamma.nextGamma();
                d += -this.gamma.logPdf(nextGamma);
                this.cutoffs.setParameterValueQuietly(i, columnDimension, nextGamma);
            } else if (this.sparseLoadings.getParameterValue(i, nextInt) == 0.0d) {
                d += sparseIncrement(i, columnDimension);
            } else {
                MathUtils.nextInt(3);
                if (2 == 0) {
                    d += sparseIncrement(i, columnDimension);
                } else if (2 == 1) {
                    this.sparseLoadings.setParameterValueQuietly(i, columnDimension, this.sparseLoadings.getParameterValue(i, nextInt));
                    this.cutoffs.setParameterValueQuietly(i, columnDimension, this.cutoffs.getParameterValue(i, nextInt));
                    this.denseLoadings.setParameterValueQuietly(i, columnDimension, this.denseLoadings.getParameterValue(i, nextInt));
                    d += sparseIncrement(i, nextInt);
                } else {
                    double d2 = d + (-Math.log(Math.pow(this.denseLoadings.getParameterValue(i, nextInt), 2.0d)));
                    this.sparseLoadings.setParameterValueQuietly(i, columnDimension, 1.0d);
                    double doubleValue = ((Double) this.drawDistribution.nextRandom()).doubleValue();
                    double parameterValue = this.denseLoadings.getParameterValue(i, nextInt) + doubleValue;
                    double parameterValue2 = this.denseLoadings.getParameterValue(i, nextInt) - doubleValue;
                    this.denseLoadings.setParameterValueQuietly(i, columnDimension, parameterValue);
                    this.denseLoadings.setParameterValueQuietly(i, nextInt, parameterValue2);
                    double log = d2 + (-this.drawDistribution.logPdf(doubleValue)) + Math.log(2.0d);
                    double nextDouble = MathUtils.nextDouble() * Math.pow(parameterValue, 2.0d);
                    double nextDouble2 = MathUtils.nextDouble() * Math.pow(parameterValue2, 2.0d);
                    d = log + Math.log(Math.pow(parameterValue, 2.0d)) + Math.log(Math.pow(parameterValue2, 2.0d));
                    this.cutoffs.setParameterValueQuietly(i, columnDimension, nextDouble);
                    this.cutoffs.setParameterValueQuietly(i, nextInt, nextDouble2);
                }
            }
        }
        return d;
    }

    double sparseIncrement(int i, int i2) {
        this.sparseLoadings.setParameterValueQuietly(i, i2, 0.0d);
        double doubleValue = ((Double) this.standardNormal.nextRandom()).doubleValue();
        this.denseLoadings.setParameterValueQuietly(i, i2, doubleValue);
        double d = 0.0d + (-this.standardNormal.logPdf(doubleValue));
        double nextDouble = MathUtils.nextDouble() * Math.pow(doubleValue, 2.0d);
        double log = d + Math.log(Math.pow(doubleValue, 2.0d));
        this.cutoffs.setParameterValueQuietly(i, i2, nextDouble);
        return log;
    }

    private double decrement() {
        double d = 0.0d;
        int nextInt = MathUtils.nextInt(this.denseLoadings.getColumnDimension() - 1);
        int columnDimension = this.denseLoadings.getColumnDimension() - 1;
        for (int i = 0; i < this.denseLoadings.getRowDimension(); i++) {
            if (i < columnDimension) {
                d += this.gamma.logPdf(this.cutoffs.getParameterValue(i, columnDimension));
            } else if (this.sparseLoadings.getParameterValue(i, nextInt) == 1.0d && this.sparseLoadings.getParameterValue(i, columnDimension) == 1.0d) {
                d = d + (-Math.log(Math.pow(this.denseLoadings.getParameterValue(i, nextInt), 2.0d))) + (-Math.log(Math.pow(this.denseLoadings.getParameterValue(i, columnDimension), 2.0d))) + this.drawDistribution.logPdf((this.denseLoadings.getParameterValue(i, columnDimension) - this.denseLoadings.getParameterValue(i, nextInt)) / 2.0d) + (-Math.log(2.0d));
                double parameterValue = (this.denseLoadings.getParameterValue(i, nextInt) + this.denseLoadings.getParameterValue(i, columnDimension)) / 2.0d;
                this.denseLoadings.setParameterValueQuietly(i, nextInt, parameterValue);
                this.cutoffs.setParameterValueQuietly(i, nextInt, MathUtils.nextDouble() * Math.pow(parameterValue, 2.0d));
            } else if (this.sparseLoadings.getParameterValue(i, nextInt) == 1.0d && this.sparseLoadings.getParameterValue(i, columnDimension) == 0.0d) {
                d = d + this.standardNormal.logPdf(this.denseLoadings.getParameterValue(i, columnDimension)) + (-Math.log(Math.pow(this.denseLoadings.getParameterValue(i, columnDimension), 2.0d)));
            } else if (this.sparseLoadings.getParameterValue(i, nextInt) == 0.0d && this.sparseLoadings.getParameterValue(i, columnDimension) == 1.0d) {
                d = d + this.standardNormal.logPdf(this.denseLoadings.getParameterValue(i, nextInt)) + (-Math.log(Math.pow(this.denseLoadings.getParameterValue(i, nextInt), 2.0d)));
                this.denseLoadings.setParameterValueQuietly(i, nextInt, this.denseLoadings.getParameterValue(i, columnDimension));
                this.cutoffs.setParameterValueQuietly(i, nextInt, this.cutoffs.getParameterValue(i, columnDimension));
                this.sparseLoadings.setParameterValueQuietly(i, nextInt, 1.0d);
            } else {
                d = d + this.standardNormal.logPdf(this.denseLoadings.getParameterValue(i, columnDimension)) + (-Math.log(Math.pow(this.denseLoadings.getParameterValue(i, columnDimension), 2.0d)));
            }
        }
        this.denseLoadings.setColumnDimension(this.denseLoadings.getColumnDimension() - 1);
        this.sparseLoadings.setColumnDimension(this.sparseLoadings.getColumnDimension() - 1);
        this.cutoffs.setColumnDimension(this.cutoffs.getColumnDimension() - 1);
        return d;
    }
}
