package dr.evomodel.operators;

import dr.evolution.tree.MutableTreeModel;
import dr.evolution.tree.NodeRef;
import dr.evomodel.continuous.FullyConjugateMultivariateTraitLikelihood;
import dr.evomodel.continuous.LatentTruncation;
import dr.inference.model.CompoundParameter;
import dr.inference.model.MatrixParameter;
import dr.inference.model.Parameter;
import dr.inference.operators.MCMCOperator;
import dr.inference.operators.SimpleMCMCOperator;
import dr.math.MathUtils;
import dr.math.distributions.MultivariateNormalDistribution;
import dr.math.distributions.NormalDistribution;
import dr.math.matrixAlgebra.IllegalDimension;
import dr.math.matrixAlgebra.Matrix;
import dr.math.matrixAlgebra.SymmetricMatrix;
import dr.math.matrixAlgebra.Vector;
import dr.util.Citable;
import dr.xml.AbstractXMLObjectParser;
import dr.xml.AttributeRule;
import dr.xml.ElementRule;
import dr.xml.XMLObject;
import dr.xml.XMLObjectParser;
import dr.xml.XMLParseException;
import dr.xml.XMLSyntaxRule;
import java.util.logging.Logger;

/* loaded from: input_file:dr/evomodel/operators/LatentLiabilityGibbs.class */
public class LatentLiabilityGibbs extends SimpleMCMCOperator {
    public static final String LATENT_LIABILITY_GIBBS_OPERATOR = "latentLiabilityGibbsOperator";
    public static final String TREE_MODEL = "treeModel";
    private final LatentTruncation latentLiability;
    private final FullyConjugateMultivariateTraitLikelihood traitModel;
    private final CompoundParameter tipTraitParameter;
    protected double[] rootPriorMean;
    protected double rootPriorSampleSize;
    private final MatrixParameter precisionParam;
    private final MutableTreeModel treeModel;
    private final int dim;
    public double[][] postMeans;
    public double[][] preMeans;
    public double[] preP;
    public double[] postP;
    private Parameter mask;
    private boolean hasMask;
    private int numFixed = 0;
    private int numUpdate = 0;
    private int[] doUpdate;
    private int[] dontUpdate;
    public static XMLObjectParser PARSER = new AbstractXMLObjectParser() { // from class: dr.evomodel.operators.LatentLiabilityGibbs.1
        public static final String MASK = "mask";
        private XMLSyntaxRule[] rules = {AttributeRule.newDoubleRule("weight"), new ElementRule(FullyConjugateMultivariateTraitLikelihood.class, "The model for the latent random variables"), new ElementRule(LatentTruncation.class, "The model that links latent and observed variables"), new ElementRule("mask", Parameter.class, "Mask: 1 for latent variables that should be sampled", true), new ElementRule(CompoundParameter.class, "The parameter of tip locations from the tree")};

        @Override // dr.xml.XMLObjectParser
        public String getParserName() {
            return LatentLiabilityGibbs.LATENT_LIABILITY_GIBBS_OPERATOR;
        }

        @Override // dr.xml.AbstractXMLObjectParser
        public Object parseXMLObject(XMLObject xMLObject) throws XMLParseException {
            if (xMLObject.getChildCount() < 3) {
                throw new XMLParseException("Element with id = '" + xMLObject.getName() + "' should contain:\n\t 1 conjugate multivariateTraitLikelihood, 1 latentLiabilityLikelihood and one parameter \n");
            }
            double doubleAttribute = xMLObject.getDoubleAttribute("weight");
            FullyConjugateMultivariateTraitLikelihood fullyConjugateMultivariateTraitLikelihood = (FullyConjugateMultivariateTraitLikelihood) xMLObject.getChild(FullyConjugateMultivariateTraitLikelihood.class);
            LatentTruncation latentTruncation = (LatentTruncation) xMLObject.getChild(LatentTruncation.class);
            CompoundParameter compoundParameter = (CompoundParameter) xMLObject.getChild(CompoundParameter.class);
            Parameter parameter = null;
            if (xMLObject.hasChildNamed("mask")) {
                parameter = (Parameter) xMLObject.getElementFirstChild("mask");
            }
            return new LatentLiabilityGibbs(fullyConjugateMultivariateTraitLikelihood, latentTruncation, compoundParameter, parameter, doubleAttribute);
        }

        @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
        public String getParserDescription() {
            return "This element returns a gibbs sampler on tip latent trais for latent liability model.";
        }

        @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
        public Class getReturnType() {
            return MCMCOperator.class;
        }

        @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
        public XMLSyntaxRule[] getSyntaxRules() {
            return this.rules;
        }
    };

