package dr.evomodel.treedatalikelihood.continuous;

import dr.evolution.tree.NodeRef;
import dr.evolution.tree.TreeTrait;
import dr.evomodel.treedatalikelihood.DataLikelihoodDelegate;
import dr.evomodel.treedatalikelihood.LikelihoodTreeTraversal;
import dr.evomodel.treedatalikelihood.ProcessOnTreeDelegate;
import dr.evomodel.treedatalikelihood.TreeDataLikelihood;
import dr.evomodel.treedatalikelihood.TreeTraversal;
import dr.evomodel.treedatalikelihood.continuous.cdi.ContinuousDiffusionIntegrator;
import dr.evomodel.treedatalikelihood.continuous.cdi.MultivariateIntegrator;
import dr.evomodel.treedatalikelihood.continuous.cdi.PrecisionType;
import dr.evomodel.treedatalikelihood.preorder.AbstractRealizedContinuousTraitDelegate;
import dr.inference.loggers.LogColumn;
import dr.inference.loggers.Loggable;
import dr.inference.loggers.NumberColumn;
import dr.inference.model.AbstractModel;
import dr.inference.model.MatrixParameterInterface;
import dr.inference.model.Model;
import dr.inference.model.Variable;
import dr.math.distributions.WishartSufficientStatistics;
import dr.math.interfaces.ConjugateWishartStatisticsProvider;
import dr.xml.AbstractXMLObjectParser;
import dr.xml.AttributeRule;
import dr.xml.ElementRule;
import dr.xml.XMLObject;
import dr.xml.XMLObjectParser;
import dr.xml.XMLParseException;
import dr.xml.XMLSyntaxRule;
import java.util.List;

/* loaded from: input_file:dr/evomodel/treedatalikelihood/continuous/WishartStatisticsWrapper.class */
public class WishartStatisticsWrapper extends AbstractModel implements ConjugateWishartStatisticsProvider, Loggable {
    public static final String PARSER_NAME = "wishartStatistics";
    public static final String TRAIT_NAME = "traitName";
    public static XMLObjectParser PARSER;
    private final LikelihoodTreeTraversal treeTraversalDelegate;
    private final TreeTrait tipSampleTrait;
    private final int dimTrait;
    private final int numTrait;
    private final int tipCount;
    private final int dimPartial;
    private final ContinuousDataLikelihoodDelegate likelihoodDelegate;
    private final ContinuousDataLikelihoodDelegate outerProductDelegate;
    private final TreeDataLikelihood dataLikelihood;
    private boolean traitDataKnown;
    private boolean outerProductsKnown;
    private boolean savedTraitDataKnown;
    private boolean savedOuterProductsKnown;
    private WishartSufficientStatistics wishartStatistics;
    private WishartSufficientStatistics savedWishartStatistics;
    private static final boolean DEBUG = false;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:dr/evomodel/treedatalikelihood/continuous/WishartStatisticsWrapper$OuterProductColumn.class */
    private class OuterProductColumn extends NumberColumn {
        private int index;

        private OuterProductColumn(String str, int i) {
            super(str);
            this.index = i;
        }

        @Override // dr.inference.loggers.NumberColumn
        public double getDoubleValue() {
            return WishartStatisticsWrapper.this.getWishartStatistics().getScaleMatrix()[this.index];
        }
    }

    /* loaded from: input_file:dr/evomodel/treedatalikelihood/continuous/WishartStatisticsWrapper$TipSampleColumn.class */
    private class TipSampleColumn extends NumberColumn {
        private int index;

        private TipSampleColumn(String str, int i) {
            super(str);
            this.index = i;
        }

        @Override // dr.inference.loggers.NumberColumn
        public double getDoubleValue() {
            return ((double[]) WishartStatisticsWrapper.this.tipSampleTrait.getTrait(WishartStatisticsWrapper.this.dataLikelihood.getTree(), null))[this.index];
        }
    }

