package dr.evomodel.arg.likelihood;

import dr.evolution.alignment.PatternList;
import dr.evolution.datatype.AminoAcids;
import dr.evolution.datatype.Codons;
import dr.evolution.datatype.Nucleotides;
import dr.evolution.datatype.OldHiddenNucleotides;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.util.TaxonList;
import dr.evomodel.arg.ARGModel;
import dr.evomodel.arg.ARGTree;
import dr.evomodel.arg.operators.ARGPartitioningOperator;
import dr.evomodel.branchratemodel.BranchRateModel;
import dr.evomodel.branchratemodel.DefaultBranchRateModel;
import dr.inference.model.Likelihood;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.oldevomodel.sitemodel.SiteModel;
import dr.oldevomodel.substmodel.FrequencyModel;
import dr.oldevomodel.treelikelihood.AminoAcidLikelihoodCore;
import dr.oldevomodel.treelikelihood.GeneralLikelihoodCore;
import dr.oldevomodel.treelikelihood.LikelihoodCore;
import dr.oldevomodel.treelikelihood.NativeAminoAcidLikelihoodCore;
import dr.oldevomodel.treelikelihood.NativeCovarionLikelihoodCore;
import dr.oldevomodel.treelikelihood.NativeNucleotideLikelihoodCore;
import dr.oldevomodel.treelikelihood.NucleotideLikelihoodCore;
import dr.xml.AbstractXMLObjectParser;
import dr.xml.AttributeRule;
import dr.xml.ElementRule;
import dr.xml.XMLObject;
import dr.xml.XMLObjectParser;
import dr.xml.XMLParseException;
import dr.xml.XMLSyntaxRule;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.logging.Logger;

/* loaded from: input_file:dr/evomodel/arg/likelihood/ARGLikelihood.class */
public class ARGLikelihood extends AbstractARGLikelihood {
    public static final String ARG_LIKELIHOOD = "argTreeLikelihood";
    public static final String USE_AMBIGUITIES = "useAmbiguities";
    public static final String STORE_PARTIALS = "storePartials";
    public static final String USE_SCALING = "useScaling";
    private static final boolean NO_CACHING = false;
    private Set<NodeRef> unsetNodes;
    public static XMLObjectParser PARSER = new AbstractXMLObjectParser() { // from class: dr.evomodel.arg.likelihood.ARGLikelihood.1
        private final XMLSyntaxRule[] rules = {AttributeRule.newBooleanRule("useAmbiguities", true), AttributeRule.newBooleanRule("storePartials", true), AttributeRule.newBooleanRule("useScaling", true), new ElementRule(PatternList.class), new ElementRule(ARGModel.class), new ElementRule(SiteModel.class), new ElementRule(BranchRateModel.class, true)};

        @Override // dr.xml.XMLObjectParser
        public String getParserName() {
            return ARGLikelihood.ARG_LIKELIHOOD;
        }

        @Override // dr.xml.AbstractXMLObjectParser
        public Object parseXMLObject(XMLObject xMLObject) throws XMLParseException {
            boolean z = false;
            boolean z2 = true;
            boolean z3 = false;
            if (xMLObject.hasAttribute("useAmbiguities")) {
                z = xMLObject.getBooleanAttribute("useAmbiguities");
            }
            if (xMLObject.hasAttribute("storePartials")) {
                z2 = xMLObject.getBooleanAttribute("storePartials");
            }
            if (xMLObject.hasAttribute("useScaling")) {
                z3 = xMLObject.getBooleanAttribute("useScaling");
            }
            return new ARGLikelihood((PatternList) xMLObject.getChild(PatternList.class), (ARGModel) xMLObject.getChild(ARGModel.class), (SiteModel) xMLObject.getChild(SiteModel.class), (BranchRateModel) xMLObject.getChild(BranchRateModel.class), z, z2, z3);
        }

        @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
        public String getParserDescription() {
            return "This element represents the likelihood of a patternlist on a tree given the site model.";
        }

        @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
        public Class getReturnType() {
            return Likelihood.class;
        }

        @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
        public XMLSyntaxRule[] getSyntaxRules() {
            return this.rules;
        }
    };
    protected FrequencyModel frequencyModel;
    protected SiteModel siteModel;
    protected BranchRateModel branchRateModel;
    private boolean storePartials;
    private boolean integrateAcrossCategories;
    protected int[] siteCategories;
    protected double[] rootPartials;
    protected double[] patternLogLikelihoods;
    protected int categoryCount;
    protected double[] probabilities;
    protected LikelihoodCore likelihoodCore;
    private boolean useAmbiguities;
    private boolean reconstructTree;
    private ARGTree tree;
    private ARGTree oldTree;
    private Map<NodeRef, Integer> mapARGNodesToInts;
    private Map<NodeRef, Integer> oldMapARGNodesToInts;
    private Map<NodeRef, NodeRef> mapARGNodesToTreeNodes;
    private static final boolean DEBUG = true;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:dr/evomodel/arg/likelihood/ARGLikelihood$NegativeBranchLengthException.class */
    public class NegativeBranchLengthException extends Exception {
        NegativeBranchLengthException() {
        }
    }

