package dr.evomodel.operators;

import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeTrait;
import dr.evomodel.continuous.DummyLatentTruncationProvider;
import dr.evomodel.continuous.LatentTruncation;
import dr.evomodel.tree.UniformNodeHeightPrior;
import dr.evomodel.treedatalikelihood.TreeDataLikelihood;
import dr.evomodel.treedatalikelihood.continuous.ContinuousDataLikelihoodDelegate;
import dr.evomodel.treedatalikelihood.continuous.RepeatedMeasuresTraitDataModel;
import dr.evomodel.treedatalikelihood.preorder.ConditionalPrecisionAndTransform;
import dr.evomodel.treedatalikelihood.preorder.WrappedNormalSufficientStatistics;
import dr.evomodel.treedatalikelihood.preorder.WrappedTipFullConditionalDistributionDelegate;
import dr.inference.model.CompoundParameter;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.inference.operators.MCMCOperator;
import dr.inference.operators.SimpleMCMCOperator;
import dr.math.MathUtils;
import dr.math.distributions.MultivariateNormalDistribution;
import dr.math.matrixAlgebra.Matrix;
import dr.math.matrixAlgebra.WrappedMatrix;
import dr.math.matrixAlgebra.WrappedVector;
import dr.math.matrixAlgebra.missingData.MissingOps;
import dr.xml.AbstractXMLObjectParser;
import dr.xml.AttributeRule;
import dr.xml.ElementRule;
import dr.xml.XMLObject;
import dr.xml.XMLObjectParser;
import dr.xml.XMLParseException;
import dr.xml.XMLSyntaxRule;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.ejml.alg.dense.mult.MatrixVectorMult;
import org.ejml.data.DenseMatrix64F;
import org.ejml.ops.CommonOps;