    public LatentLiabilityGibbs(FullyConjugateMultivariateTraitLikelihood fullyConjugateMultivariateTraitLikelihood, LatentTruncation latentTruncation, CompoundParameter compoundParameter, Parameter parameter, double d) {
        this.hasMask = false;
        this.latentLiability = latentTruncation;
        this.traitModel = fullyConjugateMultivariateTraitLikelihood;
        this.tipTraitParameter = compoundParameter;
        this.rootPriorMean = fullyConjugateMultivariateTraitLikelihood.getPriorMean();
        this.rootPriorSampleSize = fullyConjugateMultivariateTraitLikelihood.getPriorSampleSize();
        this.precisionParam = (MatrixParameter) fullyConjugateMultivariateTraitLikelihood.getDiffusionModel().getPrecisionParameter();
        this.treeModel = fullyConjugateMultivariateTraitLikelihood.getTreeModel();
        this.dim = this.precisionParam.getRowDimension();
        this.mask = parameter;
        if (parameter != null) {
            this.hasMask = true;
        }
        this.postMeans = new double[this.treeModel.getNodeCount()][this.dim];
        this.preMeans = new double[this.treeModel.getNodeCount()][this.dim];
        this.preP = new double[this.treeModel.getNodeCount()];
        this.postP = new double[this.treeModel.getNodeCount()];
        this.dontUpdate = new int[this.dim];
        this.doUpdate = new int[this.dim];
        if (this.hasMask) {
            for (int i = 0; i < this.dim; i++) {
                if (parameter.getParameterValue(i) == 0.0d) {
                    this.dontUpdate[this.numFixed] = i;
                    this.numFixed++;
                } else {
                    this.doUpdate[this.numUpdate] = i;
                    this.numUpdate++;
                }
            }
        }
        setWeight(d);
    }

    public int getStepCount() {
        return 1;
    }

    private void printInformation(MatrixParameter matrixParameter) {
        StringBuffer stringBuffer = new StringBuffer("\n \n parameter \n");
        for (int i = 0; i < this.dim; i++) {
            stringBuffer.append(matrixParameter.getParameterValue(0, i));
        }
        Logger.getLogger("dr.evomodel").info(stringBuffer.toString());
    }

    private void printInformation(double[] dArr) {
        StringBuffer stringBuffer = new StringBuffer("\n \n double vector \n");
        for (int i = 0; i < this.treeModel.getNodeCount(); i++) {
            stringBuffer.append(dArr[i]);
        }
        Logger.getLogger("dr.evomodel").info(stringBuffer.toString());
    }

    private void printInformation(double[][] dArr) {
        StringBuffer stringBuffer = new StringBuffer("\n \n double matrix \n");
        for (int i = 0; i < 1; i++) {
            for (int i2 = 0; i2 < this.treeModel.getNodeCount(); i2++) {
                stringBuffer.append(dArr[i2][i]);
            }
        }
        Logger.getLogger("dr.evomodel").info(stringBuffer.toString());
    }

    private void printInformation(double d) {
        StringBuffer stringBuffer = new StringBuffer("\n \n double \n");
        stringBuffer.append(d);
        Logger.getLogger("dr.evomodel").info(stringBuffer.toString());
    }

    private void printInformation(double d, String str) {
        StringBuffer stringBuffer = new StringBuffer("\n");
        stringBuffer.append(str);
        stringBuffer.append(Citable.Utils.DEFAULT_PREPEND);
        stringBuffer.append(d);
        Logger.getLogger("dr.evomodel").info(stringBuffer.toString());
    }

    @Override // dr.inference.operators.SimpleMCMCOperator
    public double doOperation() {
        double sampleNode2 = sampleNode2(this.treeModel.getExternalNode(MathUtils.nextInt(this.treeModel.getExternalNodeCount())));
        this.tipTraitParameter.fireParameterChangedEvent();
        return sampleNode2;
    }

    public void doPostOrderTraversal(NodeRef nodeRef) {
        int number = nodeRef.getNumber();
        if (this.treeModel.isExternal(nodeRef)) {
            double[] nodeTrait = getNodeTrait(nodeRef);
            for (int i = 0; i < this.dim; i++) {
                this.postMeans[number][i] = nodeTrait[i];
            }
            this.postP[number] = 1.0d / this.traitModel.getRescaledBranchLengthForPrecision(nodeRef);
            return;
        }
        NodeRef child = this.treeModel.getChild(nodeRef, 0);
        NodeRef child2 = this.treeModel.getChild(nodeRef, 1);
        doPostOrderTraversal(child);
        doPostOrderTraversal(child2);
        if (this.treeModel.isRoot(nodeRef)) {
            return;
        }
        int number2 = child.getNumber();
        int number3 = child2.getNumber();
        double d = this.postP[number2];
        double d2 = this.postP[number3];
        double rescaledBranchLengthForPrecision = 1.0d / this.traitModel.getRescaledBranchLengthForPrecision(nodeRef);
        double d3 = d + d2;
        this.postP[number] = (d3 * rescaledBranchLengthForPrecision) / (d3 + rescaledBranchLengthForPrecision);
        for (int i2 = 0; i2 < this.dim; i2++) {
            this.postMeans[number][i2] = ((d * this.postMeans[number2][i2]) + (d2 * this.postMeans[number3][i2])) / (d + d2);
        }
    }

