package dr.evomodel.tree;

import dr.app.tools.NexusExporter;
import dr.evolution.io.Importer;
import dr.evolution.io.NewickImporter;
import dr.evolution.io.TreeTrace;
import dr.evolution.tree.CladeSet;
import dr.evolution.tree.FlexibleTree;
import dr.evolution.tree.MutableTree;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeUtils;
import dr.evomodel.continuous.TopographicalMap;
import dr.evomodel.tree.randomlocalmodel.RLTVLoggerOnTree;
import dr.inferencexml.model.CompoundLikelihoodParser;
import dr.stats.DiscreteStatistics;
import dr.util.FrequencySet;
import dr.util.NumberFormatter;
import java.io.IOException;
import java.io.PrintStream;
import java.io.Reader;
import java.io.StringReader;
import java.util.ArrayList;
import java.util.Set;
import jebl.evolution.treemetrics.RobinsonsFouldMetric;
import jebl.evolution.trees.SimpleRootedTree;

/* loaded from: input_file:dr/evomodel/tree/TreeTraceAnalysis.class */
public class TreeTraceAnalysis {
    private int burnin;
    private final TreeTrace[] traces;
    private CladeSet cladeSet;
    private FrequencySet<String> treeSet;
    static final /* synthetic */ boolean $assertionsDisabled;

    private TreeTraceAnalysis(TreeTrace[] treeTraceArr, int i, boolean z) {
        this.burnin = -1;
        this.traces = treeTraceArr;
        int i2 = Integer.MAX_VALUE;
        for (TreeTrace treeTrace : treeTraceArr) {
            if (treeTrace.getMaximumState() < i2) {
                i2 = treeTrace.getMaximumState();
            }
        }
        if (i < 0 || i >= i2) {
            this.burnin = i2 / (10 * treeTraceArr[0].getStepSize());
            if (z) {
                System.out.println((i < 0 ? "Defalt burn-in" : "WARNING: Burn-in larger than total number of states") + " - using 10% of smallest trace");
            }
        } else {
            this.burnin = i;
        }
        analyze(z);
    }

