package dr.evomodel.tree;

import dr.evolution.io.Importer;
import dr.evolution.io.NexusImporter;
import dr.evolution.io.TreeTrace;
import dr.evolution.tree.Clade;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.SimpleTree;
import dr.evolution.tree.Tree;
import dr.evomodel.continuous.TopographicalMap;
import dr.inference.model.Likelihood;
import dr.math.MathUtils;
import dr.util.Citable;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.Reader;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Set;

/* loaded from: input_file:dr/evomodel/tree/ConditionalCladeFrequency.class */
public class ConditionalCladeFrequency extends AbstractCladeImportanceDistribution {
    private double EPSILON;
    private long samples = 0;
    private HashMap<BitSet, Clade> cladeProbabilities = new HashMap<>();
    private HashMap<BitSet, HashMap<BitSet, Clade>> cladeCoProbabilities = new HashMap<>();
    private TreeTrace[] traces;
    private int burnin;

    public ConditionalCladeFrequency(Tree tree, double d) {
        this.EPSILON = d;
    }

    public ConditionalCladeFrequency(TreeTrace[] treeTraceArr, double d, int i, boolean z) {
        this.EPSILON = d;
        this.traces = treeTraceArr;
        int i2 = Integer.MAX_VALUE;
        for (TreeTrace treeTrace : treeTraceArr) {
            if (treeTrace.getMaximumState() < i2) {
                i2 = treeTrace.getMaximumState();
            }
        }
        if (i < 0 || i >= i2) {
            this.burnin = i2 / (10 * treeTraceArr[0].getStepSize());
            if (z) {
                System.out.println("WARNING: Burn-in larger than total number of states - using 10% of smallest trace");
            }
        } else {
            this.burnin = i;
        }
        analyzeTrace(z);
    }

    public void analyzeTrace(boolean z) {
        if (z && this.traces.length > 1) {
            System.out.println("Combining " + this.traces.length + " traces.");
        }
        getTree(0);
        for (TreeTrace treeTrace : this.traces) {
            int treeCount = treeTrace.getTreeCount(this.burnin * treeTrace.getStepSize());
            double d = treeCount / 60.0d;
            int i = 1;
            if (z) {
                System.out.println("Analyzing " + treeCount + " trees...");
                System.out.println("0              25             50             75            100");
                System.out.println("|--------------|--------------|--------------|--------------|");
                System.out.print(TopographicalMap.defaultInvalidString);
            }
            for (int i2 = 1; i2 < treeCount; i2++) {
                addTree(treeTrace.getTree(i2, this.burnin * treeTrace.getStepSize()));
                if (i2 >= ((int) Math.round(i * d)) && i <= 60) {
                    if (z) {
                        System.out.print(TopographicalMap.defaultInvalidString);
                        System.out.flush();
                    }
                    i++;
                }
            }
            if (z) {
                System.out.println(TopographicalMap.defaultInvalidString);
            }
        }
    }

    public void report(Reader reader) throws IOException, Importer.ImportException {
        System.err.println("making report");
        ArrayList arrayList = new ArrayList();
        BufferedReader bufferedReader = new BufferedReader(reader);
        if (!bufferedReader.readLine().toUpperCase().startsWith("#NEXUS")) {
            throw new RuntimeException("Could not read reference tree. Only Nexus format is supported.");
        }
        for (Tree tree : new NexusImporter(bufferedReader).importTrees(null)) {
            arrayList.add(tree);
            SimpleTree simpleTree = new SimpleTree(tree);
            System.out.println("Estimated marginal posterior by condiational clade frequencies:");
            System.out.println(getTreeProbability(simpleTree) + Citable.Utils.DEFAULT_PREPEND + simpleTree);
        }
        System.out.flush();
    }

    @Override // dr.evomodel.tree.AbstractCladeImportanceDistribution, dr.evolution.tree.ImportanceDistribution
    public double getTreeProbability(Tree tree) {
        double d = 0.0d;
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        getNonComplementaryClades(tree, tree.getRoot(), arrayList2, arrayList);
        int size = arrayList.size();
        for (int i = 0; i < size; i++) {
            Clade clade = arrayList.get(i);
            Clade clade2 = arrayList2.get(i);
            double d2 = this.EPSILON;
            BitSet bits = clade2.getBits();
            double sampleCount = this.cladeProbabilities.containsKey(bits) ? 0.0d + this.cladeProbabilities.get(bits).getSampleCount() : 0.0d;
            if (this.cladeCoProbabilities.containsKey(bits)) {
                if (this.cladeCoProbabilities.get(bits).containsKey(clade.getBits())) {
                    d2 += r0.get(r0).getSampleCount();
                }
            }
            d += Math.log(d2 / (sampleCount + (this.EPSILON * (Math.pow(2.0d, clade2.getSize() - 1) - 1.0d))));
        }
        return d;
    }