    public double[] getNodeTrait(NodeRef nodeRef) {
        return this.tipTraitParameter.getParameter(nodeRef.getNumber()).getParameterValues();
    }

    public double getNodeTrait(NodeRef nodeRef, int i) {
        return this.tipTraitParameter.getParameter(nodeRef.getNumber()).getParameterValue(i);
    }

    public void setNodeTrait(NodeRef nodeRef, double[] dArr) {
        int number = nodeRef.getNumber();
        for (int i = 0; i < this.dim; i++) {
            this.tipTraitParameter.getParameter(number).setParameterValue(i, dArr[i]);
        }
        this.traitModel.getTraitParameter().getParameter(number).fireParameterChangedEvent();
    }

    public void setNodeTrait(NodeRef nodeRef, int i, double d) {
        this.tipTraitParameter.getParameter(nodeRef.getNumber()).setParameterValue(i, d);
    }

    public void doPreOrderTraversal(NodeRef nodeRef) {
        int number = nodeRef.getNumber();
        if (this.treeModel.isRoot(nodeRef)) {
            this.preP[number] = this.rootPriorSampleSize;
            for (int i = 0; i < this.dim; i++) {
                this.preMeans[number][i] = this.rootPriorMean[i];
            }
        } else {
            NodeRef parent = this.treeModel.getParent(nodeRef);
            NodeRef sisterNode = getSisterNode(nodeRef);
            int number2 = parent.getNumber();
            int number3 = sisterNode.getNumber();
            double d = this.preP[number2];
            double d2 = this.postP[number3];
            double rescaledBranchLengthForPrecision = 1.0d / this.traitModel.getRescaledBranchLengthForPrecision(nodeRef);
            double d3 = d + d2;
            this.preP[number] = (d3 * rescaledBranchLengthForPrecision) / (d3 + rescaledBranchLengthForPrecision);
            for (int i2 = 0; i2 < this.dim; i2++) {
                this.preMeans[number][i2] = ((d * this.preMeans[number2][i2]) + (d2 * this.postMeans[number3][i2])) / (d + d2);
            }
        }
        if (this.treeModel.isExternal(nodeRef)) {
            return;
        }
        doPreOrderTraversal(this.treeModel.getChild(nodeRef, 0));
        doPreOrderTraversal(this.treeModel.getChild(nodeRef, 1));
    }

    public NodeRef getSisterNode(NodeRef nodeRef) {
        NodeRef child = this.treeModel.getChild(this.treeModel.getParent(nodeRef), 0);
        return child == nodeRef ? this.treeModel.getChild(this.treeModel.getParent(nodeRef), 1) : child;
    }

    public double sampleNode(NodeRef nodeRef) {
        int number = nodeRef.getNumber();
        double[] nodeTrait = getNodeTrait(nodeRef);
        double[] dArr = new double[this.dim];
        for (int i = 0; i < this.dim; i++) {
            dArr[i] = this.preMeans[number][i];
        }
        double d = this.preP[number];
        double[][] dArr2 = new double[this.dim][this.dim];
        for (int i2 = 0; i2 < this.dim; i2++) {
            for (int i3 = 0; i3 < this.dim; i3++) {
                dArr2[i2][i3] = d * this.precisionParam.getParameterValue(i2, i3);
            }
        }
        int nextInt = MathUtils.nextInt(this.dim);
        double conditionalMean = getConditionalMean(nextInt, dArr2, nodeTrait, dArr);
        double sqrt = Math.sqrt(1.0d / dArr2[nextInt][nextInt]);
        double nodeTrait2 = getNodeTrait(nodeRef, nextInt);
        double nextGaussian = (MathUtils.nextGaussian() * sqrt) + conditionalMean;
        NormalDistribution normalDistribution = new NormalDistribution(conditionalMean, sqrt);
        double logPdf = normalDistribution.logPdf(nodeTrait2);
        double logPdf2 = normalDistribution.logPdf(nextGaussian);
        setNodeTrait(nodeRef, nextInt, nextGaussian);
        double d2 = logPdf - logPdf2;
        this.traitModel.getTraitParameter().getParameter(number).fireParameterChangedEvent();
        return d2;
    }