/* loaded from: input_file:dr/evomodel/operators/NewLatentLiabilityGibbs.class */
public class NewLatentLiabilityGibbs extends SimpleMCMCOperator {
    private static final String NEW_LATENT_LIABILITY_GIBBS_OPERATOR = "newlatentLiabilityGibbsOperator";
    private static final String MAX_ATTEMPTS = "numAttempts";
    private static final String MISSING_BY_COLUMN = "missingByColumn";
    private static final String FORCE_ALL_MISSING = "forceAllMissing";
    private final LatentTruncation latentLiability;
    private final CompoundParameter tipTraitParameter;
    private final TreeTrait<List<WrappedNormalSufficientStatistics>> fullConditionalDensity;
    private final RepeatedMeasuresTraitDataModel repeatedMeasuresModel;
    private int maxAttempts;
    private final Tree treeModel;
    private final int dim;
    private Parameter mask;
    private final MaskIndicesDelegate maskDelegate;
    private final Boolean missingByColumn;
    private final int[] needSampling;
    private double[] fcdMean;
    private double[][] fcdPrecision;
    private double[][] fcdVaraince;
    private double[] maskedMean;
    private double[][] maskedPrecision;
    public static XMLObjectParser PARSER = new AbstractXMLObjectParser() { // from class: dr.evomodel.operators.NewLatentLiabilityGibbs.1
        private static final String MASK = "mask";
        private static final String PARTIALS_PROVIDER = "partialsProvider";
        private XMLSyntaxRule[] rules = {AttributeRule.newDoubleRule("weight"), AttributeRule.newBooleanRule(NewLatentLiabilityGibbs.MISSING_BY_COLUMN, true), new ElementRule(TreeDataLikelihood.class, "The model for the latent random variables"), new ElementRule(LatentTruncation.class, "The model that links latent and observed variables"), new ElementRule("mask", Parameter.class, "Mask: 1 for latent variables that should be sampled", true), new ElementRule(CompoundParameter.class, "The parameter of tip locations from the tree"), new ElementRule(PARTIALS_PROVIDER, RepeatedMeasuresTraitDataModel.class, "Provides information about model extensions", true)};

        @Override // dr.xml.XMLObjectParser
        public String getParserName() {
            return NewLatentLiabilityGibbs.NEW_LATENT_LIABILITY_GIBBS_OPERATOR;
        }

        @Override // dr.xml.AbstractXMLObjectParser
        public Object parseXMLObject(XMLObject xMLObject) throws XMLParseException {
            if (xMLObject.getChildCount() < 3) {
                throw new XMLParseException("Element with id = '" + xMLObject.getName() + "' should contain:\n\t 1 conjugate multivariateTraitLikelihood, 1 latentLiabilityLikelihood and one parameter \n");
            }
            double doubleAttribute = xMLObject.getDoubleAttribute("weight");
            TreeDataLikelihood treeDataLikelihood = (TreeDataLikelihood) xMLObject.getChild(TreeDataLikelihood.class);
            LatentTruncation latentTruncation = (LatentTruncation) xMLObject.getChild(LatentTruncation.class);
            CompoundParameter compoundParameter = (CompoundParameter) xMLObject.getChild(CompoundParameter.class);
            int intValue = ((Integer) xMLObject.getAttribute(NewLatentLiabilityGibbs.MAX_ATTEMPTS, Integer.valueOf(UniformNodeHeightPrior.DEFAULT_MC_SAMPLE))).intValue();
            boolean booleanValue = ((Boolean) xMLObject.getAttribute(NewLatentLiabilityGibbs.MISSING_BY_COLUMN, true)).booleanValue();
            Parameter parameter = xMLObject.hasChildNamed("mask") ? (Parameter) xMLObject.getElementFirstChild("mask") : null;
            RepeatedMeasuresTraitDataModel repeatedMeasuresTraitDataModel = xMLObject.hasChildNamed(PARTIALS_PROVIDER) ? (RepeatedMeasuresTraitDataModel) xMLObject.getElementFirstChild(PARTIALS_PROVIDER) : null;
            if (((Boolean) xMLObject.getAttribute(NewLatentLiabilityGibbs.FORCE_ALL_MISSING, false)).booleanValue()) {
                int traitDim = treeDataLikelihood.getDataLikelihoodDelegate().getTraitDim();
                parameter = new Parameter.Default(traitDim);
                for (int i = 0; i < traitDim; i++) {
                    parameter.setParameterValue(i, 1.0d);
                }
                booleanValue = true;
            }
            return new NewLatentLiabilityGibbs(treeDataLikelihood, latentTruncation, compoundParameter, repeatedMeasuresTraitDataModel, parameter, doubleAttribute, "latent", intValue, booleanValue);
        }

        @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
        public String getParserDescription() {
            return "This element returns a gibbs sampler on tip latent trais for latent liability model.";
        }

        @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
        public Class getReturnType() {
            return MCMCOperator.class;
        }

        @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
        public XMLSyntaxRule[] getSyntaxRules() {
            return this.rules;
        }
    };

    /* loaded from: input_file:dr/evomodel/operators/NewLatentLiabilityGibbs$MaskIndices.class */
    protected class MaskIndices {
        final int[] discreteIndices;
        final int[] continuousIndex;