    public ARGLikelihood(PatternList patternList, ARGModel aRGModel, SiteModel siteModel, BranchRateModel branchRateModel, boolean z, boolean z2, boolean z3) {
        super(ARG_LIKELIHOOD, patternList, aRGModel);
        this.unsetNodes = null;
        this.frequencyModel = null;
        this.siteModel = null;
        this.branchRateModel = null;
        this.storePartials = false;
        this.integrateAcrossCategories = false;
        this.siteCategories = null;
        this.rootPartials = null;
        this.patternLogLikelihoods = null;
        this.reconstructTree = true;
        this.tree = null;
        this.mapARGNodesToInts = null;
        this.mapARGNodesToTreeNodes = null;
        this.partition = aRGModel.addLikelihoodCalculator(this);
        this.storePartials = z2;
        this.useAmbiguities = z;
        try {
            this.siteModel = siteModel;
            addModel(siteModel);
            this.frequencyModel = siteModel.getFrequencyModel();
            addModel(this.frequencyModel);
            this.integrateAcrossCategories = siteModel.integrateAcrossCategories();
            this.categoryCount = siteModel.getCategoryCount();
            if (!this.integrateAcrossCategories) {
                Logger.getLogger("dr.evomodel").info("TreeLikelihood using Java general likelihood core");
                this.likelihoodCore = new GeneralLikelihoodCore(patternList.getStateCount());
            } else if (patternList.getDataType() instanceof Nucleotides) {
                if (NativeNucleotideLikelihoodCore.isAvailable()) {
                    Logger.getLogger("dr.evomodel").info("TreeLikelihood using native nucleotide likelihood core");
                    this.likelihoodCore = new NativeNucleotideLikelihoodCore();
                } else {
                    Logger.getLogger("dr.evomodel").info("TreeLikelihood using Java nucleotide likelihood core");
                    this.likelihoodCore = new NucleotideLikelihoodCore();
                }
            } else if (!(patternList.getDataType() instanceof AminoAcids)) {
                if (patternList.getDataType() instanceof Codons) {
                    Logger.getLogger("dr.evomodel").info("TreeLikelihood using Java codon likelihood core");
                    this.useAmbiguities = true;
                    throw new RuntimeException("Still need to merge codon likelihood core");
                }
                if ((patternList.getDataType() instanceof OldHiddenNucleotides) && NativeCovarionLikelihoodCore.isAvailable()) {
                    Logger.getLogger("dr.evomodel").info("TreeLikelihood using native covarion likelihood core");
                    this.likelihoodCore = new NativeCovarionLikelihoodCore();
                } else {
                    Logger.getLogger("dr.evomodel").info("TreeLikelihood using Java general likelihood core");
                    this.likelihoodCore = new GeneralLikelihoodCore(patternList.getStateCount());
                }
            } else if (NativeAminoAcidLikelihoodCore.isAvailable()) {
                Logger.getLogger("dr.evomodel").info("TreeLikelihood using native amino acid likelihood core");
                this.likelihoodCore = new NativeAminoAcidLikelihoodCore();
            } else {
                Logger.getLogger("dr.evomodel").info("TreeLikelihood using java likelihood core");
                this.likelihoodCore = new AminoAcidLikelihoodCore();
            }
            Logger.getLogger("dr.evomodel").info("  " + (z ? "Using" : "Ignoring") + " ambiguities in tree likelihood.");
            Logger.getLogger("dr.evomodel").info("  Partial likelihood scaling " + (z3 ? "on." : "off."));
            if (branchRateModel != null) {
                this.branchRateModel = branchRateModel;
                Logger.getLogger("dr.evomodel").info("Branch rate model used: " + branchRateModel.getModelName());
            } else {
                this.branchRateModel = new DefaultBranchRateModel();
            }
            addModel(this.branchRateModel);
            this.probabilities = new double[this.stateCount * this.stateCount];
            this.likelihoodCore.initialize(this.nodeCount, this.patternCount, this.categoryCount, this.integrateAcrossCategories);
            int externalNodeCount = aRGModel.getExternalNodeCount();
            int internalNodeCount = aRGModel.getInternalNodeCount();
            for (int i = 0; i < externalNodeCount; i++) {
                String taxonId = aRGModel.getTaxonId(i);
                int taxonIndex = patternList.getTaxonIndex(taxonId);
                if (taxonIndex == -1) {
                    throw new TaxonList.MissingTaxonException("Taxon, " + taxonId + ", in tree, " + aRGModel.getId() + ", is not found in patternList, " + patternList.getId());
                }
                if (z) {
                    setPartials(this.likelihoodCore, patternList, this.categoryCount, taxonIndex, i);
                } else {
                    setStates(this.likelihoodCore, patternList, taxonIndex, i);
                }
            }
            for (int i2 = 0; i2 < internalNodeCount; i2++) {
                this.likelihoodCore.createNodePartials(externalNodeCount + i2);
            }
        } catch (TaxonList.MissingTaxonException e) {
            throw new RuntimeException(e.toString());
        }
    }