    public double sampleNode2(NodeRef nodeRef) {
        int number = nodeRef.getNumber();
        double[] conditionalMean = this.traitModel.getConditionalMean(number);
        double[][] conditionalPrecision = this.traitModel.getConditionalPrecision(number);
        double[] nodeTrait = getNodeTrait(nodeRef);
        double[] dArr = nodeTrait;
        int i = 0;
        boolean z = false;
        if (this.hasMask) {
            double[] dArr2 = new double[this.numUpdate];
            double[] dArr3 = new double[this.numUpdate];
            double[][] dArr4 = new double[this.numUpdate][this.numUpdate];
            for (int i2 = 0; i2 < this.numUpdate; i2++) {
                dArr2[i2] = dArr[this.doUpdate[i2]];
                for (int i3 = 0; i3 < this.numUpdate; i3++) {
                    dArr4[i2][i3] = conditionalPrecision[this.doUpdate[i2]][this.doUpdate[i3]];
                }
            }
            MultivariateNormalDistribution multivariateNormalDistribution = new MultivariateNormalDistribution(getComponentConditionalMean(conditionalPrecision, nodeTrait, conditionalMean, dArr4), dArr4);
            while (true) {
                if (!(!z) || !(i < 10000)) {
                    double logPdf = multivariateNormalDistribution.logPdf(dArr2) - multivariateNormalDistribution.logPdf(dArr3);
                    this.traitModel.getTraitParameter().getParameter(number).fireParameterChangedEvent();
                    return logPdf;
                }
                dArr3 = multivariateNormalDistribution.nextMultivariateNormal();
                for (int i4 = 0; i4 < this.numUpdate; i4++) {
                    dArr[this.doUpdate[i4]] = dArr3[i4];
                }
                setNodeTrait(nodeRef, dArr);
                if (this.latentLiability.validTraitForTip(number)) {
                    z = true;
                }
                i++;
            }
        } else {
            MultivariateNormalDistribution multivariateNormalDistribution2 = new MultivariateNormalDistribution(conditionalMean, conditionalPrecision);
            while (true) {
                if (!(!z) || !(i < 10000)) {
                    double logPdf2 = multivariateNormalDistribution2.logPdf(nodeTrait) - multivariateNormalDistribution2.logPdf(dArr);
                    this.traitModel.getTraitParameter().getParameter(number).fireParameterChangedEvent();
                    return logPdf2;
                }
                dArr = multivariateNormalDistribution2.nextMultivariateNormal();
                setNodeTrait(nodeRef, dArr);
                if (this.latentLiability.validTraitForTip(number)) {
                    z = true;
                }
                i++;
            }
        }
    }

    private double[] getComponentConditionalMean(double[][] dArr, double[] dArr2, double[] dArr3, double[][] dArr4) {
        double[] dArr5 = new double[this.numUpdate];
        double[][] dArr6 = new double[this.numUpdate][this.numFixed];
        new Matrix(this.numUpdate, this.numFixed);
        Vector vector = new Vector(this.numUpdate);
        double[] dArr7 = new double[this.numFixed];
        for (int i = 0; i < this.numUpdate; i++) {
            for (int i2 = 0; i2 < this.numFixed; i2++) {
                dArr6[i][i2] = dArr[this.doUpdate[i]][this.dontUpdate[i2]];
            }
        }
        for (int i3 = 0; i3 < this.numFixed; i3++) {
            dArr7[i3] = dArr2[this.dontUpdate[i3]] - dArr3[this.dontUpdate[i3]];
        }
        try {
            vector = new SymmetricMatrix(dArr4).inverse().product(new Matrix(dArr6)).product(new Vector(dArr7));
        } catch (IllegalDimension e) {
            e.printStackTrace();
        }
        for (int i4 = 0; i4 < this.numUpdate; i4++) {
            dArr5[i4] = dArr3[this.doUpdate[i4]] - vector.component(i4);
        }
        return dArr5;
    }

    private double getConditionalMean(int i, double[][] dArr, double[] dArr2, double[] dArr3) {
        double d = 0.0d;
        for (int i2 = 0; i2 < this.dim; i2++) {
            if (i2 != i) {
                d += dArr[i][i2] * (dArr2[i2] - dArr3[i2]);
            }
        }
        return dArr3[i] - (d / dArr[i][i]);
    }

    public String getPerformanceSuggestion() {
        return null;
    }

    @Override // dr.inference.operators.SimpleMCMCOperator, dr.inference.operators.MCMCOperator
    public String getOperatorName() {
        return LATENT_LIABILITY_GIBBS_OPERATOR;
    }
}
