package dr.app.tools;

import dr.app.beast.BeastVersion;
import dr.app.util.Arguments;
import dr.app.util.Utils;
import dr.evolution.io.Importer;
import dr.evolution.io.NexusImporter;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeUtils;
import dr.evomodel.arg.ARGModel;
import dr.evomodel.continuous.TopographicalMap;
import dr.inference.trace.TraceException;
import dr.util.Version;
import dr.xml.XMLObject;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.StringTokenizer;

/* loaded from: input_file:dr/app/tools/GetNSCountsFromTrees.class */
public class GetNSCountsFromTrees {
    public static final String BURNIN = "burnin";
    public static final String totalcN = "N";
    public static final String totalcS = "S";
    public static final String totaluN = "b_u_N";
    public static final String totaluS = "b_u_S";
    public static final String historyN = "all_N";
    public static final String historyS = "all_S";
    public static final String SEP = "\t";
    public static final String BRANCHINFO = "branchInfo";
    public static final String BRANCHSET = "branchSet";
    public static final String INCLUDECLADES = "includeClades";
    public static final String CLADESTEM = "cladeStem";
    public static final String EXCLUDECLADESTEM = "excludeCladeStem";
    public static final String BACKBONETAXA = "backboneTaxa";
    public static final String ZEROBRANCHES = "zeroBranches";
    public static final String SUMMARY = "summary";
    public static final String SITESUM = "siteSum";
    public static final String CODONSITELIST = "codonSiteList";
    public static final String MRSD = "mrsd";
    public static final String EXCLUDECLADES = "excludeClades";
    private boolean branchInfo;
    private boolean zeroBranches;
    private boolean summary;
    private int sites;
    private double mrsd;
    private boolean cladeStem;
    private boolean excludeCladeStems;
    private PrintStream resultsStream;
    private static final Version version = new BeastVersion();
    public static final String[] falseTrue = {"false", ARGModel.IS_REASSORTMENT};
    private static PrintStream progressStream = System.err;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:dr/app/tools/GetNSCountsFromTrees$BranchSet.class */
    public enum BranchSet {
        ALL,
        INT,
        EXT,
        BACKBONE,
        CLADE
    }

    public GetNSCountsFromTrees(int i, String str, String str2, boolean z, BranchSet branchSet, List<Set> list, boolean z2, boolean z3, boolean z4, int i2, double d, List<Set> list2, boolean z5, double[] dArr) throws IOException {
        File file = new File(str);
        if (file.isFile()) {
            System.out.println("Analysing tree file " + str + " with a burn-in of " + i + " trees...");
        } else {
            progressStream.println("cannot find " + str);
            System.exit(0);
        }
        this.branchInfo = z;
        this.zeroBranches = z3;
        this.summary = z4;
        this.mrsd = d;
        this.cladeStem = z2;
        this.excludeCladeStems = z5;
        this.sites = i2;
        if (dArr != null) {
            this.sites = dArr.length * 3;
            progressStream.println("number of sites set based on site list provided");
        }
        if (str2 != null) {
            try {
                this.resultsStream = new PrintStream(new File(str2));
            } catch (IOException e) {
                progressStream.println("Error opening file: " + str2);
                System.exit(-1);
            }
        } else {
            this.resultsStream = new PrintStream(new File(str + ".NSout.txt"));
        }
        analyze(file, i, branchSet, list, list2, dArr);
    }

    private void analyze(File file, int i, BranchSet branchSet, List<Set> list, List<Set> list2, double[] dArr) {
        if (this.summary) {
            this.resultsStream.print("tree\tcN\tuN\tcS\tuS\tcNrate\tcSrate\tdN/dS\n");
        } else {
            this.resultsStream.print("tree\tbranch\tN/S\tsite\theight/date\tfromState\ttoState");
            if (this.branchInfo) {
                this.resultsStream.print("\tbranchLength\tbranchCN/S\tbranchUN/S\n");
            } else {
                this.resultsStream.print("\n");
            }
        }
        int i2 = 10000;
        int i3 = 0;
        System.out.println("Reading and analyzing trees (bar assumes 10,000 trees)...");
        System.out.println("0              25             50             75            100");
        System.out.println("|--------------|--------------|--------------|--------------|");
        int i4 = 10000 / 60;
        if (i4 < 1) {
            i4 = 1;
        }
        int i5 = 0;
        int i6 = 1;
        try {
            NexusImporter nexusImporter = new NexusImporter(new FileReader(file));
            while (nexusImporter.hasTree()) {
                Tree importNextTree = nexusImporter.importNextTree();
                if (i5 >= i) {
                    getNSCounts(importNextTree, i6, branchSet, list, list2, dArr);
                    i6++;
                }
                i5++;
                if (i2 > 0 && i2 % i4 == 0) {
                    System.out.print(TopographicalMap.defaultInvalidString);
                    i3++;
                    if (i3 % 61 == 0) {
                        System.out.print("\n");
                    }
                    System.out.flush();
                }
                i2++;
            }
            System.out.print("\n");
        } catch (Importer.ImportException e) {
            progressStream.println("Error Parsing Input Tree: " + e.getMessage());
        } catch (IOException e2) {
            progressStream.println("Error Parsing Input Tree: " + e2.getMessage());
        }
    }

