package dr.evomodel.operators;

import dr.evolution.tree.MutableTreeModel;
import dr.evolution.tree.NodeRef;
import dr.evolution.util.Taxon;
import dr.evomodel.continuous.AbstractMultivariateTraitLikelihood;
import dr.evomodel.continuous.SampledMultivariateTraitLikelihood;
import dr.geo.GeoSpatialCollectionModel;
import dr.geo.GeoSpatialDistribution;
import dr.inference.distribution.MultivariateDistributionLikelihood;
import dr.inference.model.MatrixParameter;
import dr.inference.operators.GibbsOperator;
import dr.inference.operators.MCMCOperator;
import dr.inference.operators.SimpleMCMCOperator;
import dr.math.MathUtils;
import dr.math.distributions.MultivariateDistribution;
import dr.math.distributions.MultivariateNormalDistribution;
import dr.math.matrixAlgebra.SymmetricMatrix;
import dr.xml.AbstractXMLObjectParser;
import dr.xml.AttributeRule;
import dr.xml.ElementRule;
import dr.xml.XMLObject;
import dr.xml.XMLObjectParser;
import dr.xml.XMLParseException;
import dr.xml.XMLSyntaxRule;
import java.util.HashMap;
import java.util.Map;
import java.util.logging.Logger;

/* loaded from: input_file:dr/evomodel/operators/TraitGibbsOperator.class */
public class TraitGibbsOperator extends SimpleMCMCOperator implements GibbsOperator {
    public static final String GIBBS_OPERATOR = "traitGibbsOperator";
    public static final String INTERNAL_ONLY = "onlyInternalNodes";
    public static final String TIP_WITH_PRIORS_ONLY = "onlyTipsWithPriors";
    public static final String NODE_PRIOR = "nodePrior";
    public static final String NODE_LABEL = "taxon";
    public static final String ROOT_PRIOR = "rootPrior";
    private final MutableTreeModel treeModel;
    private final MatrixParameter precisionMatrixParameter;
    private final SampledMultivariateTraitLikelihood traitModel;
    private final int dim;
    private final String traitName;
    private Map<Taxon, GeoSpatialDistribution> nodeGeoSpatialPrior;
    private Map<Taxon, MultivariateNormalDistribution> nodeMVNPrior;
    private boolean onlyInternalNodes;
    private boolean onlyTipsWithPriors;
    private double[] rootPriorMean;
    private double[][] rootPriorPrecision;
    public static XMLObjectParser PARSER = new AbstractXMLObjectParser() { // from class: dr.evomodel.operators.TraitGibbsOperator.1
        private final String[] names = {TraitGibbsOperator.GIBBS_OPERATOR, "internalTraitGibbsOperator"};
        private final XMLSyntaxRule[] rules = {AttributeRule.newDoubleRule("weight"), AttributeRule.newBooleanRule("onlyInternalNodes", true), AttributeRule.newBooleanRule(TraitGibbsOperator.TIP_WITH_PRIORS_ONLY, true), new ElementRule(SampledMultivariateTraitLikelihood.class), new ElementRule(MultivariateDistributionLikelihood.class, 0, Integer.MAX_VALUE), new ElementRule("rootPrior", new XMLSyntaxRule[]{new ElementRule(MultivariateDistributionLikelihood.class)}, true), new ElementRule(GeoSpatialCollectionModel.class, true)};

        @Override // dr.xml.XMLObjectParser
        public String getParserName() {
            return TraitGibbsOperator.GIBBS_OPERATOR;
        }

        @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
        public String[] getParserNames() {
            return this.names;
        }

        @Override // dr.xml.AbstractXMLObjectParser
        public Object parseXMLObject(XMLObject xMLObject) throws XMLParseException {
            double doubleAttribute = xMLObject.getDoubleAttribute("weight");
            boolean booleanValue = ((Boolean) xMLObject.getAttribute("onlyInternalNodes", true)).booleanValue();
            boolean booleanValue2 = ((Boolean) xMLObject.getAttribute(TraitGibbsOperator.TIP_WITH_PRIORS_ONLY, true)).booleanValue();
            SampledMultivariateTraitLikelihood sampledMultivariateTraitLikelihood = (SampledMultivariateTraitLikelihood) xMLObject.getChild(AbstractMultivariateTraitLikelihood.class);
            TraitGibbsOperator traitGibbsOperator = new TraitGibbsOperator(sampledMultivariateTraitLikelihood, booleanValue, booleanValue2);
            traitGibbsOperator.setWeight(doubleAttribute);
            XMLObject child = xMLObject.getChild("rootPrior");
            if (child != null) {
                MultivariateDistributionLikelihood multivariateDistributionLikelihood = (MultivariateDistributionLikelihood) child.getChild(MultivariateDistributionLikelihood.class);
                if (!(multivariateDistributionLikelihood.getDistribution() instanceof MultivariateDistribution)) {
                    throw new XMLParseException("Only multivariate normal priors allowed for Gibbs sampling the root trait");
                }
                traitGibbsOperator.setRootPrior((MultivariateNormalDistribution) multivariateDistributionLikelihood.getDistribution());
            }
            for (int i = 0; i < xMLObject.getChildCount(); i++) {
                if (xMLObject.getChild(i) instanceof MultivariateDistributionLikelihood) {
                    MultivariateDistribution distribution = ((MultivariateDistributionLikelihood) xMLObject.getChild(i)).getDistribution();
                    if (distribution instanceof GeoSpatialDistribution) {
                        GeoSpatialDistribution geoSpatialDistribution = (GeoSpatialDistribution) distribution;
                        Taxon taxon = getTaxon(sampledMultivariateTraitLikelihood.getTreeModel(), geoSpatialDistribution.getLabel());
                        traitGibbsOperator.setTaxonPrior(taxon, geoSpatialDistribution);
                        System.err.println("Adding truncated prior for taxon '" + taxon + "'");
                    }
                }
            }
            GeoSpatialCollectionModel geoSpatialCollectionModel = (GeoSpatialCollectionModel) xMLObject.getChild(GeoSpatialCollectionModel.class);
            if (geoSpatialCollectionModel != null) {
                traitGibbsOperator.setParameterPrior(geoSpatialCollectionModel);
                System.err.println("Adding truncated prior '" + geoSpatialCollectionModel.getId() + "' for parameter '" + geoSpatialCollectionModel.getParameter().getId() + "'");
            }
            return traitGibbsOperator;
        }

        private Taxon getTaxon(MutableTreeModel mutableTreeModel, String str) throws XMLParseException {
            int taxonIndex = mutableTreeModel.getTaxonIndex(str);
            if (taxonIndex == -1) {
                throw new XMLParseException("Taxon '" + str + "' not found for geoSpatialDistribution element in traitGibbsOperator element");
            }
            return mutableTreeModel.getTaxon(taxonIndex);
        }

        @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
        public String getParserDescription() {
            return "This element returns a multivariate Gibbs operator on traits for possible all nodes.";
        }

        @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
        public Class getReturnType() {
            return MCMCOperator.class;
        }

        @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
        public XMLSyntaxRule[] getSyntaxRules() {
            return this.rules;
        }
    };
    private GeoSpatialCollectionModel parameterPrior = null;
    private boolean sampleRoot = false;
    private final int maxTries = 10000;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:dr/evomodel/operators/TraitGibbsOperator$MeanPrecision.class */
    public class MeanPrecision {
        final double[] mean;
        final double[][] precision;