        private MaskIndices(int[] iArr, int[] iArr2) {
            this.discreteIndices = iArr;
            this.continuousIndex = iArr2;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:dr/evomodel/operators/NewLatentLiabilityGibbs$MaskIndicesDelegate.class */
    public class MaskIndicesDelegate {
        int[] latentColumns;
        int[] observedColumns;

        private MaskIndicesDelegate() {
            this.latentColumns = null;
            this.observedColumns = null;
            if (NewLatentLiabilityGibbs.this.mask == null || !NewLatentLiabilityGibbs.this.missingByColumn.booleanValue()) {
                return;
            }
            ArrayList arrayList = new ArrayList();
            ArrayList arrayList2 = new ArrayList();
            for (int i = 0; i < NewLatentLiabilityGibbs.this.dim; i++) {
                if (NewLatentLiabilityGibbs.this.mask.getParameterValue(i) == 1.0d) {
                    arrayList.add(Integer.valueOf(i));
                } else {
                    arrayList2.add(Integer.valueOf(i));
                }
            }
            this.latentColumns = NewLatentLiabilityGibbs.this.convertListToArray(arrayList);
            this.observedColumns = NewLatentLiabilityGibbs.this.convertListToArray(arrayList2);
        }

        /* JADX INFO: Access modifiers changed from: private */
        public int[] getLatentIndices(NodeRef nodeRef) {
            return getLatentIndices(nodeRef.getNumber());
        }

        /* JADX INFO: Access modifiers changed from: private */
        public int[] getLatentIndices(int i) {
            if (NewLatentLiabilityGibbs.this.missingByColumn.booleanValue()) {
                return this.latentColumns;
            }
            int i2 = NewLatentLiabilityGibbs.this.dim * i;
            ArrayList arrayList = new ArrayList();
            for (int i3 = i2; i3 < i2 + NewLatentLiabilityGibbs.this.dim; i3++) {
                if (NewLatentLiabilityGibbs.this.mask.getParameterValue(i3) == 1.0d) {
                    arrayList.add(Integer.valueOf(i3 - i2));
                }
            }
            return NewLatentLiabilityGibbs.this.convertListToArray(arrayList);
        }

        /* JADX INFO: Access modifiers changed from: private */
        public int[] getObservedIndices(int i) {
            if (NewLatentLiabilityGibbs.this.missingByColumn.booleanValue()) {
                return this.observedColumns;
            }
            int i2 = NewLatentLiabilityGibbs.this.dim * i;
            ArrayList arrayList = new ArrayList();
            for (int i3 = i2; i3 < i2 + NewLatentLiabilityGibbs.this.dim; i3++) {
                if (NewLatentLiabilityGibbs.this.mask.getParameterValue(i3) == 0.0d) {
                    arrayList.add(Integer.valueOf(i3 - i2));
                }
            }
            return NewLatentLiabilityGibbs.this.convertListToArray(arrayList);
        }

        /* JADX INFO: Access modifiers changed from: private */
        public int[] getObservedIndices(NodeRef nodeRef) {
            return getObservedIndices(nodeRef.getNumber());
        }
    }

    public NewLatentLiabilityGibbs(TreeDataLikelihood treeDataLikelihood, LatentTruncation latentTruncation, CompoundParameter compoundParameter, RepeatedMeasuresTraitDataModel repeatedMeasuresTraitDataModel, Parameter parameter, double d, String str, int i, boolean z) {
        this.latentLiability = latentTruncation;
        this.tipTraitParameter = compoundParameter;
        this.treeModel = treeDataLikelihood.getTree();
        this.repeatedMeasuresModel = repeatedMeasuresTraitDataModel;
        ContinuousDataLikelihoodDelegate continuousDataLikelihoodDelegate = (ContinuousDataLikelihoodDelegate) treeDataLikelihood.getDataLikelihoodDelegate();
        this.dim = continuousDataLikelihoodDelegate.getTraitDim();
        String name = WrappedTipFullConditionalDistributionDelegate.getName(str);
        if (treeDataLikelihood.getTreeTrait(name) == null) {
            continuousDataLikelihoodDelegate.addWrappedFullConditionalDensityTrait(str);
        }
        this.fullConditionalDensity = castTreeTrait(treeDataLikelihood.getTreeTrait(name));
        this.missingByColumn = Boolean.valueOf(z);
        this.mask = parameter;
        this.maskDelegate = new MaskIndicesDelegate();
        this.needSampling = setupNeedSampling();
        this.fcdMean = new double[this.dim];
        this.fcdVaraince = new double[this.dim][this.dim];
        this.fcdPrecision = new double[this.dim][this.dim];
        this.maxAttempts = i;
        setWeight(d);
    }

    public int getStepCount() {
        return 1;
    }

    @Override // dr.inference.operators.SimpleMCMCOperator
    public double doOperation() {
        NodeRef externalNode = this.treeModel.getExternalNode(this.needSampling[MathUtils.nextInt(this.needSampling.length)]);
        double sampleNode = sampleNode(externalNode, this.fullConditionalDensity.getTrait(this.treeModel, externalNode).get(0));
        this.tipTraitParameter.fireParameterChangedEvent();
        return sampleNode;
    }

    private double[] getNodeTrait(NodeRef nodeRef) {
        return this.tipTraitParameter.getParameter(nodeRef.getNumber()).getParameterValues();
    }

    private void setNodeTrait(NodeRef nodeRef, double[] dArr) {
        int number = nodeRef.getNumber();
        if (this.mask == null) {
            Parameter parameter = this.tipTraitParameter.getParameter(number);
            for (int i = 0; i < this.dim; i++) {
                parameter.setParameterValueQuietly(i, dArr[i]);
            }
            parameter.fireParameterChangedEvent(-1, Variable.ChangeType.ALL_VALUES_CHANGED);
            return;
        }
        int i2 = 0;
        Parameter parameter2 = this.tipTraitParameter.getParameter(number);
        for (int i3 : this.maskDelegate.getLatentIndices(nodeRef)) {
            parameter2.setParameterValueQuietly(i3, dArr[i2]);
            i2++;
        }
        parameter2.fireParameterChangedEvent(-1, Variable.ChangeType.ALL_VALUES_CHANGED);
    }

    private double sampleNode(NodeRef nodeRef, WrappedNormalSufficientStatistics wrappedNormalSufficientStatistics) {
        MultivariateNormalDistribution multivariateNormalDistribution;
        int number = nodeRef.getNumber();
        int length = this.maskDelegate.getObservedIndices(nodeRef).length;
        if (length == this.dim) {
            return 0.0d;
        }
        WrappedVector mean = wrappedNormalSufficientStatistics.getMean();
        WrappedMatrix precision = wrappedNormalSufficientStatistics.getPrecision();
        double precisionScalar = wrappedNormalSufficientStatistics.getPrecisionScalar();
        for (int i = 0; i < mean.getDim(); i++) {
            this.fcdMean[i] = mean.get(i);
        }
        for (int i2 = 0; i2 < mean.getDim(); i2++) {
            for (int i3 = 0; i3 < mean.getDim(); i3++) {
                this.fcdPrecision[i2][i3] = precision.get(i2, i3) * precisionScalar;
            }
        }
        if (this.repeatedMeasuresModel != null) {
            DenseMatrix64F denseMatrix64F = new DenseMatrix64F(this.fcdPrecision);
            double[] tipPartial = this.repeatedMeasuresModel.getTipPartial(number, false);
            DenseMatrix64F wrap = MissingOps.wrap(tipPartial, this.dim, this.dim, this.dim);
            DenseMatrix64F denseMatrix64F2 = new DenseMatrix64F(this.dim, 1);
            for (int i4 = 0; i4 < this.dim; i4++) {
                denseMatrix64F2.set(i4, 0, this.fcdMean[i4]);
            }
            DenseMatrix64F denseMatrix64F3 = new DenseMatrix64F(this.dim, 1);
            for (int i5 = 0; i5 < this.dim; i5++) {
                denseMatrix64F3.set(i5, 0, tipPartial[i5]);
            }
            DenseMatrix64F denseMatrix64F4 = new DenseMatrix64F(this.dim, 1);
            MatrixVectorMult.mult(denseMatrix64F, denseMatrix64F2, denseMatrix64F4);
            MatrixVectorMult.multAdd(wrap, denseMatrix64F3, denseMatrix64F4);
            CommonOps.addEquals(wrap, denseMatrix64F);
            DenseMatrix64F denseMatrix64F5 = new DenseMatrix64F(this.dim, this.dim);
            CommonOps.invert(wrap, denseMatrix64F5);
            MatrixVectorMult.mult(denseMatrix64F5, denseMatrix64F4, denseMatrix64F3);
            for (int i6 = 0; i6 < this.dim; i6++) {
                this.fcdMean[i6] = denseMatrix64F3.get(i6);
                for (int i7 = 0; i7 < this.dim; i7++) {
                    this.fcdPrecision[i6][i7] = wrap.get(i6, i7);
                }
            }
        }
        MultivariateNormalDistribution multivariateNormalDistribution2 = new MultivariateNormalDistribution(this.fcdMean, this.fcdPrecision);
        if (this.mask == null || length <= 0) {
            multivariateNormalDistribution = multivariateNormalDistribution2;
        } else {
            addMaskOnContiuousTraitsPrecisionSpace(number);
            multivariateNormalDistribution = new MultivariateNormalDistribution(this.maskedMean, this.maskedPrecision);
        }
        double[] nodeTrait = getNodeTrait(nodeRef);
        int i8 = 0;
        boolean z = false;
        while (true) {
            if (!(!z) || !(i8 < this.maxAttempts)) {
                break;
            }
            setNodeTrait(nodeRef, multivariateNormalDistribution.nextMultivariateNormal());
            if (this.latentLiability.validTraitForTip(number)) {
                z = true;
            }
            i8++;
        }
        if (i8 == this.maxAttempts) {
            return Double.NEGATIVE_INFINITY;
        }
        double[] nodeTrait2 = getNodeTrait(nodeRef);
        if (this.latentLiability instanceof DummyLatentTruncationProvider) {
            return Double.POSITIVE_INFINITY;
        }
        return multivariateNormalDistribution2.logPdf(nodeTrait) - multivariateNormalDistribution2.logPdf(nodeTrait2);
    }

    private void addMaskOnContiuousTraitsPrecisionSpace(int i) {
        double[] dArr = new double[this.dim];
        for (int i2 = 0; i2 < dArr.length; i2++) {
            dArr[i2] = this.tipTraitParameter.getParameterValues()[(i * this.dim) + i2];
        }
        ConditionalPrecisionAndTransform conditionalPrecisionAndTransform = new ConditionalPrecisionAndTransform(new Matrix(this.fcdPrecision), this.maskDelegate.getLatentIndices(i), this.maskDelegate.getObservedIndices(i));
        this.maskedPrecision = conditionalPrecisionAndTransform.getConditionalPrecision().toComponents();
        this.maskedMean = conditionalPrecisionAndTransform.getConditionalMean(dArr, 0, this.fcdMean, 0);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public int[] convertListToArray(List<Integer> list) {
        int[] iArr = new int[list.size()];
        int i = 0;
        Iterator<Integer> it = list.iterator();
        while (it.hasNext()) {
            int i2 = i;
            i++;
            iArr[i2] = it.next().intValue();
        }
        return iArr;
    }

    private int[] setupNeedSampling() {
        int externalNodeCount = this.treeModel.getExternalNodeCount();
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < externalNodeCount; i++) {
            if (this.maskDelegate.getObservedIndices(i).length != this.dim) {
                arrayList.add(Integer.valueOf(i));
            }
        }
        return convertListToArray(arrayList);
    }

    private TreeTrait<List<WrappedNormalSufficientStatistics>> castTreeTrait(TreeTrait treeTrait) {
        return treeTrait;
    }

    public String getPerformanceSuggestion() {
        return null;
    }

    @Override // dr.inference.operators.SimpleMCMCOperator, dr.inference.operators.MCMCOperator
    public String getOperatorName() {
        return NEW_LATENT_LIABILITY_GIBBS_OPERATOR;
    }
}