    private void getNSCounts(Tree tree, int i, BranchSet branchSet, List<Set> list, List<Set> list2, double[] dArr) {
        int i2 = 0;
        double d = 0.0d;
        double d2 = 0.0d;
        double d3 = 0.0d;
        double d4 = 0.0d;
        double d5 = 0.0d;
        for (int i3 = 0; i3 < tree.getNodeCount(); i3++) {
            NodeRef node = tree.getNode(i3);
            if (!tree.isRoot(node)) {
                i2++;
                boolean z = false;
                if (branchSet == BranchSet.ALL) {
                    z = true;
                } else if (branchSet == BranchSet.EXT && tree.isExternal(node)) {
                    z = true;
                } else if (branchSet == BranchSet.INT && !tree.isExternal(node)) {
                    z = true;
                } else if (branchSet == BranchSet.BACKBONE) {
                    Iterator<Set> it = list.iterator();
                    while (it.hasNext()) {
                        if (onBackbone(tree, node, it.next())) {
                            z = true;
                        }
                    }
                } else if (branchSet == BranchSet.CLADE) {
                    Iterator<Set> it2 = list.iterator();
                    while (it2.hasNext()) {
                        if (inClade(tree, node, it2.next(), this.cladeStem)) {
                            z = true;
                        }
                    }
                }
                if (z && list2.size() > 0) {
                    Iterator<Set> it3 = list2.iterator();
                    while (true) {
                        if (it3.hasNext()) {
                            if (inClade(tree, node, it3.next(), this.excludeCladeStems)) {
                                z = false;
                                break;
                            }
                        } else {
                            break;
                        }
                    }
                }
                if (z) {
                    double branchLength = tree.getBranchLength(node);
                    d5 += branchLength;
                    Object nodeAttribute = tree.getNodeAttribute(node, "N");
                    Object nodeAttribute2 = tree.getNodeAttribute(node, totalcS);
                    if (nodeAttribute != null && nodeAttribute2 != null) {
                        double doubleValue = ((Double) nodeAttribute).doubleValue();
                        double doubleValue2 = ((Double) nodeAttribute2).doubleValue();
                        double doubleValue3 = ((Double) tree.getNodeAttribute(node, totaluN)).doubleValue();
                        double doubleValue4 = ((Double) tree.getNodeAttribute(node, totaluS)).doubleValue();
                        if (dArr == null) {
                            d += doubleValue;
                            d3 += doubleValue2;
                        }
                        d2 += doubleValue3;
                        d4 += doubleValue4;
                        if (doubleValue > 0.0d) {
                            for (Object obj : (Object[]) tree.getNodeAttribute(node, historyN)) {
                                Object[] objArr = (Object[]) obj;
                                boolean z2 = false;
                                if (dArr == null) {
                                    z2 = true;
                                } else if (inSiteList((Integer) objArr[0], dArr)) {
                                    z2 = true;
                                    d += 1.0d;
                                }
                                if (!this.summary && z2) {
                                    this.resultsStream.print(i + "\t" + i2 + "\tN\t");
                                    this.resultsStream.print(objArr[0] + "\t");
                                    if (this.mrsd > 0.0d) {
                                        this.resultsStream.print((this.mrsd - ((Double) objArr[1]).doubleValue()) + "\t");
                                    } else {
                                        this.resultsStream.print(objArr[1] + "\t");
                                    }
                                    this.resultsStream.print(objArr[2] + "\t" + objArr[3] + "\t");
                                    if (this.branchInfo) {
                                        this.resultsStream.print(branchLength + "\t" + doubleValue + "\t" + doubleValue3 + "\n");
                                    } else {
                                        this.resultsStream.print("\n");
                                    }
                                }
                            }
                        }
                        if (doubleValue2 > 0.0d) {
                            for (Object obj2 : (Object[]) tree.getNodeAttribute(node, historyS)) {
                                Object[] objArr2 = (Object[]) obj2;
                                boolean z3 = false;
                                if (dArr == null) {
                                    z3 = true;
                                } else if (inSiteList((Integer) objArr2[0], dArr)) {
                                    z3 = true;
                                    d3 += 1.0d;
                                }
                                if (!this.summary && z3) {
                                    this.resultsStream.print(i + "\t" + i2 + "\t" + totalcS + "\t");
                                    this.resultsStream.print(objArr2[0] + "\t");
                                    if (this.mrsd > 0.0d) {
                                        this.resultsStream.print((this.mrsd - ((Double) objArr2[1]).doubleValue()) + "\t");
                                    } else {
                                        this.resultsStream.print(objArr2[1] + "\t");
                                    }
                                    this.resultsStream.print(objArr2[2] + "\t" + objArr2[3] + "\t");
                                    if (this.branchInfo) {
                                        this.resultsStream.print(branchLength + "\t" + doubleValue2 + "\t" + doubleValue4 + "\n");
                                    } else {
                                        this.resultsStream.print("\n");
                                    }
                                }
                            }
                        }
                        if (doubleValue + doubleValue2 == 0.0d) {
                            if (!this.zeroBranches) {
                                d5 -= branchLength;
                            } else if (!this.summary) {
                                this.resultsStream.print(i + "\t" + i2 + "\t" + XMLObject.missingValue + "\t");
                                this.resultsStream.print("NA\tNA\tNA\tNA\t");
                                if (this.branchInfo) {
                                    this.resultsStream.print(branchLength + "\t" + XMLObject.missingValue + "\tNA\n");
                                } else {
                                    this.resultsStream.print("\n");
                                }
                            }
                        }
                    }
                }
            }
        }
        if (this.summary) {
            this.resultsStream.print(i + "\t" + d + "\t" + d2 + "\t" + d3 + "\t" + d4 + "\t" + (d / (d5 * this.sites)) + "\t" + (d3 / (d5 * this.sites)) + "\t" + ((d / d2) / (d3 / d4)) + "\n");
        }
    }