        MeanPrecision(double[] dArr, double[][] dArr2) {
            this.mean = dArr;
            this.precision = dArr2;
        }
    }

    public TraitGibbsOperator(SampledMultivariateTraitLikelihood sampledMultivariateTraitLikelihood, boolean z, boolean z2) {
        this.onlyInternalNodes = true;
        this.onlyTipsWithPriors = true;
        this.traitModel = sampledMultivariateTraitLikelihood;
        this.treeModel = sampledMultivariateTraitLikelihood.getTreeModel();
        this.precisionMatrixParameter = (MatrixParameter) sampledMultivariateTraitLikelihood.getDiffusionModel().getPrecisionParameter();
        this.traitName = sampledMultivariateTraitLikelihood.getTraitName();
        this.onlyInternalNodes = z;
        this.onlyTipsWithPriors = z2;
        this.dim = this.treeModel.getMultivariateNodeTrait(this.treeModel.getRoot(), this.traitName).length;
        Logger.getLogger("dr.evomodel").info("Using *NEW* trait Gibbs operator");
    }

    public void setRootPrior(MultivariateNormalDistribution multivariateNormalDistribution) {
        this.rootPriorMean = multivariateNormalDistribution.getMean();
        this.rootPriorPrecision = multivariateNormalDistribution.getScaleMatrix();
        this.sampleRoot = true;
    }

