package dr.oldevomodel.treelikelihood;

import dr.evolution.alignment.PatternList;
import dr.evolution.datatype.AminoAcids;
import dr.evolution.datatype.Codons;
import dr.evolution.datatype.Nucleotides;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeUtils;
import dr.evolution.util.Taxon;
import dr.evolution.util.TaxonList;
import dr.evomodel.branchratemodel.BranchRateModel;
import dr.evomodel.branchratemodel.DefaultBranchRateModel;
import dr.evomodel.tree.TreeChangedEvent;
import dr.evomodel.tree.TreeModel;
import dr.evoxml.util.GraphMLUtils;
import dr.inference.model.Likelihood;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.oldevomodel.sitemodel.SiteModel;
import dr.oldevomodel.substmodel.FrequencyModel;
import dr.xml.AbstractXMLObjectParser;
import dr.xml.AttributeRule;
import dr.xml.ElementRule;
import dr.xml.XMLObject;
import dr.xml.XMLObjectParser;
import dr.xml.XMLParseException;
import dr.xml.XMLSyntaxRule;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Set;
import java.util.logging.Logger;

@Deprecated
/* loaded from: input_file:dr/oldevomodel/treelikelihood/AdvancedTreeLikelihood.class */
public class AdvancedTreeLikelihood extends AbstractTreeLikelihood {
    public static final String ADVANCED_TREE_LIKELIHOOD = "advancedTreeLikelihood";
    public static final String CLADE = "clade";
    public static final String INCLUDE_STEM = "includeStem";
    public static final String TIPS = "tips";
    public static final String DELTA = "delta";
    public static final String USE_AMBIGUITIES = "useAmbiguities";
    public static final String STORE_PARTIALS = "storePartials";
    public static final String USE_SCALING = "useScaling";
    public static XMLObjectParser PARSER = new AbstractXMLObjectParser() { // from class: dr.oldevomodel.treelikelihood.AdvancedTreeLikelihood.1
        private final XMLSyntaxRule[] rules = {AttributeRule.newBooleanRule("useAmbiguities", true), AttributeRule.newBooleanRule("useScaling", true), new ElementRule("tips", SiteModel.class, "A siteModel that will be applied only to the tips.", 0, 1), new ElementRule("delta", new XMLSyntaxRule[]{new ElementRule(TaxonList.class, "A set of taxa to which to apply the delta model to", 0, 1), new ElementRule(Parameter.class, "A parameter that specifies the amount of extra substitutions per site at each tip.", 0, 1)}, true), new ElementRule("clade", new XMLSyntaxRule[]{AttributeRule.newBooleanRule("includeStem", true, "determines whether or not the stem branch above this clade is included in the siteModel."), new ElementRule(TaxonList.class, "A set of taxa which defines a clade to apply a different site model to"), new ElementRule(SiteModel.class, "A siteModel that will be applied only to this clade")}, 0, Integer.MAX_VALUE), new ElementRule(PatternList.class), new ElementRule(TreeModel.class), new ElementRule(SiteModel.class), new ElementRule(BranchRateModel.class, true)};

        @Override // dr.xml.XMLObjectParser
        public String getParserName() {
            return AdvancedTreeLikelihood.ADVANCED_TREE_LIKELIHOOD;
        }

        @Override // dr.xml.AbstractXMLObjectParser
        public Object parseXMLObject(XMLObject xMLObject) throws XMLParseException {
            AdvancedTreeLikelihood advancedTreeLikelihood = new AdvancedTreeLikelihood((PatternList) xMLObject.getChild(PatternList.class), (TreeModel) xMLObject.getChild(TreeModel.class), (SiteModel) xMLObject.getChild(SiteModel.class), (BranchRateModel) xMLObject.getChild(BranchRateModel.class), ((Boolean) xMLObject.getAttribute("useAmbiguities", false)).booleanValue(), ((Boolean) xMLObject.getAttribute("useScaling", false)).booleanValue());
            if (xMLObject.hasChildNamed("tips")) {
                advancedTreeLikelihood.addTipsSiteModel((SiteModel) xMLObject.getElementFirstChild("tips"));
            }
            XMLObject child = xMLObject.getChild("delta");
            if (child != null) {
                advancedTreeLikelihood.addDeltaParameter((Parameter) child.getChild(Parameter.class), (TaxonList) child.getChild(TaxonList.class));
            }
            for (int i = 0; i < xMLObject.getChildCount(); i++) {
                if (xMLObject.getChild(i) instanceof XMLObject) {
                    XMLObject xMLObject2 = (XMLObject) xMLObject.getChild(i);
                    if (xMLObject2.getName().equals("clade")) {
                        SiteModel siteModel = (SiteModel) xMLObject2.getChild(SiteModel.class);
                        TaxonList taxonList = (TaxonList) xMLObject2.getChild(TaxonList.class);
                        boolean z = false;
                        if (xMLObject2.hasAttribute("includeStem")) {
                            z = xMLObject2.getBooleanAttribute("includeStem");
                            if (taxonList.getTaxonCount() == 1 && !z) {
                                throw new XMLParseException("The site model is only applied to 1 taxon and therefore must include the stem branch");
                            }
                        } else if (taxonList.getTaxonCount() == 1) {
                            z = true;
                        }
                        try {
                            advancedTreeLikelihood.addCladeSiteModel(siteModel, taxonList, z);
                        } catch (TreeUtils.MissingTaxonException e) {
                            throw new XMLParseException("Taxon, " + e + ", in " + getParserName() + " was not found in the tree.");
                        }
                    } else {
                        continue;
                    }
                }
            }
            return advancedTreeLikelihood;
        }

        @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
        public String getParserDescription() {
            return "This element represents the likelihood of a patternlist on a tree given the site model.";
        }

        @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
        public Class getReturnType() {
            return Likelihood.class;
        }

        @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
        public XMLSyntaxRule[] getSyntaxRules() {
            return this.rules;
        }
    };
    protected FrequencyModel frequencyModel;
    protected SiteModel siteModel;
    protected BranchRateModel branchRateModel;
    private final boolean storePartials = false;
    protected SiteModel tipsSiteModel;
    protected Parameter deltaParameter;
    protected Set<Integer> deltaTips;
    protected ArrayList<Clade> cladeSiteModels;
    private boolean commonAncestorsKnown;
    protected double[] rootPartials;
    protected double[] patternLogLikelihoods;
    protected int categoryCount;
    protected double[] probabilities;
    protected LikelihoodCore likelihoodCore;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:dr/oldevomodel/treelikelihood/AdvancedTreeLikelihood$Clade.class */
    public class Clade {
        SiteModel siteModel;
        Set<String> leafSet;
        int node;
        boolean includeStem;