    public WishartStatisticsWrapper(String str, String str2, TreeDataLikelihood treeDataLikelihood, ContinuousDataLikelihoodDelegate continuousDataLikelihoodDelegate) {
        super(str);
        this.dataLikelihood = treeDataLikelihood;
        this.likelihoodDelegate = continuousDataLikelihoodDelegate;
        this.dimTrait = continuousDataLikelihoodDelegate.getTraitDim();
        this.numTrait = continuousDataLikelihoodDelegate.getTraitCount();
        this.tipCount = treeDataLikelihood.getTree().getExternalNodeCount();
        this.dimPartial = this.dimTrait + 1;
        addModel(treeDataLikelihood);
        this.tipSampleTrait = treeDataLikelihood.getTreeTrait(AbstractRealizedContinuousTraitDelegate.getTipTraitName(str2));
        this.treeTraversalDelegate = new LikelihoodTreeTraversal(treeDataLikelihood.getTree(), treeDataLikelihood.getBranchRateModel(), TreeTraversal.TraversalType.POST_ORDER);
        if (continuousDataLikelihoodDelegate.getIntegrator() instanceof MultivariateIntegrator) {
            ContinuousTraitPartialsProvider dataModel = continuousDataLikelihoodDelegate.getDataModel();
            this.outerProductDelegate = ContinuousDataLikelihoodDelegate.createObservedDataOnly(continuousDataLikelihoodDelegate, dataModel instanceof RepeatedMeasuresTraitDataModel ? new EmptyTraitDataModel(str2, dataModel.getParameter(), dataModel.getTraitDimension(), PrecisionType.SCALAR) : dataModel);
        } else {
            this.outerProductDelegate = continuousDataLikelihoodDelegate;
        }
        this.traitDataKnown = false;
        this.outerProductsKnown = false;
    }

    @Override // dr.math.interfaces.ConjugateWishartStatisticsProvider
    public WishartSufficientStatistics getWishartStatistics() {
        if (!this.outerProductsKnown) {
            computeOuterProducts();
            this.outerProductsKnown = true;
        }
        return this.wishartStatistics;
    }

    public void simulateMissingTraits() {
        this.likelihoodDelegate.fireModelChanged();
        double[] dArr = (double[]) this.tipSampleTrait.getTrait(this.dataLikelihood.getTree(), null);
        ContinuousDiffusionIntegrator integrator = this.outerProductDelegate.getIntegrator();
        if (!$assertionsDisabled && !(integrator instanceof ContinuousDiffusionIntegrator.Basic)) {
            throw new AssertionError();
        }
        double[] dArr2 = new double[this.dimPartial * this.numTrait];
        for (int i = 0; i < this.numTrait; i++) {
            dArr2[(i * this.dimPartial) + this.dimTrait] = Double.POSITIVE_INFINITY;
        }
        for (int i2 = 0; i2 < this.tipCount; i2++) {
            int i3 = i2 * this.dimTrait * this.numTrait;
            int i4 = 0;
            for (int i5 = 0; i5 < this.numTrait; i5++) {
                System.arraycopy(dArr, i3, dArr2, i4, this.dimTrait);
                i3 += this.dimTrait;
                i4 += this.dimPartial;
            }
            this.outerProductDelegate.setTipDataDirectly(i2, dArr2);
        }
    }

    private void computeOuterProducts() {
        this.dataLikelihood.getLogLikelihood();
        if (this.likelihoodDelegate != this.outerProductDelegate) {
            simulateMissingTraits();
        }
        this.treeTraversalDelegate.updateAllNodes();
        this.treeTraversalDelegate.dispatchTreeTraversalCollectBranchAndNodeOperations();
        List<ProcessOnTreeDelegate.BranchOperation> branchOperations = this.treeTraversalDelegate.getBranchOperations();
        List<ProcessOnTreeDelegate.NodeOperation> nodeOperations = this.treeTraversalDelegate.getNodeOperations();
        NodeRef root = this.dataLikelihood.getTree().getRoot();
        this.outerProductDelegate.setComputeWishartStatistics(true);
        this.outerProductDelegate.calculateLikelihood(branchOperations, nodeOperations, root.getNumber());
        this.outerProductDelegate.setComputeWishartStatistics(false);
        this.wishartStatistics = this.outerProductDelegate.getWishartStatistics();
    }

    @Override // dr.math.interfaces.ConjugateWishartStatisticsProvider
    public MatrixParameterInterface getPrecisionParameter() {
        return this.likelihoodDelegate.getDiffusionModel().getPrecisionParameter();
    }