    public void setTaxonPrior(Taxon taxon, MultivariateDistribution multivariateDistribution) {
        if (multivariateDistribution instanceof GeoSpatialDistribution) {
            if (this.nodeGeoSpatialPrior == null) {
                this.nodeGeoSpatialPrior = new HashMap();
            }
            this.nodeGeoSpatialPrior.put(taxon, (GeoSpatialDistribution) multivariateDistribution);
        } else {
            if (!(multivariateDistribution instanceof MultivariateNormalDistribution)) {
                throw new RuntimeException("Only flat/truncated geospatial and multivariate normal distributions allowed");
            }
            if (this.nodeMVNPrior == null) {
                this.nodeMVNPrior = new HashMap();
            }
            this.nodeMVNPrior.put(taxon, (MultivariateNormalDistribution) multivariateDistribution);
        }
    }

    public void setParameterPrior(GeoSpatialCollectionModel geoSpatialCollectionModel) {
        this.parameterPrior = geoSpatialCollectionModel;
    }

    public int getStepCount() {
        return 1;
    }

    private boolean nodeGeoSpatialPriorExists(NodeRef nodeRef) {
        return this.nodeGeoSpatialPrior != null && this.nodeGeoSpatialPrior.containsKey(this.treeModel.getNodeTaxon(nodeRef));
    }

    private boolean nodeMVNPriorExists(NodeRef nodeRef) {
        return this.nodeMVNPrior != null && this.nodeMVNPrior.containsKey(this.treeModel.getNodeTaxon(nodeRef));
    }

    @Override // dr.inference.operators.SimpleMCMCOperator
    public double doOperation() {
        NodeRef nodeRef = null;
        NodeRef root = this.treeModel.getRoot();
        while (nodeRef == null) {
            if (this.onlyInternalNodes) {
                nodeRef = this.treeModel.getInternalNode(MathUtils.nextInt(this.treeModel.getInternalNodeCount()));
            } else {
                nodeRef = this.treeModel.getNode(MathUtils.nextInt(this.treeModel.getNodeCount()));
                if (this.onlyTipsWithPriors && this.treeModel.getChildCount(nodeRef) == 0 && !nodeGeoSpatialPriorExists(nodeRef)) {
                    nodeRef = null;
                }
            }
            if (!this.sampleRoot && nodeRef == root) {
                nodeRef = null;
            }
        }
        double[] multivariateNodeTrait = this.treeModel.getMultivariateNodeTrait(nodeRef, this.traitName);
        MeanPrecision operateNotRoot = nodeRef != root ? operateNotRoot(nodeRef) : operateRoot(nodeRef);
        Taxon nodeTaxon = this.treeModel.getNodeTaxon(nodeRef);
        boolean nodeGeoSpatialPriorExists = nodeGeoSpatialPriorExists(nodeRef);
        int i = 0;
        boolean z = this.parameterPrior != null;
        while (i <= 10000) {
            double[] nextMultivariateNormalPrecision = MultivariateNormalDistribution.nextMultivariateNormalPrecision(operateNotRoot.mean, operateNotRoot.precision);
            i++;
            if (!nodeGeoSpatialPriorExists || this.nodeGeoSpatialPrior.get(nodeTaxon).logPdf(nextMultivariateNormalPrecision) != Double.NEGATIVE_INFINITY) {
                this.treeModel.setMultivariateTrait(nodeRef, this.traitName, nextMultivariateNormalPrecision);
                if (!z || this.parameterPrior.getLogLikelihood() != Double.NEGATIVE_INFINITY) {
                    return 0.0d;
                }
            }
        }
        this.treeModel.setMultivariateTrait(nodeRef, this.traitName, multivariateNodeTrait);
        throw new RuntimeException("Truncated Gibbs is stuck!");
    }