        Clade(SiteModel siteModel, TaxonList taxonList, boolean z) throws TreeUtils.MissingTaxonException {
            this.siteModel = siteModel;
            this.leafSet = TreeUtils.getLeavesForTaxa(AdvancedTreeLikelihood.this.treeModel, taxonList);
            this.includeStem = z;
            if (taxonList.getTaxonCount() == 1) {
                this.includeStem = true;
            }
            findMRCA();
        }

        void findMRCA() {
            this.node = TreeUtils.getCommonAncestorNode(AdvancedTreeLikelihood.this.treeModel, this.leafSet).getNumber();
        }

        int getNode() {
            return this.node;
        }

        boolean includeStem() {
            return this.includeStem;
        }

        SiteModel getSiteModel() {
            return this.siteModel;
        }
    }

    public AdvancedTreeLikelihood(PatternList patternList, TreeModel treeModel, SiteModel siteModel, BranchRateModel branchRateModel, boolean z, boolean z2) {
        super(ADVANCED_TREE_LIKELIHOOD, patternList, treeModel);
        this.frequencyModel = null;
        this.siteModel = null;
        this.branchRateModel = null;
        this.storePartials = false;
        this.tipsSiteModel = null;
        this.deltaParameter = null;
        this.deltaTips = null;
        this.cladeSiteModels = new ArrayList<>();
        this.commonAncestorsKnown = true;
        this.rootPartials = null;
        this.patternLogLikelihoods = null;
        try {
            this.siteModel = siteModel;
            addModel(siteModel);
            this.frequencyModel = siteModel.getFrequencyModel();
            addModel(this.frequencyModel);
            if (!siteModel.integrateAcrossCategories()) {
                throw new RuntimeException("AdvancedTreeLikelihood can only use SiteModels that require integration across categories");
            }
            this.categoryCount = siteModel.getCategoryCount();
            if (patternList.getDataType() instanceof Nucleotides) {
                if (NativeNucleotideLikelihoodCore.isAvailable()) {
                    Logger.getLogger("dr.evomodel").info("AdvancedTreeLikelihood using native nucleotide likelihood core.");
                    this.likelihoodCore = new NativeNucleotideLikelihoodCore();
                } else {
                    Logger.getLogger("dr.evomodel").info("AdvancedTreeLikelihood Java nucleotide likelihood core.");
                    this.likelihoodCore = new NucleotideLikelihoodCore();
                }
            } else if (patternList.getDataType() instanceof AminoAcids) {
                Logger.getLogger("dr.evomodel").info("AdvancedTreeLikelihood Java amino acid likelihood core.");
                this.likelihoodCore = new AminoAcidLikelihoodCore();
            } else if (patternList.getDataType() instanceof Codons) {
                Logger.getLogger("dr.evomodel").info("TreeLikelihood using Java general likelihood core");
                this.likelihoodCore = new GeneralLikelihoodCore(patternList.getStateCount());
                z = true;
            } else {
                Logger.getLogger("dr.evomodel").info("AdvancedTreeLikelihood using Java general likelihood core");
                this.likelihoodCore = new GeneralLikelihoodCore(patternList.getStateCount());
            }
            Logger.getLogger("dr.evomodel").info("  " + (z ? "Using" : "Ignoring") + " ambiguities in tree likelihood.");
            Logger.getLogger("dr.evomodel").info("  Partial likelihood scaling " + (z2 ? "on." : "off."));
            if (branchRateModel != null) {
                this.branchRateModel = branchRateModel;
                Logger.getLogger("dr.evomodel").info("Branch rate model used: " + branchRateModel.getModelName());
            } else {
                this.branchRateModel = new DefaultBranchRateModel();
            }
            addModel(this.branchRateModel);
            this.probabilities = new double[this.stateCount * this.stateCount];
            this.likelihoodCore.initialize(this.nodeCount, this.patternCount, this.categoryCount, true);
            int externalNodeCount = treeModel.getExternalNodeCount();
            int internalNodeCount = treeModel.getInternalNodeCount();
            for (int i = 0; i < externalNodeCount; i++) {
                String taxonId = treeModel.getTaxonId(i);
                int taxonIndex = patternList.getTaxonIndex(taxonId);
                if (taxonIndex == -1) {
                    throw new TaxonList.MissingTaxonException("Taxon, " + taxonId + ", in tree, " + treeModel.getId() + ", is not found in patternList, " + patternList.getId());
                }
                if (z) {
                    setPartials(this.likelihoodCore, patternList, this.categoryCount, taxonIndex, i);
                } else {
                    setStates(this.likelihoodCore, patternList, taxonIndex, i);
                }
            }
            for (int i2 = 0; i2 < internalNodeCount; i2++) {
                this.likelihoodCore.createNodePartials(externalNodeCount + i2);
            }
        } catch (TaxonList.MissingTaxonException e) {
            throw new RuntimeException(e.toString());
        }
    }