    private boolean inSiteList(Integer num, double[] dArr) {
        boolean z = false;
        int length = dArr.length;
        int i = 0;
        while (true) {
            if (i >= length) {
                break;
            }
            if (num.intValue() == dArr[i]) {
                z = true;
                break;
            }
            i++;
        }
        return z;
    }

    private static Set getTargetSet(String str) {
        HashSet hashSet = new HashSet();
        try {
            BufferedReader bufferedReader = new BufferedReader(new FileReader(str));
            try {
                String trim = bufferedReader.readLine().trim();
                while (trim != null) {
                    if (trim.equals("")) {
                        break;
                    }
                    hashSet.add(trim);
                    trim = bufferedReader.readLine();
                    if (trim != null) {
                        trim = trim.trim();
                    }
                }
            } catch (IOException e) {
                progressStream.println("Error reading " + str);
            }
        } catch (FileNotFoundException e2) {
            progressStream.println("Error finding " + str);
        }
        return hashSet;
    }

    private static boolean onBackbone(Tree tree, NodeRef nodeRef, Set set) {
        if (tree.isExternal(nodeRef)) {
            return false;
        }
        Set<String> descendantLeaves = TreeUtils.getDescendantLeaves(tree, nodeRef);
        int size = descendantLeaves.size();
        descendantLeaves.retainAll(set);
        if (descendantLeaves.size() <= 0) {
            return false;
        }
        if (descendantLeaves.size() != size) {
            return true;
        }
        Set<String> descendantLeaves2 = TreeUtils.getDescendantLeaves(tree, tree.getParent(nodeRef));
        descendantLeaves2.removeAll(set);
        return descendantLeaves2.size() > 0;
    }

    private static boolean inClade(Tree tree, NodeRef nodeRef, Set set, boolean z) {
        Set<String> descendantLeaves = TreeUtils.getDescendantLeaves(tree, nodeRef);
        descendantLeaves.removeAll(set);
        if (descendantLeaves.size() != 0) {
            return false;
        }
        if (z) {
            return true;
        }
        Set<String> descendantLeaves2 = TreeUtils.getDescendantLeaves(tree, tree.getParent(nodeRef));
        descendantLeaves2.removeAll(set);
        return descendantLeaves2.size() == 0;
    }