    private MeanPrecision operateNotRoot(NodeRef nodeRef) {
        double[][] parameterAsMatrix = this.precisionMatrixParameter.getParameterAsMatrix();
        NodeRef parent = this.treeModel.getParent(nodeRef);
        double[] dArr = new double[this.dim];
        double rescaledBranchLengthForPrecision = 1.0d / this.traitModel.getRescaledBranchLengthForPrecision(nodeRef);
        double[] multivariateNodeTrait = this.treeModel.getMultivariateNodeTrait(parent, this.traitName);
        for (int i = 0; i < this.dim; i++) {
            dArr[i] = multivariateNodeTrait[i] * rescaledBranchLengthForPrecision;
        }
        double d = rescaledBranchLengthForPrecision;
        for (int i2 = 0; i2 < this.treeModel.getChildCount(nodeRef); i2++) {
            NodeRef child = this.treeModel.getChild(nodeRef, i2);
            double[] multivariateNodeTrait2 = this.treeModel.getMultivariateNodeTrait(child, this.traitName);
            double rescaledBranchLengthForPrecision2 = 1.0d / this.traitModel.getRescaledBranchLengthForPrecision(child);
            for (int i3 = 0; i3 < this.dim; i3++) {
                int i4 = i3;
                dArr[i4] = dArr[i4] + (multivariateNodeTrait2[i3] * rescaledBranchLengthForPrecision2);
            }
            d += rescaledBranchLengthForPrecision2;
        }
        for (int i5 = 0; i5 < this.dim; i5++) {
            int i6 = i5;
            dArr[i6] = dArr[i6] / d;
            for (int i7 = i5; i7 < this.dim; i7++) {
                double[] dArr2 = parameterAsMatrix[i5];
                int i8 = i7;
                double d2 = dArr2[i8] * d;
                dArr2[i8] = d2;
                parameterAsMatrix[i7][i5] = d2;
            }
        }
        if (nodeMVNPriorExists(nodeRef)) {
            throw new RuntimeException("Still trying to implement multivariate normal taxon priors");
        }
        return new MeanPrecision(dArr, parameterAsMatrix);
    }

    private MeanPrecision operateRoot(NodeRef nodeRef) {
        double d = 0.0d;
        double[] dArr = new double[this.dim];
        double[][] parameterAsMatrix = this.precisionMatrixParameter.getParameterAsMatrix();
        for (int i = 0; i < this.treeModel.getChildCount(nodeRef); i++) {
            NodeRef child = this.treeModel.getChild(nodeRef, i);
            double[] multivariateNodeTrait = this.treeModel.getMultivariateNodeTrait(child, this.traitName);
            double rescaledBranchLengthForPrecision = 1.0d / this.traitModel.getRescaledBranchLengthForPrecision(child);
            for (int i2 = 0; i2 < this.dim; i2++) {
                for (int i3 = 0; i3 < this.dim; i3++) {
                    int i4 = i2;
                    dArr[i4] = dArr[i4] + (parameterAsMatrix[i2][i3] * rescaledBranchLengthForPrecision * multivariateNodeTrait[i3]);
                }
            }
            d += rescaledBranchLengthForPrecision;
        }
        for (int i5 = 0; i5 < this.dim; i5++) {
            for (int i6 = 0; i6 < this.dim; i6++) {
                int i7 = i5;
                dArr[i7] = dArr[i7] + (this.rootPriorPrecision[i5][i6] * this.rootPriorMean[i6]);
                parameterAsMatrix[i5][i6] = (parameterAsMatrix[i5][i6] * d) + this.rootPriorPrecision[i5][i6];
            }
        }
        double[][] components = new SymmetricMatrix(parameterAsMatrix).inverse().toComponents();
        double[] dArr2 = new double[this.dim];
        for (int i8 = 0; i8 < this.dim; i8++) {
            for (int i9 = 0; i9 < this.dim; i9++) {
                int i10 = i8;
                dArr2[i10] = dArr2[i10] + (components[i8][i9] * dArr[i9]);
            }
        }
        return new MeanPrecision(dArr2, parameterAsMatrix);
    }

    public String getPerformanceSuggestion() {
        return null;
    }

    @Override // dr.inference.operators.SimpleMCMCOperator, dr.inference.operators.MCMCOperator
    public String getOperatorName() {
        return GIBBS_OPERATOR;
    }
}