    public double getTreeProbability(Tree tree, HashMap<String, Integer> hashMap) {
        double d = 0.0d;
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        getNonComplementaryClades(tree, tree.getRoot(), arrayList2, arrayList, hashMap);
        int size = arrayList.size();
        for (int i = 0; i < size; i++) {
            Clade clade = arrayList.get(i);
            Clade clade2 = arrayList2.get(i);
            double d2 = this.EPSILON;
            BitSet bits = clade2.getBits();
            double sampleCount = this.cladeProbabilities.containsKey(bits) ? 0.0d + this.cladeProbabilities.get(bits).getSampleCount() : 0.0d;
            if (this.cladeCoProbabilities.containsKey(bits)) {
                if (this.cladeCoProbabilities.get(bits).containsKey(clade.getBits())) {
                    d2 += r0.get(r0).getSampleCount();
                }
            }
            d += Math.log(d2 / (sampleCount + (this.EPSILON * (Math.pow(2.0d, clade2.getSize() - 1) - 1.0d))));
        }
        return d;
    }

    @Override // dr.evomodel.tree.AbstractCladeImportanceDistribution, dr.evolution.tree.ImportanceDistribution
    public double splitClade(Clade clade, Clade[] cladeArr) {
        BitSet bitSet;
        BitSet bitSet2;
        double pow = Math.pow(2.0d, clade.getSize() - 1) - 1.0d;
        double d = 0.0d;
        if (this.cladeCoProbabilities.containsKey(clade.getBits())) {
            HashMap<BitSet, Clade> hashMap = this.cladeCoProbabilities.get(clade.getBits());
            double d2 = 0.0d;
            double d3 = 0.0d;
            Set<BitSet> keySet = hashMap.keySet();
            Iterator<BitSet> it = keySet.iterator();
            while (it.hasNext()) {
                if (clade.getSize() > hashMap.get(it.next()).getSize() + 1) {
                    d3 += (r0.getSampleCount() + this.EPSILON) / 2.0d;
                    d2 += 0.5d;
                } else {
                    d3 += r0.getSampleCount() + this.EPSILON;
                    d2 += 1.0d;
                }
            }
            double d4 = d3 + (this.EPSILON * (pow - d2));
            double nextDouble = MathUtils.nextDouble() * d4;
            Iterator<BitSet> it2 = keySet.iterator();
            while (true) {
                if (!it2.hasNext()) {
                    break;
                }
                Clade clade2 = hashMap.get(it2.next());
                nextDouble = clade.getSize() > clade2.getSize() + 1 ? nextDouble - ((clade2.getSampleCount() + this.EPSILON) / 2.0d) : nextDouble - (clade2.getSampleCount() + this.EPSILON);
                if (nextDouble < 0.0d) {
                    cladeArr[0] = clade2;
                    d = (clade2.getSampleCount() + this.EPSILON) / d4;
                    break;
                }
            }
            if (nextDouble >= 0.0d) {
                d = this.EPSILON / d4;
                while (true) {
                    bitSet2 = (BitSet) clade.getBits().clone();
                    int i = -1;
                    do {
                        i = bitSet2.nextSetBit(i + 1);
                        if (i > -1 && MathUtils.nextBoolean()) {
                            bitSet2.clear(i);
                        }
                    } while (i > -1);
                    if (bitSet2.cardinality() != 0 && bitSet2.cardinality() != clade.getSize()) {
                        BitSet bitSet3 = (BitSet) bitSet2.clone();
                        bitSet3.xor(clade.getBits());
                        if (!hashMap.containsKey(bitSet2) && !hashMap.containsKey(bitSet3)) {
                            break;
                        }
                    }
                }
                cladeArr[0] = new Clade(bitSet2, 0.9999d * clade.getHeight());
                BitSet bitSet4 = (BitSet) cladeArr[0].getBits().clone();
                bitSet4.xor(clade.getBits());
                cladeArr[1] = new Clade(bitSet4, 0.9999d * clade.getHeight());
            } else {
                BitSet bitSet5 = (BitSet) cladeArr[0].getBits().clone();
                bitSet5.xor(clade.getBits());
                cladeArr[1] = hashMap.get(bitSet5);
                if (cladeArr[1] == null) {
                    cladeArr[1] = new Clade(bitSet5, 0.9999d * clade.getHeight());
                }
            }
        } else {
            d = 1.0d / pow;
            while (true) {
                bitSet = (BitSet) clade.getBits().clone();
                int i2 = -1;
                do {
                    i2 = bitSet.nextSetBit(i2 + 1);
                    if (i2 > -1 && MathUtils.nextBoolean()) {
                        bitSet.clear(i2);
                    }
                } while (i2 > -1);
                if (bitSet.cardinality() != 0 && bitSet.cardinality() != clade.getSize()) {
                    break;
                }
            }
            Clade clade3 = new Clade(bitSet, 0.9999d * clade.getHeight());
            clade3.addHeight(0.9999d * clade.getHeight());
            cladeArr[0] = clade3;
            BitSet bitSet6 = (BitSet) cladeArr[0].getBits().clone();
            bitSet6.xor(clade.getBits());
            cladeArr[1] = new Clade(bitSet6, 0.9999d * clade.getHeight());
            clade3.addHeight(0.9999d * clade.getHeight());
        }
        return Math.log(d);
    }

