package dr.app.tools;

import dr.evolution.io.Importer;
import dr.evolution.io.NexusImporter;
import dr.evolution.tree.FlexibleNode;
import dr.evolution.tree.FlexibleTree;
import dr.evolution.tree.NodeRef;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;

/* loaded from: input_file:dr/app/tools/TaxaOriginTrait.class */
public class TaxaOriginTrait {
    private FlexibleTree[] trees;
    private String[] taxaNames;
    private String traitName = getTraitName();
    private String attributeName;
    private String fileNameRoot;

    private TaxaOriginTrait(String[] strArr, FlexibleTree[] flexibleTreeArr, String str, String str2) {
        this.taxaNames = strArr;
        this.trees = flexibleTreeArr;
        this.attributeName = str;
        this.fileNameRoot = str2;
    }

    private FlexibleNode findCommonAncestor(FlexibleTree flexibleTree, FlexibleNode[] flexibleNodeArr) {
        HashSet hashSet = new HashSet();
        FlexibleNode flexibleNode = flexibleNodeArr[0];
        for (FlexibleNode flexibleNode2 : flexibleNodeArr) {
            hashSet.add(flexibleNode2);
            flexibleNode = flexibleNode2;
            boolean z = getTipSet(flexibleTree, flexibleNode).containsAll(hashSet);
            while (!z) {
                flexibleNode = (FlexibleNode) flexibleTree.getParent(flexibleNode);
                if (getTipSet(flexibleTree, flexibleNode).containsAll(hashSet)) {
                    z = true;
                }
            }
        }
        return flexibleNode;
    }

    private String getTraitName() {
        FlexibleTree flexibleTree = this.trees[0];
        String str = null;
        for (int i = 0; i < flexibleTree.getExternalNodeCount(); i++) {
            NodeRef externalNode = flexibleTree.getExternalNode(i);
            for (String str2 : this.taxaNames) {
                if (flexibleTree.getNodeTaxon(externalNode).getId().equals(str2)) {
                    String str3 = (String) flexibleTree.getNodeAttribute(externalNode, this.attributeName);
                    if (str != null && !str.equals(str3)) {
                        throw new RuntimeException("Not all taxa given have the same trait value");
                    }
                    str = str3;
                }
            }
        }
        return str;
    }

    private boolean branchNode(FlexibleTree flexibleTree, FlexibleNode flexibleNode) {
        return flexibleTree.getChildCount(flexibleNode) == 2;
    }

    private FlexibleNode[] getTipsOfInterest(FlexibleTree flexibleTree) {
        HashSet hashSet = new HashSet();
        for (String str : this.taxaNames) {
            for (int i = 0; i < flexibleTree.getExternalNodeCount(); i++) {
                if (flexibleTree.getNodeTaxon(flexibleTree.getExternalNode(i)).toString().equals(str)) {
                    hashSet.add((FlexibleNode) flexibleTree.getExternalNode(i));
                }
            }
        }
        return (FlexibleNode[]) hashSet.toArray(new FlexibleNode[hashSet.size()]);
    }

    private HashMap<String, String> getIncomingJumpOrigins(FlexibleTree flexibleTree) {
        HashMap<String, String> hashMap = new HashMap<>();
        FlexibleNode[] tipsOfInterest = getTipsOfInterest(flexibleTree);
        FlexibleNode findCommonAncestor = findCommonAncestor(flexibleTree, tipsOfInterest);
        if (!new HashSet(Arrays.asList(tipsOfInterest)).containsAll(getTipSet(flexibleTree, findCommonAncestor))) {
            System.out.println("WARNING: mixed traits in a clade");
        }
        if (findCommonAncestor.getAttribute(this.attributeName).equals(this.traitName)) {
            boolean z = true;
            FlexibleNode flexibleNode = findCommonAncestor;
            while (true) {
                if (!z) {
                    break;
                }
                flexibleNode = (FlexibleNode) flexibleTree.getParent(flexibleNode);
                if (flexibleNode == null) {
                    hashMap.put(this.traitName, "root");
                    break;
                }
                String str = (String) flexibleNode.getAttribute(this.attributeName);
                if (!str.equals(this.traitName)) {
                    z = false;
                    hashMap.put(this.traitName, str);
                }
            }
        } else {
            hashMap.put(this.traitName, "Multiple");
            System.out.println("Multiple origin found.");
        }
        return hashMap;
    }

    private HashSet<FlexibleNode> getTipSet(FlexibleTree flexibleTree, FlexibleNode flexibleNode) {
        HashSet<FlexibleNode> hashSet = new HashSet<>();
        if (flexibleTree.isExternal(flexibleNode)) {
            hashSet.add(flexibleNode);
            return hashSet;
        }
        for (int i = 0; i < flexibleTree.getChildCount(flexibleNode); i++) {
            hashSet.addAll(getTipSet(flexibleTree, (FlexibleNode) flexibleTree.getChild(flexibleNode, i)));
        }
        return hashSet;
    }

    private void tabulateOrigins() {
        HashMap hashMap = new HashMap();
        int i = 0;
        for (FlexibleTree flexibleTree : this.trees) {
            if (i % 1 == 0) {
                System.out.println("Doing tree " + i);
            }
            HashMap<String, String> incomingJumpOrigins = getIncomingJumpOrigins(flexibleTree);
            for (String str : incomingJumpOrigins.keySet()) {
                hashMap.put(incomingJumpOrigins.get(str), Integer.valueOf((hashMap.containsKey(incomingJumpOrigins.get(str)) ? ((Integer) hashMap.get(incomingJumpOrigins.get(str))).intValue() : 0) + 1));
            }
            i++;
        }
        try {
            BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(this.fileNameRoot + ".csv"));
            for (String str2 : hashMap.keySet()) {
                bufferedWriter.write(str2 + "," + hashMap.get(str2) + "\n");
            }
            bufferedWriter.flush();
        } catch (IOException e) {
            System.out.println("Failed to write to file");
        }
    }

    public static void main(String[] strArr) {
        try {
            BufferedReader bufferedReader = new BufferedReader(new FileReader(strArr[1]));
            HashSet hashSet = new HashSet();
            while (true) {
                String readLine = bufferedReader.readLine();
                if (readLine == null) {
                    break;
                } else {
                    hashSet.add(readLine);
                }
            }
            String[] strArr2 = (String[]) hashSet.toArray(new String[hashSet.size()]);
            NexusImporter nexusImporter = new NexusImporter(new FileReader(strArr[2]));
            NexusImporter.setSuppressWarnings(true);
            ArrayList arrayList = new ArrayList();
            int i = 0;
            while (nexusImporter.hasTree()) {
                if (i % 100 == 0) {
                    System.out.println("Loaded " + i + " trees");
                }
                arrayList.add((FlexibleTree) nexusImporter.importNextTree());
                i++;
            }
            new TaxaOriginTrait(strArr2, (FlexibleTree[]) arrayList.toArray(new FlexibleTree[arrayList.size()]), strArr[0], strArr[3]).tabulateOrigins();
        } catch (Importer.ImportException e) {
            System.out.println("Failed to import trees");
        } catch (IOException e2) {
            System.out.println("Failed to read files");
        }
    }
}