    @Override // dr.evomodel.arg.likelihood.AbstractARGLikelihood, dr.inference.model.AbstractModel
    protected void handleModelChangedEvent(Model model, Object obj, int i) {
        if (model == this.treeModel) {
            if (obj instanceof ARGModel.ARGTreeChangedEvent) {
                ARGModel.ARGTreeChangedEvent aRGTreeChangedEvent = (ARGModel.ARGTreeChangedEvent) obj;
                if (aRGTreeChangedEvent.isSizeChanged()) {
                    updateAllNodes();
                    this.reconstructTree = true;
                } else if (aRGTreeChangedEvent.isNodeChanged()) {
                    NodeRef nodeRef = this.mapARGNodesToTreeNodes.get(aRGTreeChangedEvent.getNode());
                    if (nodeRef != null) {
                        if (aRGTreeChangedEvent.isHeightChanged() || aRGTreeChangedEvent.isRateChanged()) {
                            updateNodeAndChildren(nodeRef);
                        } else {
                            this.reconstructTree = true;
                            updateAllNodes();
                        }
                    }
                } else {
                    if (!aRGTreeChangedEvent.isTreeChanged()) {
                        throw new RuntimeException("Another tree event has occured (possibly a trait change).");
                    }
                    this.reconstructTree = true;
                    updateAllNodes();
                }
            } else if (obj instanceof ARGPartitioningOperator.PartitionChangedEvent) {
                if (((ARGPartitioningOperator.PartitionChangedEvent) obj).getUpdatedPartitions()[this.partition]) {
                    this.reconstructTree = true;
                    updateAllNodes();
                }
            } else if (!(obj instanceof Parameter)) {
                throw new RuntimeException("Unexpected ARGModel update " + obj.getClass());
            }
        } else if (model == this.branchRateModel) {
            updateAllNodes();
        } else if (model == this.frequencyModel) {
            updateAllNodes();
        } else {
            if (!(model instanceof SiteModel)) {
                throw new RuntimeException("Unknown componentChangedEvent");
            }
            updateAllNodes();
        }
        super.handleModelChangedEvent(model, obj, i);
    }

    @Override // dr.evomodel.arg.likelihood.AbstractARGLikelihood, dr.inference.model.AbstractModel
    protected void storeState() {
        if (this.storePartials) {
            this.likelihoodCore.storeState();
        }
        super.storeState();
    }

    @Override // dr.evomodel.arg.likelihood.AbstractARGLikelihood, dr.inference.model.AbstractModel
    protected void restoreState() {
        if (this.storePartials) {
            this.likelihoodCore.restoreState();
        } else {
            updateAllNodes();
        }
        this.reconstructTree = true;
        super.restoreState();
    }

    private int getUnusedInt(Map<NodeRef, Integer> map) {
        Collection<Integer> values = map.values();
        int externalNodeCount = this.tree.getExternalNodeCount();
        while (values.contains(Integer.valueOf(externalNodeCount))) {
            externalNodeCount++;
        }
        return externalNodeCount;
    }