    private static String[] parseVariableLengthStringArray(String str) {
        ArrayList arrayList = new ArrayList();
        StringTokenizer stringTokenizer = new StringTokenizer(str, ",");
        while (stringTokenizer.hasMoreTokens()) {
            arrayList.add(stringTokenizer.nextToken());
        }
        if (arrayList.size() > 0) {
            return (String[]) arrayList.toArray(new String[arrayList.size()]);
        }
        return null;
    }

    public static double[] parseVariableLengthDoubleArray(String str) {
        ArrayList arrayList = new ArrayList();
        StringTokenizer stringTokenizer = new StringTokenizer(str, ",");
        while (stringTokenizer.hasMoreTokens()) {
            arrayList.add(Double.valueOf(Double.parseDouble(stringTokenizer.nextToken())));
        }
        if (arrayList.size() <= 0) {
            return null;
        }
        double[] dArr = new double[arrayList.size()];
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = ((Double) arrayList.get(i)).doubleValue();
        }
        return dArr;
    }

    public static void centreLine(String str, int i) {
        int length = (i - str.length()) / 2;
        for (int i2 = 0; i2 < length; i2++) {
            System.out.print(" ");
        }
        System.out.println(str);
    }

    public static void printUsage(Arguments arguments) {
        arguments.printUsage("GetNSCountsFromTrees", "[-burnin <burnin>][<input-file-name> [<output-file-name>]]");
        progressStream.println();
        progressStream.println("  Example: GetNSCountsFromTrees -burnin 10000 trees.log out.txt");
        progressStream.println();
    }

    public static void printTitle() {
        System.out.println();
        centreLine("GetNSCountsFromTrees " + version.getVersionString() + ", " + version.getDateString(), 60);
        centreLine("MCMC Output analysis", 60);
        centreLine("by", 60);
        centreLine("Philippe Lemey and Marc Suchard", 60);
        System.out.println();
        centreLine("Department of Immunology and Microbiology", 60);
        centreLine("KU Leuven -- University of Leuven", 60);
        centreLine("philippe.lemey@kuleuven.be", 60);
        System.out.println();
        centreLine("Department of Biomathematics", 60);
        centreLine("University of Califormia, Los Angeles", 60);
        centreLine("msuchard@ucla.edu", 60);
        System.out.println();
        System.out.println();
    }

    public static void main(String[] strArr) throws IOException, TraceException {
        printTitle();
        Arguments arguments = new Arguments(new Arguments.Option[]{new Arguments.IntegerOption("burnin", "the number of states to be considered as 'burn-in' [default = 0]"), new Arguments.StringOption(BRANCHINFO, falseTrue, false, "include a summary for the root [default=off]"), new Arguments.StringOption("branchSet", TimeSlicer.enumNamesToStringArray(BranchSet.values()), false, "branch set [default = all]"), new Arguments.StringOption(INCLUDECLADES, "clade exclusion files", "specifies files with taxa that define clades to be excluded"), new Arguments.StringOption(CLADESTEM, falseTrue, false, "include clade stem [default=false]"), new Arguments.StringOption(EXCLUDECLADESTEM, falseTrue, true, "include clade stem in the exclusion [default=true]"), new Arguments.StringOption(BACKBONETAXA, "Backbone taxa file", "specifies a file with taxa that define the backbone"), new Arguments.StringOption(ZEROBRANCHES, falseTrue, true, "include branches with 0 N and S subtitutions [default=included]"), new Arguments.StringOption("summary", falseTrue, true, "provide a summary of the N and S counts per tree [default=detailed output]"), new Arguments.RealOption("mrsd", "specifies the most recent sampling data in fractional years to rescale time [default=0]"), new Arguments.StringOption(EXCLUDECLADES, "clade exclusion files", "specifies files with taxa that define clades to be excluded"), new Arguments.IntegerOption(SITESUM, "the number of nucleotide sites to summarize rates in per site per time unit [default = 1]"), new Arguments.StringOption(CODONSITELIST, "list of sites", "sites for which the summary is restricted to"), new Arguments.Option("help", "option to print this message")});
        try {
            arguments.parseArguments(strArr);
        } catch (Arguments.ArgumentException e) {
            System.out.println(e);
            printUsage(arguments);
            System.exit(1);
        }
        if (arguments.hasOption("help")) {
            printUsage(arguments);
            System.exit(0);
        }
        int integerOption = arguments.hasOption("burnin") ? arguments.getIntegerOption("burnin") : -1;
        boolean z = false;
        String stringOption = arguments.getStringOption(BRANCHINFO);
        if (stringOption != null && stringOption.compareToIgnoreCase(ARGModel.IS_REASSORTMENT) == 0) {
            z = true;
        }
        BranchSet branchSet = BranchSet.ALL;
        String stringOption2 = arguments.getStringOption("branchSet");
        ArrayList arrayList = new ArrayList();
        if (stringOption2 != null) {
            branchSet = BranchSet.valueOf(stringOption2.toUpperCase());
            progressStream.println("Using the branch set: " + branchSet.name());
        }
        if (branchSet == BranchSet.BACKBONE) {
            if (arguments.hasOption(BACKBONETAXA)) {
                for (String str : parseVariableLengthStringArray(arguments.getStringOption(BACKBONETAXA))) {
                    arrayList.add(getTargetSet(str));
                    progressStream.println("getting target set for backbone inclusion: " + str);
                }
            } else {
                progressStream.println("you want to get summaries for (a) backbone(s), but no files with taxa to define it are provided??");
            }
        }
        if (branchSet == BranchSet.CLADE) {
            if (arguments.hasOption(INCLUDECLADES)) {
                for (String str2 : parseVariableLengthStringArray(arguments.getStringOption(INCLUDECLADES))) {
                    arrayList.add(getTargetSet(str2));
                    progressStream.println("getting target set for clade inclusion: " + str2);
                }
            } else {
                progressStream.println("you want to get summaries for one or more clades, but no files with taxa to define it are provided??");
            }
        }
        boolean z2 = false;
        String stringOption3 = arguments.getStringOption(CLADESTEM);
        if (stringOption3 != null && stringOption3.compareToIgnoreCase(ARGModel.IS_REASSORTMENT) == 0) {
            z2 = true;
        }
        boolean z3 = true;
        String stringOption4 = arguments.getStringOption(CLADESTEM);
        if (stringOption4 != null && stringOption4.compareToIgnoreCase("false") == 0) {
            z3 = false;
        }
        boolean z4 = true;
        String stringOption5 = arguments.getStringOption(ZEROBRANCHES);
        if (stringOption5 != null && stringOption5.compareToIgnoreCase("false") == 0) {
            z4 = false;
        }
        boolean z5 = false;
        String stringOption6 = arguments.getStringOption("summary");
        if (stringOption6 != null && stringOption6.compareToIgnoreCase(ARGModel.IS_REASSORTMENT) == 0) {
            z5 = true;
        }
        int integerOption2 = arguments.hasOption(SITESUM) ? arguments.getIntegerOption(SITESUM) : 1;
        double realOption = arguments.hasOption("mrsd") ? arguments.getRealOption("mrsd") : 0.0d;
        ArrayList arrayList2 = new ArrayList();
        if (arguments.hasOption(EXCLUDECLADES)) {
            for (String str3 : parseVariableLengthStringArray(arguments.getStringOption(EXCLUDECLADES))) {
                arrayList2.add(getTargetSet(str3));
                progressStream.println("getting target set for clade exclusion: " + str3);
            }
        }
        double[] dArr = null;
        if (arguments.hasOption(CODONSITELIST)) {
            dArr = parseVariableLengthDoubleArray(arguments.getStringOption(CODONSITELIST));
            progressStream.println("site list provided: note that dN/dS will not be accurately estimated because the neutral expectation to get dN/dS (uN and uS) is for all sites along a branch.");
        }
        String[] leftoverArguments = arguments.getLeftoverArguments();
        if (leftoverArguments.length > 2) {
            progressStream.println("Unknown option: " + leftoverArguments[2]);
            System.err.println();
            printUsage(arguments);
            System.exit(1);
        }
        String str4 = leftoverArguments.length > 0 ? leftoverArguments[0] : null;
        String str5 = leftoverArguments.length > 1 ? leftoverArguments[1] : null;
        if (str4 == null) {
            str4 = Utils.getLoadFileName("GetNSCountsFromTrees " + version.getVersionString() + " - Select tree file to analyse");
        }
        if (integerOption == -1) {
            System.err.println("Enter number of trees to burn-in (integer): ");
            integerOption = Integer.parseInt(new BufferedReader(new InputStreamReader(System.in)).readLine());
        }
        new GetNSCountsFromTrees(integerOption, str4, str5, z, branchSet, arrayList, z2, z4, z5, integerOption2, realOption, arrayList2, z3, dArr);
        System.exit(0);
    }
}