    public void addCladeSiteModel(SiteModel siteModel, TaxonList taxonList, boolean z) throws TreeUtils.MissingTaxonException {
        Logger.getLogger("dr.evomodel").info("SiteModel added for clade.");
        this.cladeSiteModels.add(new Clade(siteModel, taxonList, z));
        addModel(siteModel);
        this.commonAncestorsKnown = true;
    }

    public void addTipsSiteModel(SiteModel siteModel) {
        Logger.getLogger("dr.evomodel").info("SiteModel added for tips.");
        this.tipsSiteModel = siteModel;
        addModel(siteModel);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public void addDeltaParameter(Parameter parameter, TaxonList taxonList) {
        this.deltaParameter = parameter;
        this.deltaTips = new HashSet();
        if (taxonList != null) {
            boolean z = true;
            StringBuffer stringBuffer = new StringBuffer("Delta parameter added for tips: {");
            for (int i = 0; i < this.treeModel.getExternalNodeCount(); i++) {
                NodeRef externalNode = this.treeModel.getExternalNode(i);
                Taxon nodeTaxon = this.treeModel.getNodeTaxon(externalNode);
                if (taxonList.getTaxonIndex(nodeTaxon) != -1) {
                    if (z) {
                        z = false;
                    } else {
                        stringBuffer.append(", ");
                    }
                    stringBuffer.append(nodeTaxon.getId());
                    this.deltaTips.add(Integer.valueOf(externalNode.getNumber()));
                }
            }
            stringBuffer.append(GraphMLUtils.END_SECTION);
            Logger.getLogger("dr.evomodel").info(stringBuffer.toString());
        } else {
            Logger.getLogger("dr.evomodel").info("Delta parameter added for all tips.");
        }
        addVariable(parameter);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // dr.oldevomodel.treelikelihood.AbstractTreeLikelihood, dr.inference.model.AbstractModel
    public final void handleVariableChangedEvent(Variable variable, int i, Variable.ChangeType changeType) {
        updateAllNodes();
        super.handleVariableChangedEvent(variable, i, changeType);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // dr.oldevomodel.treelikelihood.AbstractTreeLikelihood, dr.inference.model.AbstractModel
    public void handleModelChangedEvent(Model model, Object obj, int i) {
        if (model == this.treeModel) {
            if (obj instanceof TreeChangedEvent) {
                if (((TreeChangedEvent) obj).isNodeChanged()) {
                    updateNodeAndChildren(((TreeChangedEvent) obj).getNode());
                } else {
                    updateAllNodes();
                    this.commonAncestorsKnown = false;
                }
            }
        } else if (model == this.branchRateModel) {
            updateAllNodes();
        } else if (model == this.frequencyModel) {
            updateAllNodes();
        } else {
            if (!(model instanceof SiteModel)) {
                throw new RuntimeException("Unknown componentChangedEvent");
            }
            if (model == this.siteModel) {
                updateAllNodes();
            } else if (model == this.tipsSiteModel) {
                updateAllNodes();
            } else {
                NodeRef nodeRef = null;
                int size = this.cladeSiteModels.size();
                for (int i2 = 0; i2 < size; i2++) {
                    Clade clade = this.cladeSiteModels.get(i2);
                    if (!this.commonAncestorsKnown) {
                        clade.findMRCA();
                    }
                    if (clade.getSiteModel() == model) {
                        nodeRef = this.treeModel.getNode(clade.getNode());
                    }
                }
                this.commonAncestorsKnown = true;
                updateNodeAndDescendents(nodeRef);
            }
        }
        super.handleModelChangedEvent(model, obj, i);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // dr.oldevomodel.treelikelihood.AbstractTreeLikelihood, dr.inference.model.AbstractModel
    public void storeState() {
        super.storeState();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // dr.oldevomodel.treelikelihood.AbstractTreeLikelihood, dr.inference.model.AbstractModel
    public void restoreState() {
        updateAllNodes();
        super.restoreState();
    }

    @Override // dr.oldevomodel.treelikelihood.AbstractTreeLikelihood
    protected double calculateLogLikelihood() {
        NodeRef root = this.treeModel.getRoot();
        if (this.rootPartials == null) {
            this.rootPartials = new double[this.patternCount * this.stateCount];
        }
        if (this.patternLogLikelihoods == null) {
            this.patternLogLikelihoods = new double[this.patternCount];
        }
        if (!this.commonAncestorsKnown) {
            int size = this.cladeSiteModels.size();
            for (int i = 0; i < size; i++) {
                this.cladeSiteModels.get(i).findMRCA();
            }
            this.commonAncestorsKnown = true;
        }
        traverse(this.treeModel, root, this.siteModel);
        for (int i2 = 0; i2 < this.nodeCount; i2++) {
            this.updateNode[i2] = false;
        }
        double d = 0.0d;
        for (int i3 = 0; i3 < this.patternCount; i3++) {
            d += this.patternLogLikelihoods[i3] * this.patternWeights[i3];
        }
        return d;
    }

    private boolean traverse(Tree tree, NodeRef nodeRef, SiteModel siteModel) {
        boolean z = false;
        int number = nodeRef.getNumber();
        SiteModel siteModel2 = siteModel;
        if (this.tipsSiteModel == null || !tree.isExternal(nodeRef)) {
            int i = 0;
            int size = this.cladeSiteModels.size();
            while (true) {
                if (i >= size) {
                    break;
                }
                Clade clade = this.cladeSiteModels.get(i);
                if (clade.getNode() == number) {
                    siteModel2 = clade.getSiteModel();
                    if (clade.includeStem()) {
                        siteModel = siteModel2;
                    }
                } else {
                    i++;
                }
            }
        } else {
            siteModel = this.tipsSiteModel;
        }
        NodeRef parent = tree.getParent(nodeRef);
        if (parent != null && this.updateNode[number]) {
            double branchRate = this.branchRateModel.getBranchRate(tree, nodeRef) * (tree.getNodeHeight(parent) - tree.getNodeHeight(nodeRef));
            if (branchRate < 0.0d) {
                throw new RuntimeException("Negative branch length: " + branchRate);
            }
            this.likelihoodCore.setNodeMatrixForUpdate(number);
            if (tree.isExternal(nodeRef) && this.deltaParameter != null && (this.deltaTips.size() == 0 || this.deltaTips.contains(new Integer(nodeRef.getNumber())))) {
                branchRate += this.deltaParameter.getParameterValue(0);
            }
            for (int i2 = 0; i2 < this.categoryCount; i2++) {
                siteModel.getSubstitutionModel().getTransitionProbabilities(siteModel.getRateForCategory(i2) * branchRate, this.probabilities);
                this.likelihoodCore.setNodeMatrix(number, i2, this.probabilities);
            }
            z = true;
        }
        if (!tree.isExternal(nodeRef)) {
            if (tree.getChildCount(nodeRef) != 2) {
                throw new RuntimeException("binary trees only!");
            }
            NodeRef child = tree.getChild(nodeRef, 0);
            boolean traverse = traverse(tree, child, siteModel2);
            NodeRef child2 = tree.getChild(nodeRef, 1);
            boolean traverse2 = traverse(tree, child2, siteModel2);
            if (traverse || traverse2) {
                int number2 = child.getNumber();
                int number3 = child2.getNumber();
                this.likelihoodCore.setNodePartialsForUpdate(number);
                this.likelihoodCore.calculatePartials(number2, number3, number);
                if (parent == null) {
                    double[] frequencies = this.frequencyModel.getFrequencies();
                    this.likelihoodCore.integratePartials(number, siteModel.getCategoryProportions(), this.rootPartials);
                    this.likelihoodCore.calculateLogLikelihoods(this.rootPartials, frequencies, this.patternLogLikelihoods);
                }
                z = true;
            }
        }
        return z;
    }
}
