package dr.evolution.tree;

import dr.evolution.util.TaxonList;
import dr.evoxml.util.GraphMLUtils;
import dr.util.FrequencySet;
import java.util.BitSet;
import java.util.HashMap;
import java.util.Map;
import java.util.SortedSet;
import java.util.TreeSet;

/* loaded from: input_file:dr/evolution/tree/CladeSet.class */
public class CladeSet extends FrequencySet<BitSet> {
    private TaxonList taxonList;
    private final Map<BitSet, Double> totalNodeHeight;
    private int totalTrees;

    public CladeSet() {
        this.taxonList = null;
        this.totalNodeHeight = new HashMap();
        this.totalTrees = 0;
    }

    public CladeSet(Tree tree) {
        this(tree, tree);
    }

    public CladeSet(Tree tree, TaxonList taxonList) {
        this.taxonList = null;
        this.totalNodeHeight = new HashMap();
        this.totalTrees = 0;
        this.taxonList = taxonList;
        add(tree);
    }

    public int getCladeCount() {
        return size();
    }

    public String getClade(int i) {
        BitSet bitSet = get(i);
        StringBuffer stringBuffer = new StringBuffer(GraphMLUtils.START_SECTION);
        boolean z = true;
        for (String str : getTaxaSet(bitSet)) {
            if (z) {
                z = false;
            } else {
                stringBuffer.append(", ");
            }
            stringBuffer.append(str);
        }
        stringBuffer.append(GraphMLUtils.END_SECTION);
        return stringBuffer.toString();
    }

    private SortedSet<String> getTaxaSet(BitSet bitSet) {
        TreeSet treeSet = new TreeSet();
        for (int i = 0; i < bitSet.length(); i++) {
            if (bitSet.get(i)) {
                treeSet.add(this.taxonList.getTaxonId(i));
            }
        }
        return treeSet;
    }

    int getCladeFrequency(int i) {
        return getFrequency(i);
    }

    public void add(Tree tree) {
        if (this.taxonList == null) {
            this.taxonList = tree;
        }
        this.totalTrees++;
        addClades(tree, tree.getRoot(), null);
    }

    private void addClades(Tree tree, NodeRef nodeRef, BitSet bitSet) {
        if (tree.isExternal(nodeRef)) {
            if (this.taxonList != null) {
                bitSet.set(this.taxonList.getTaxonIndex(tree.getNodeTaxon(nodeRef).getId()));
                return;
            } else {
                bitSet.set(nodeRef.getNumber());
                return;
            }
        }
        BitSet bitSet2 = new BitSet();
        for (int i = 0; i < tree.getChildCount(nodeRef); i++) {
            addClades(tree, tree.getChild(nodeRef, i), bitSet2);
        }
        add(bitSet2, 1);
        addNodeHeight(bitSet2, tree.getNodeHeight(nodeRef));
        if (bitSet != null) {
            bitSet.or(bitSet2);
        }
    }

    public double getMeanNodeHeight(int i) {
        return getTotalNodeHeight(get(i)) / getFrequency(i);
    }

    private double getTotalNodeHeight(BitSet bitSet) {
        Double d = this.totalNodeHeight.get(bitSet);
        if (d == null) {
            return 0.0d;
        }
        return d.doubleValue();
    }

    private void addNodeHeight(BitSet bitSet, double d) {
        this.totalNodeHeight.put(bitSet, Double.valueOf(getTotalNodeHeight(bitSet) + d));
    }

    private BitSet annotate(MutableTree mutableTree, NodeRef nodeRef, String str) {
        BitSet bitSet = null;
        if (mutableTree.isExternal(nodeRef)) {
            int taxonIndex = this.taxonList != null ? this.taxonList.getTaxonIndex(mutableTree.getNodeTaxon(nodeRef).getId()) : nodeRef.getNumber();
            bitSet = new BitSet(mutableTree.getExternalNodeCount());
            bitSet.set(taxonIndex);
        } else {
            for (int i = 0; i < mutableTree.getChildCount(nodeRef); i++) {
                BitSet annotate = annotate(mutableTree, mutableTree.getChild(nodeRef, i), str);
                if (i == 0) {
                    bitSet = annotate;
                } else {
                    bitSet.or(annotate);
                }
            }
            int frequency = getFrequency((CladeSet) bitSet);
            if (frequency >= 0) {
                mutableTree.setNodeAttribute(nodeRef, str, Double.valueOf(frequency / this.totalTrees));
            }
        }
        return bitSet;
    }

    public double annotate(MutableTree mutableTree, String str) {
        annotate(mutableTree, mutableTree.getRoot(), str);
        double d = 0.0d;
        for (int i = 0; i < mutableTree.getInternalNodeCount(); i++) {
            d += Math.log(((Double) mutableTree.getNodeAttribute(mutableTree.getInternalNode(i), str)).doubleValue());
        }
        return d;
    }

    public boolean hasClade(int i, Tree tree) {
        NodeRef[] nodeRefArr = new NodeRef[1];
        findClade(get(i), tree, tree.getRoot(), nodeRefArr);
        return nodeRefArr[0] != null;
    }

    private int findClade(BitSet bitSet, Tree tree, NodeRef nodeRef, NodeRef[] nodeRefArr) {
        if (tree.isExternal(nodeRef)) {
            return this.taxonList != null ? bitSet.get(this.taxonList.getTaxonIndex(tree.getNodeTaxon(nodeRef).getId())) ? 1 : -1 : bitSet.get(nodeRef.getNumber()) ? 1 : -1;
        }
        int i = 0;
        for (int i2 = 0; i2 < tree.getChildCount(nodeRef); i2++) {
            int findClade = findClade(bitSet, tree, tree.getChild(nodeRef, i2), nodeRefArr);
            i = (findClade == -1 || i == -1) ? -1 : i + findClade;
        }
        if (i == bitSet.cardinality()) {
            nodeRefArr[0] = nodeRef;
        }
        return i;
    }
}