    private void reconstructTree() {
        this.oldTree = this.tree;
        this.oldMapARGNodesToInts = this.mapARGNodesToInts;
        this.tree = new ARGTree(this.treeModel, this.partition);
        this.reconstructTree = false;
        this.mapARGNodesToInts = new HashMap(this.tree.getInternalNodeCount());
        this.mapARGNodesToTreeNodes = this.tree.getMapping();
        if (this.oldTree == null) {
            for (int i = 0; i < this.tree.getInternalNodeCount(); i++) {
                NodeRef internalNode = this.tree.getInternalNode(i);
                this.mapARGNodesToInts.put(this.treeModel.getMirrorNode(internalNode), Integer.valueOf(internalNode.getNumber()));
            }
            return;
        }
        if (this.unsetNodes == null) {
            this.unsetNodes = new HashSet();
        } else {
            this.unsetNodes.clear();
        }
        for (int i2 = 0; i2 < this.tree.getInternalNodeCount(); i2++) {
            NodeRef internalNode2 = this.tree.getInternalNode(i2);
            NodeRef mirrorNode = this.treeModel.getMirrorNode(internalNode2);
            if (this.oldMapARGNodesToInts.containsKey(mirrorNode)) {
                int intValue = this.oldMapARGNodesToInts.get(mirrorNode).intValue();
                this.treeModel.setNodeNumber(internalNode2, intValue);
                this.mapARGNodesToInts.put(mirrorNode, Integer.valueOf(intValue));
            } else {
                this.unsetNodes.add(internalNode2);
            }
        }
        for (NodeRef nodeRef : this.unsetNodes) {
            int unusedInt = getUnusedInt(this.mapARGNodesToInts);
            this.treeModel.setNodeNumber(nodeRef, unusedInt);
            this.mapARGNodesToInts.put(nodeRef, Integer.valueOf(unusedInt));
            this.updateNode[unusedInt] = true;
        }
    }

    @Override // dr.evomodel.arg.likelihood.AbstractARGLikelihood
    protected double calculateLogLikelihood() {
        if (this.reconstructTree) {
            reconstructTree();
        }
        NodeRef root = this.tree.getRoot();
        if (this.rootPartials == null) {
            this.rootPartials = new double[this.patternCount * this.stateCount];
        }
        if (this.patternLogLikelihoods == null) {
            this.patternLogLikelihoods = new double[this.patternCount];
        }
        if (!this.integrateAcrossCategories) {
            if (this.siteCategories == null) {
                this.siteCategories = new int[this.patternCount];
            }
            for (int i = 0; i < this.patternCount; i++) {
                this.siteCategories[i] = this.siteModel.getCategoryOfSite(i);
            }
        }
        try {
            traverse(this.tree, root);
            for (int i2 = 0; i2 < this.nodeCount; i2++) {
                this.updateNode[i2] = false;
            }
            double d = 0.0d;
            for (int i3 = 0; i3 < this.patternCount; i3++) {
                d += this.patternLogLikelihoods[i3] * this.patternWeights[i3];
            }
            return d;
        } catch (NegativeBranchLengthException e) {
            System.err.println("Negative branch length found, trying to return 0 likelihood");
            return Double.NEGATIVE_INFINITY;
        }
    }

    private boolean traverse(Tree tree, NodeRef nodeRef) throws NegativeBranchLengthException {
        boolean z = false;
        int number = nodeRef.getNumber();
        NodeRef parent = tree.getParent(nodeRef);
        if (parent != null && this.updateNode[number]) {
            double branchRate = this.branchRateModel.getBranchRate(tree, nodeRef) * (tree.getNodeHeight(parent) - tree.getNodeHeight(nodeRef));
            if (branchRate < 0.0d) {
                throw new NegativeBranchLengthException();
            }
            for (int i = 0; i < this.categoryCount; i++) {
                this.siteModel.getSubstitutionModel().getTransitionProbabilities(this.siteModel.getRateForCategory(i) * branchRate, this.probabilities);
                this.likelihoodCore.setNodeMatrix(number, i, this.probabilities);
            }
            z = true;
        }
        if (!tree.isExternal(nodeRef)) {
            if (tree.getChildCount(nodeRef) != 2) {
                throw new RuntimeException("binary trees only!");
            }
            NodeRef child = tree.getChild(nodeRef, 0);
            boolean traverse = traverse(tree, child);
            NodeRef child2 = tree.getChild(nodeRef, 1);
            boolean traverse2 = traverse(tree, child2);
            if (traverse || traverse2) {
                int number2 = child.getNumber();
                int number3 = child2.getNumber();
                if (this.integrateAcrossCategories) {
                    this.likelihoodCore.calculatePartials(number2, number3, number);
                } else {
                    this.likelihoodCore.calculatePartials(number2, number3, number, this.siteCategories);
                }
                if (parent == null) {
                    double[] frequencies = this.frequencyModel.getFrequencies();
                    if (this.integrateAcrossCategories) {
                        this.likelihoodCore.integratePartials(number, this.siteModel.getCategoryProportions(), this.rootPartials);
                    } else {
                        this.likelihoodCore.getPartials(number, this.rootPartials);
                    }
                    this.likelihoodCore.calculateLogLikelihoods(this.rootPartials, frequencies, this.patternLogLikelihoods);
                }
                z = true;
            }
        }
        return z;
    }
}