    @Override // dr.inference.model.AbstractModel
    protected void storeState() {
        this.savedTraitDataKnown = this.traitDataKnown;
        this.savedOuterProductsKnown = this.outerProductsKnown;
        if (this.outerProductsKnown) {
            if (this.savedWishartStatistics == null) {
                this.savedWishartStatistics = this.wishartStatistics.m1028clone();
            } else {
                this.wishartStatistics.copyTo(this.savedWishartStatistics);
            }
        }
    }

    @Override // dr.inference.model.AbstractModel
    protected void restoreState() {
        this.traitDataKnown = this.savedTraitDataKnown;
        this.outerProductsKnown = this.savedOuterProductsKnown;
        if (this.outerProductsKnown) {
            WishartSufficientStatistics wishartSufficientStatistics = this.wishartStatistics;
            this.wishartStatistics = this.savedWishartStatistics;
            this.savedWishartStatistics = wishartSufficientStatistics;
        }
    }

    @Override // dr.inference.model.AbstractModel
    protected void acceptState() {
    }

    @Override // dr.inference.model.AbstractModel
    protected void handleVariableChangedEvent(Variable variable, int i, Variable.ChangeType changeType) {
        this.outerProductsKnown = false;
    }

    @Override // dr.inference.model.AbstractModel
    protected void handleModelChangedEvent(Model model, Object obj, int i) {
        this.outerProductsKnown = false;
    }

    @Override // dr.inference.loggers.Loggable
    public LogColumn[] getColumns() {
        int length = this.tipSampleTrait != null ? ((double[]) this.tipSampleTrait.getTrait(this.dataLikelihood.getTree(), null)).length : 0;
        LogColumn[] logColumnArr = new LogColumn[(this.dimTrait * this.dimTrait) + length];
        int i = 0;
        for (int i2 = 0; i2 < this.dimTrait; i2++) {
            for (int i3 = 0; i3 < this.dimTrait; i3++) {
                logColumnArr[i] = new OuterProductColumn("OP" + (i2 + 1) + "" + (i3 + 1), i);
                i++;
            }
        }
        for (int i4 = 0; i4 < length; i4++) {
            logColumnArr[i] = new TipSampleColumn("TIP" + (i4 + 1), i4);
            i++;
        }
        return logColumnArr;
    }

    static {
        $assertionsDisabled = !WishartStatisticsWrapper.class.desiredAssertionStatus();
        PARSER = new AbstractXMLObjectParser() { // from class: dr.evomodel.treedatalikelihood.continuous.WishartStatisticsWrapper.1
            private final XMLSyntaxRule[] syntax = {new ElementRule(TreeDataLikelihood.class), AttributeRule.newStringRule("traitName", true)};

            @Override // dr.xml.AbstractXMLObjectParser
            public Object parseXMLObject(XMLObject xMLObject) throws XMLParseException {
                String id = xMLObject.hasId() ? xMLObject.getId() : WishartStatisticsWrapper.PARSER_NAME;
                String str = (String) xMLObject.getAttribute("traitName", "trait");
                TreeDataLikelihood treeDataLikelihood = (TreeDataLikelihood) xMLObject.getChild(TreeDataLikelihood.class);
                DataLikelihoodDelegate dataLikelihoodDelegate = treeDataLikelihood.getDataLikelihoodDelegate();
                if (dataLikelihoodDelegate instanceof ContinuousDataLikelihoodDelegate) {
                    return new WishartStatisticsWrapper(id, str, treeDataLikelihood, (ContinuousDataLikelihoodDelegate) dataLikelihoodDelegate);
                }
                throw new XMLParseException("May not provide a sequence data likelihood in the precision Gibbs sampler");
            }

            @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
            public XMLSyntaxRule[] getSyntaxRules() {
                return this.syntax;
            }

            @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
            public String getParserDescription() {
                return null;
            }

            @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
            public Class getReturnType() {
                return WishartStatisticsWrapper.class;
            }

            @Override // dr.xml.XMLObjectParser
            public String getParserName() {
                return WishartStatisticsWrapper.PARSER_NAME;
            }
        };
    }
}