    @Override // dr.evomodel.tree.AbstractCladeImportanceDistribution
    public double getChanceForNodeHeights(TreeModel treeModel, Likelihood likelihood) {
        NodeRef root = treeModel.getRoot();
        getClade(treeModel, root);
        int childCount = treeModel.getChildCount(root);
        for (int i = 0; i < childCount; i++) {
            if (!treeModel.isExternal(treeModel.getChild(root, i))) {
            }
        }
        return 0.0d;
    }

    @Override // dr.evomodel.tree.AbstractCladeImportanceDistribution
    public double setNodeHeights(TreeModel treeModel, Likelihood likelihood) {
        NodeRef root = treeModel.getRoot();
        getClade(treeModel, root);
        int childCount = treeModel.getChildCount(root);
        for (int i = 0; i < childCount; i++) {
            if (!treeModel.isExternal(treeModel.getChild(root, i))) {
            }
        }
        return 0.0d;
    }

    public final Tree getTree(int i) {
        int i2 = 0;
        int i3 = 0;
        for (TreeTrace treeTrace : this.traces) {
            i3 += treeTrace.getTreeCount(this.burnin * treeTrace.getStepSize());
            if (i < i3) {
                return treeTrace.getTree(i - i2, this.burnin * treeTrace.getStepSize());
            }
            i2 = i3;
        }
        throw new RuntimeException("Couldn't find tree " + i);
    }

    @Override // dr.evomodel.tree.AbstractCladeImportanceDistribution, dr.evolution.tree.ImportanceDistribution
    public void addTree(Tree tree) {
        HashMap<BitSet, Clade> hashMap;
        this.samples++;
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        getClades(tree, tree.getRoot(), arrayList2, arrayList);
        arrayList.add(arrayList2.get(arrayList2.size() - 1));
        arrayList2.add(arrayList.get(arrayList.size() - 1));
        int size = arrayList.size();
        for (int i = 0; i < size; i++) {
            Clade clade = arrayList.get(i);
            Clade clade2 = arrayList2.get(i);
            if (this.cladeProbabilities.containsKey(clade.getBits())) {
                this.cladeProbabilities.get(clade.getBits()).addHeight(clade.getHeight());
            } else {
                clade.addHeight(clade.getHeight());
                this.cladeProbabilities.put(clade.getBits(), clade);
            }
            if (!clade2.equals(clade)) {
                if (this.cladeCoProbabilities.containsKey(clade2.getBits())) {
                    hashMap = this.cladeCoProbabilities.get(clade2.getBits());
                } else {
                    hashMap = new HashMap<>();
                    this.cladeCoProbabilities.put(clade2.getBits(), hashMap);
                }
                if (hashMap.containsKey(clade.getBits())) {
                    hashMap.get(clade.getBits()).addHeight(clade.getHeight());
                } else {
                    Clade clade3 = new Clade((BitSet) clade.getBits().clone(), clade.getHeight());
                    clade3.addHeight(clade.getHeight());
                    hashMap.put(clade.getBits(), clade3);
                }
            }
        }
    }

    public void addTree(Tree tree, HashMap<String, Integer> hashMap) {
        HashMap<BitSet, Clade> hashMap2;
        this.samples++;
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        getClades(tree, tree.getRoot(), arrayList2, arrayList, hashMap);
        arrayList.add(arrayList2.get(arrayList2.size() - 1));
        arrayList2.add(arrayList.get(arrayList.size() - 1));
        int size = arrayList.size();
        for (int i = 0; i < size; i++) {
            Clade clade = arrayList.get(i);
            Clade clade2 = arrayList2.get(i);
            if (this.cladeProbabilities.containsKey(clade.getBits())) {
                this.cladeProbabilities.get(clade.getBits()).addHeight(clade.getHeight());
            } else {
                clade.addHeight(clade.getHeight());
                this.cladeProbabilities.put(clade.getBits(), clade);
            }
            if (!clade2.equals(clade)) {
                if (this.cladeCoProbabilities.containsKey(clade2.getBits())) {
                    hashMap2 = this.cladeCoProbabilities.get(clade2.getBits());
                } else {
                    hashMap2 = new HashMap<>();
                    this.cladeCoProbabilities.put(clade2.getBits(), hashMap2);
                }
                if (hashMap2.containsKey(clade.getBits())) {
                    hashMap2.get(clade.getBits()).addHeight(clade.getHeight());
                } else {
                    Clade clade3 = new Clade((BitSet) clade.getBits().clone(), clade.getHeight());
                    clade3.addHeight(clade.getHeight());
                    hashMap2.put(clade.getBits(), clade3);
                }
            }
        }
    }

    public static ConditionalCladeFrequency analyzeLogFile(Reader[] readerArr, double d, int i, boolean z) throws IOException {
        TreeTrace[] treeTraceArr = new TreeTrace[readerArr.length];
        for (int i2 = 0; i2 < readerArr.length; i2++) {
            try {
                treeTraceArr[i2] = TreeTrace.loadTreeTrace(readerArr[i2]);
                readerArr[i2].close();
            } catch (Importer.ImportException e) {
                throw new RuntimeException(e.toString());
            }
        }
        return new ConditionalCladeFrequency(treeTraceArr, d, i, z);
    }
}