    public static double[] getSymmetricTreeDistanceTrace(TreeTrace treeTrace, Tree tree) {
        double[] dArr = new double[treeTrace.getTreeCount(0)];
        RobinsonsFouldMetric robinsonsFouldMetric = new RobinsonsFouldMetric();
        SimpleRootedTree asJeblTree = TreeUtils.asJeblTree(tree);
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = robinsonsFouldMetric.getMetric(asJeblTree, TreeUtils.asJeblTree(treeTrace.getTree(i, 0)));
        }
        return dArr;
    }

    void analyze(boolean z) {
        if (z && this.traces.length > 1) {
            System.out.println("Combining " + this.traces.length + " traces.");
        }
        Tree tree = getTree(0);
        double[][] dArr = new double[tree.getNodeCount()][tree.getNodeCount()];
        double[] dArr2 = new double[tree.getNodeCount()];
        boolean z2 = false;
        this.cladeSet = new CladeSet(tree);
        this.treeSet = new FrequencySet<>();
        this.treeSet.add(TreeUtils.uniqueNewick(tree, tree.getRoot()));
        for (TreeTrace treeTrace : this.traces) {
            int treeCount = treeTrace.getTreeCount(this.burnin * treeTrace.getStepSize());
            double d = treeCount / 60.0d;
            int i = 1;
            if (z) {
                System.out.println("Analyzing " + treeCount + " trees...");
                System.out.println("0              25             50             75            100");
                System.out.println("|--------------|--------------|--------------|--------------|");
                System.out.print(TopographicalMap.defaultInvalidString);
            }
            for (int i2 = 1; i2 < treeCount; i2++) {
                Tree tree2 = treeTrace.getTree(i2, this.burnin * treeTrace.getStepSize());
                for (int i3 = 0; i3 < tree2.getNodeCount(); i3++) {
                    if (tree2.getNode(i3) != tree2.getRoot() && tree2.getNodeAttribute(tree2.getNode(i3), RLTVLoggerOnTree.TRAIT_NAME) != null) {
                        z2 = true;
                        if (tree2.getNodeAttribute(tree2.getNode(i3), RLTVLoggerOnTree.TRAIT_NAME) != null) {
                            boolean changed = getChanged(tree2, i3);
                            if (changed) {
                                int i4 = i3;
                                dArr2[i4] = dArr2[i4] + ((Double) tree2.getNodeAttribute(tree2.getNode(i3), "rate")).doubleValue();
                            }
                            for (int i5 = 0; i5 < tree2.getNodeCount(); i5++) {
                                if (tree2.getNode(i5) != tree2.getRoot()) {
                                    double[] dArr3 = dArr[i3];
                                    int i6 = i5;
                                    dArr3[i6] = dArr3[i6] + ((changed && getChanged(tree2, i5)) ? 1.0d : 0.0d);
                                }
                            }
                        }
                    }
                }
                this.cladeSet.add(tree2);
                this.treeSet.add(TreeUtils.uniqueNewick(tree2, tree2.getRoot()));
                if (z && i2 >= ((int) Math.round(i * d)) && i <= 60) {
                    System.out.print(TopographicalMap.defaultInvalidString);
                    System.out.flush();
                    i++;
                }
            }
            if (z) {
                System.out.println(TopographicalMap.defaultInvalidString);
            }
        }
        if (z2) {
            for (int i7 = 0; i7 < tree.getNodeCount(); i7++) {
                System.out.println(i7 + "\t" + dArr2[i7]);
            }
            System.out.println();
            for (int i8 = 0; i8 < tree.getNodeCount(); i8++) {
                for (int i9 = 0; i9 < tree.getNodeCount(); i9++) {
                    System.out.print(dArr[i8][i9] + "\t");
                }
                System.out.println();
            }
        }
    }

    private boolean getChanged(Tree tree, int i) {
        Object nodeAttribute = tree.getNodeAttribute(tree.getNode(i), RLTVLoggerOnTree.TRAIT_NAME);
        return nodeAttribute instanceof Integer ? ((Integer) nodeAttribute).intValue() == 1 : ((Boolean) nodeAttribute).booleanValue();
    }

    final MutableTree analyzeTree(String str) {
        int treeCount = getTreeCount();
        FlexibleTree flexibleTree = null;
        int i = 0;
        while (true) {
            if (i >= treeCount) {
                break;
            }
            Tree tree = getTree(i);
            if (TreeUtils.uniqueNewick(tree, tree.getRoot()).equals(str)) {
                flexibleTree = new FlexibleTree(tree);
                break;
            }
            i++;
        }
        if (flexibleTree == null) {
            throw new RuntimeException("No target tree in trace");
        }
        int internalNodeCount = flexibleTree.getInternalNodeCount();
        for (int i2 = 0; i2 < internalNodeCount; i2++) {
            double[] dArr = new double[treeCount];
            NodeRef internalNode = flexibleTree.getInternalNode(i2);
            Set<String> descendantLeaves = TreeUtils.getDescendantLeaves(flexibleTree, internalNode);
            for (int i3 = 0; i3 < treeCount; i3++) {
                Tree tree2 = getTree(i3);
                dArr[i3] = tree2.getNodeHeight(TreeUtils.getCommonAncestorNode(tree2, descendantLeaves));
            }
            flexibleTree.setNodeHeight(internalNode, DiscreteStatistics.mean(dArr));
            double quantile = DiscreteStatistics.quantile(0.975d, dArr);
            flexibleTree.setNodeAttribute(internalNode, "upper", Double.valueOf(quantile));
            double quantile2 = DiscreteStatistics.quantile(0.025d, dArr);
            flexibleTree.setNodeAttribute(internalNode, "lower", Double.valueOf(quantile2));
            flexibleTree.setNodeAttribute(internalNode, "range", new Double[]{Double.valueOf(quantile2), Double.valueOf(quantile)});
        }
        return flexibleTree;
    }

    final int getTreeCount() {
        int i = 0;
        for (TreeTrace treeTrace : this.traces) {
            i += treeTrace.getTreeCount(this.burnin * treeTrace.getStepSize());
        }
        return i;
    }

    final Tree getTree(int i) {
        int i2 = 0;
        int i3 = 0;
        for (TreeTrace treeTrace : this.traces) {
            int stepSize = this.burnin * treeTrace.getStepSize();
            i3 += treeTrace.getTreeCount(stepSize);
            if (i < i3) {
                return treeTrace.getTree(i - i2, stepSize);
            }
            i2 = i3;
        }
        throw new RuntimeException("Couldn't find tree " + i);
    }

    public void report(int i) throws IOException {
        report(0.5d, 0.95d, i);
    }

    public void report(double d, int i) throws IOException {
        report(d, 0.95d, i);
    }

    public void report(double d, double d2, int i) throws IOException {
        System.err.println("making report");
        NumberFormatter numberFormatter = new NumberFormatter(6);
        numberFormatter.setPadding(true);
        numberFormatter.setFieldWidth(14);
        int size = this.treeSet.size();
        int sumFrequency = this.treeSet.getSumFrequency();
        System.out.println();
        System.out.println("burnIn=" + this.burnin);
        System.out.println("total trees used =" + sumFrequency);
        System.out.println();
        System.out.println(Math.round(d2 * 100.0d) + "% credible set (" + size + " unique trees, " + sumFrequency + " total):");
        System.out.println("Count\tPercent\tTree");
        int i2 = (int) (d2 * sumFrequency);
        int i3 = 0;
        int i4 = 0;
        NumberFormatter numberFormatter2 = new NumberFormatter(8);
        int i5 = 0;
        while (true) {
            if (i5 >= size) {
                break;
            }
            int frequency = this.treeSet.getFrequency(i5);
            boolean z = true;
            if (i > 0 && frequency <= i) {
                z = false;
                i4++;
            }
            double d3 = frequency / sumFrequency;
            if (z) {
                System.out.print(frequency);
                System.out.print("\t" + numberFormatter2.formatDecimal(d3 * 100.0d, 2) + "%");
            }
            i3 += frequency;
            double d4 = i3 / sumFrequency;
            if (z) {
                System.out.print("\t" + numberFormatter2.formatDecimal(d4 * 100.0d, 2) + "%");
                String str = this.treeSet.get(i5);
                if (frequency > 100) {
                    System.out.println("\t" + TreeUtils.newick(analyzeTree(str)));
                } else {
                    System.out.println("\t" + str);
                }
            }
            if (i3 >= i2) {
                if (i4 > 0) {
                    System.out.println();
                    System.out.println("... (" + i4 + ") trees.");
                }
                System.out.println();
                System.out.println("95% credible set has " + (i5 + 1) + " trees.");
            } else {
                i5++;
            }
        }
        System.out.println();
        System.out.println(Math.round(d * 100.0d) + "%-rule clades (" + this.cladeSet.size() + " unique clades):");
        int size2 = this.cladeSet.size();
        for (int i6 = 0; i6 < size2; i6++) {
            int frequency2 = this.cladeSet.getFrequency(i6);
            double d5 = frequency2 / sumFrequency;
            if (d5 >= d) {
                System.out.print(frequency2);
                System.out.print("\t" + numberFormatter2.formatDecimal(d5 * 100.0d, 2) + "%");
                System.out.print("\t" + this.cladeSet.getMeanNodeHeight(i6));
                System.out.println("\t" + this.cladeSet.getClade(i6));
            }
        }
        System.out.flush();
        System.out.println("Clade credible sets:");
        int i7 = (5 * sumFrequency) / 100;
        int i8 = (50 * sumFrequency) / 100;
        int i9 = 0;
        if (!$assertionsDisabled && size != this.treeSet.size()) {
            throw new AssertionError();
        }
        CladeSet cladeSet = new CladeSet();
        for (int i10 = 0; i10 < size; i10++) {
            i9 += this.treeSet.getFrequency(i10);
            try {
                cladeSet.add(new NewickImporter(new StringReader(this.treeSet.get(i10))).importNextTree());
            } catch (Importer.ImportException e) {
                System.err.println("Err");
            }
            if (i9 >= i7) {
                System.out.println();
                System.out.println("5% credible set has " + cladeSet.getCladeCount() + " clades.");
                i7 = sumFrequency + 1;
            }
            if (i9 >= i8) {
                System.out.println();
                System.out.println("50% credible set has " + cladeSet.getCladeCount() + " clades.");
                i8 = sumFrequency + 1;
            }
        }
        System.out.flush();
    }

    public void shortReport(String str, Tree tree, boolean z) {
        shortReport(str, tree, z, 0.95d);
    }

    public void shortReport(String str, Tree tree, boolean z, double d) {
        String uniqueNewick = tree != null ? TreeUtils.uniqueNewick(tree, tree.getRoot()) : "";
        int size = this.treeSet.size();
        int sumFrequency = this.treeSet.getSumFrequency();
        double frequency = this.treeSet.getFrequency(0) / sumFrequency;
        String str2 = this.treeSet.get(0);
        if (z) {
            System.out.println("file\ttrees\tuniqueTrees\tp(MAP)\tMAP tree\t" + (((int) d) * 100) + "credSize\ttrue_I\tp(true)\tcum(true)");
        }
        System.out.print(str + "\t");
        System.out.print(sumFrequency + "\t");
        System.out.print(size + "\t");
        System.out.print(frequency + "\t");
        System.out.print(str2 + "\t");
        int i = (int) (d * sumFrequency);
        int i2 = 0;
        int i3 = -1;
        int i4 = -1;
        double d2 = 0.0d;
        double d3 = 1.0d;
        for (int i5 = 0; i5 < size; i5++) {
            int frequency2 = this.treeSet.getFrequency(i5);
            double d4 = frequency2 / sumFrequency;
            i2 += frequency2;
            double d5 = i2 / sumFrequency;
            if (this.treeSet.get(i5).equals(uniqueNewick)) {
                i4 = i5 + 1;
                d2 = d4;
                d3 = d5;
            }
            if (i2 >= i && i3 == -1) {
                i3 = i5 + 1;
            }
        }
        System.out.print(i3 + "\t");
        System.out.print(i4 + "\t");
        System.out.print(d2 + "\t");
        System.out.println(d3);
    }

    public void export(PrintStream printStream, double d, int i, boolean z) {
        NexusExporter nexusExporter = new NexusExporter(printStream);
        int size = this.treeSet.size();
        if (i < 0) {
            i = size;
        }
        int sumFrequency = this.treeSet.getSumFrequency();
        ArrayList arrayList = new ArrayList();
        int min = Math.min(i, size);
        boolean z2 = z && min > 60;
        if (z2) {
            System.out.println("Exporting " + min + " trees...");
            System.out.println("0              25             50             75            100");
            System.out.println("|--------------|--------------|--------------|--------------|");
            System.out.print(TopographicalMap.defaultInvalidString);
        }
        for (int i2 = 0; i2 < size; i2++) {
            double frequency = this.treeSet.getFrequency(i2) / sumFrequency;
            if (frequency >= d) {
                MutableTree analyzeTree = analyzeTree(this.treeSet.get(i2));
                analyzeTree.setAttribute("weight", Double.valueOf(frequency));
                analyzeTree.setNodeAttribute(analyzeTree.getRoot(), CompoundLikelihoodParser.POSTERIOR, Double.valueOf(Math.exp(this.cladeSet.annotate(analyzeTree, CompoundLikelihoodParser.POSTERIOR) / analyzeTree.getInternalNodeCount())));
                arrayList.add(analyzeTree);
                if (z2 && (i2 + 1) % (min / 60) == 0) {
                    System.out.print(TopographicalMap.defaultInvalidString);
                }
                if (arrayList.size() == i) {
                    break;
                }
            }
        }
        if (arrayList.size() > 0) {
            nexusExporter.exportTrees((Tree[]) arrayList.toArray(new Tree[arrayList.size()]), true);
        }
    }

    public int getBurnin() {
        return this.burnin;
    }

    public static TreeTraceAnalysis analyzeLogFile(Reader[] readerArr, int i, boolean z) throws IOException {
        TreeTrace[] treeTraceArr = new TreeTrace[readerArr.length];
        for (int i2 = 0; i2 < readerArr.length; i2++) {
            try {
                treeTraceArr[i2] = TreeTrace.loadTreeTrace(readerArr[i2]);
                readerArr[i2].close();
            } catch (Importer.ImportException e) {
                throw new RuntimeException(e.toString());
            }
        }
        return new TreeTraceAnalysis(treeTraceArr, i, z);
    }

    static {
        $assertionsDisabled = !TreeTraceAnalysis.class.desiredAssertionStatus();
    }
}
