package dr.app.tools;

import dr.app.util.Arguments;
import dr.evolution.coalescent.CoalescentSimulator;
import dr.evolution.coalescent.ConstantPopulation;
import dr.evolution.coalescent.DemographicFunction;
import dr.evolution.coalescent.ExponentialGrowth;
import dr.evolution.tree.FlexibleNode;
import dr.evolution.tree.FlexibleTree;
import dr.evolution.tree.SimpleNode;
import dr.evolution.tree.SimpleTree;
import dr.evolution.tree.Tree;
import dr.evolution.util.Date;
import dr.evolution.util.Taxon;
import dr.evolution.util.Units;
import dr.evomodel.epidemiology.LogisticGrowthN0;
import dr.xml.XMLObject;
import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;

/* loaded from: input_file:dr/app/tools/TransmissionTreeToVirusTree.class */
public class TransmissionTreeToVirusTree {
    public static final String HELP = "help";
    public static final String DEMOGRAPHIC_MODEL = "demoModel";
    public static final String STARTING_POPULATION_SIZE = "N0";
    public static final String GROWTH_RATE = "growthRate";
    public static final String T50 = "t50";
    private DemographicFunction demFunct;
    private ArrayList<InfectedUnit> units;
    private HashMap<String, InfectedUnit> idMap;
    private String outputFileRoot;
    private double coalescentProbability;
    protected static PrintStream progressStream = System.out;
    public static final String[] demographics = {"Constant", "Exponential", "Logistic"};

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:dr/app/tools/TransmissionTreeToVirusTree$Event.class */
    public class Event implements Comparable<Event> {
        private EventType type;
        private double time;
        private InfectedUnit infector;
        private InfectedUnit infectee;

        private Event(EventType eventType, double d) {
            this.type = eventType;
            this.time = d;
        }

        private Event(EventType eventType, double d, InfectedUnit infectedUnit, InfectedUnit infectedUnit2) {
            this.type = eventType;
            this.time = d;
            this.infector = infectedUnit;
            this.infectee = infectedUnit2;
        }

        @Override // java.lang.Comparable
        public int compareTo(Event event) {
            return Double.compare(this.time, event.time);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:dr/app/tools/TransmissionTreeToVirusTree$EventType.class */
    public enum EventType {
        INFECTION,
        SAMPLE
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:dr/app/tools/TransmissionTreeToVirusTree$InfectedUnit.class */
    public class InfectedUnit {
        private String id;
        private ArrayList<Event> childEvents;
        private Event infectionEvent;
        private InfectedUnit parent;

        private InfectedUnit(String str) {
            this.id = str;
            this.parent = null;
            this.childEvents = new ArrayList<>();
        }

        /* JADX INFO: Access modifiers changed from: private */
        public void addSamplingEvent(double d) {
            if (d < this.infectionEvent.time) {
                throw new RuntimeException("Adding an event to case " + this.id + " before its infection time");
            }
            this.childEvents.add(new Event(EventType.SAMPLE, d));
        }

        private void setInfectionEvent(double d, InfectedUnit infectedUnit) {
            setInfectionEvent(new Event(EventType.INFECTION, d, infectedUnit, this));
        }

        /* JADX INFO: Access modifiers changed from: private */
        public void setInfectionEvent(Event event) {
            Iterator<Event> it = this.childEvents.iterator();
            while (it.hasNext()) {
                if (event.time > it.next().time) {
                    throw new RuntimeException("Setting infection time for case " + this.id + " after an existing child event");
                }
            }
            this.infectionEvent = event;
        }

        private void addChildInfectionEvent(double d, InfectedUnit infectedUnit) {
            addInfectionEvent(new Event(EventType.INFECTION, d, this, infectedUnit));
        }

        /* JADX INFO: Access modifiers changed from: private */
        public void addInfectionEvent(Event event) {
            if (this.infectionEvent != null && event.time < this.infectionEvent.time) {
                throw new RuntimeException("Adding an infection event to case " + this.id + " at " + event.time + " before its infection time at " + this.infectionEvent.time);
            }
            this.childEvents.add(event);
        }

        /* JADX INFO: Access modifiers changed from: private */
        public void sortEvents() {
            Collections.sort(this.childEvents);
            Collections.reverse(this.childEvents);
        }
    }

    /* loaded from: input_file:dr/app/tools/TransmissionTreeToVirusTree$ModelType.class */
    private enum ModelType {
        CONSTANT,
        EXPONENTIAL,
        LOGISTIC
    }

    public TransmissionTreeToVirusTree(String str, DemographicFunction demographicFunction, String str2) {
        this.demFunct = demographicFunction;
        this.units = new ArrayList<>();
        this.idMap = new HashMap<>();
        this.outputFileRoot = str2;
        this.coalescentProbability = 1.0d;
        try {
            readInfectionEvents(str);
            readSamplingEvents(str);
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    public TransmissionTreeToVirusTree(String str, String str2, DemographicFunction demographicFunction, String str3) {
        this.demFunct = demographicFunction;
        this.units = new ArrayList<>();
        this.idMap = new HashMap<>();
        this.outputFileRoot = str3;
        try {
            readInfectionEvents(str2);
            readSamplingEvents(str);
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    private void run() throws IOException {
        ArrayList<FlexibleTree> makeTrees = makeTrees();
        ArrayList arrayList = new ArrayList();
        Iterator<FlexibleTree> it = makeTrees.iterator();
        while (it.hasNext()) {
            FlexibleTree next = it.next();
            FlexibleTree makeWellBehavedTree = makeWellBehavedTree(next);
            makeWellBehavedTree.setAttribute("firstCase", next.getAttribute("firstCase"));
            arrayList.add(makeWellBehavedTree);
        }
        Iterator<FlexibleTree> it2 = makeTrees.iterator();
        while (it2.hasNext()) {
            FlexibleTree next2 = it2.next();
            new NexusExporter(new PrintStream(this.outputFileRoot + next2.getAttribute("firstCase") + "_detailed.nex")).exportTree(next2);
        }
        Iterator it3 = arrayList.iterator();
        while (it3.hasNext()) {
            FlexibleTree flexibleTree = (FlexibleTree) it3.next();
            new NexusExporter(new PrintStream(this.outputFileRoot + flexibleTree.getAttribute("firstCase") + "_simple.nex")).exportTree(flexibleTree);
        }
    }

    private void readInfectionEvents(String str) throws IOException {
        BufferedReader bufferedReader = new BufferedReader(new FileReader(str));
        ArrayList arrayList = new ArrayList();
        bufferedReader.readLine();
        String readLine = bufferedReader.readLine();
        while (true) {
            String str2 = readLine;
            if (str2 == null) {
                break;
            }
            String[] split = str2.split(",");
            arrayList.add(split);
            InfectedUnit infectedUnit = new InfectedUnit("ID_" + split[1]);
            this.units.add(infectedUnit);
            this.idMap.put("ID_" + split[1], infectedUnit);
            readLine = bufferedReader.readLine();
        }
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            String[] strArr = (String[]) it.next();
            InfectedUnit infectedUnit2 = this.idMap.get("ID_" + strArr[1]);
            if (strArr[2].equals("-1")) {
                infectedUnit2.setInfectionEvent(new Event(EventType.INFECTION, Double.parseDouble(strArr[3]), null, infectedUnit2));
            } else {
                InfectedUnit infectedUnit3 = this.idMap.get("ID_" + strArr[2]);
                Event event = new Event(EventType.INFECTION, Double.parseDouble(strArr[3]), infectedUnit3, infectedUnit2);
                infectedUnit3.addInfectionEvent(event);
                infectedUnit2.setInfectionEvent(event);
                infectedUnit2.parent = infectedUnit3;
            }
        }
    }

    private void readSamplingEvents(String str) throws IOException {
        BufferedReader bufferedReader = new BufferedReader(new FileReader(str));
        bufferedReader.readLine();
        String readLine = bufferedReader.readLine();
        while (true) {
            String str2 = readLine;
            if (str2 == null) {
                return;
            }
            String[] split = str2.split(",");
            if (!split[7].equals(XMLObject.missingValue)) {
                if (!this.idMap.containsKey("ID_" + split[1])) {
                    throw new RuntimeException("Trying to add a sampling event to unit " + split[2] + " but this unit not previously defined");
                }
                this.idMap.get("ID_" + split[1]).addSamplingEvent(Double.parseDouble(split[7]));
            }
            readLine = bufferedReader.readLine();
        }
    }

    private FlexibleTree makeTreelet(InfectedUnit infectedUnit, ArrayList<Event> arrayList) {
        FlexibleNode flexibleNode;
        if (arrayList.size() == 0) {
            return null;
        }
        ArrayList<SimpleNode> arrayList2 = new ArrayList<>();
        infectedUnit.sortEvents();
        double d = Double.NEGATIVE_INFINITY;
        Iterator<Event> it = arrayList.iterator();
        while (it.hasNext()) {
            Event next = it.next();
            if (next.time > d) {
                d = next.time;
            }
        }
        double d2 = d - infectedUnit.infectionEvent.time;
        Iterator<Event> it2 = arrayList.iterator();
        while (it2.hasNext()) {
            Event next2 = it2.next();
            Taxon taxon = next2.type == EventType.INFECTION ? new Taxon(next2.infectee.id + "_infected_by_" + next2.infector.id + "_" + next2.time) : new Taxon(infectedUnit.id + "_sampled_" + next2.time);
            taxon.setDate(new Date(next2.time - infectedUnit.infectionEvent.time, Units.Type.YEARS, false));
            SimpleNode simpleNode = new SimpleNode();
            simpleNode.setTaxon(taxon);
            arrayList2.add(simpleNode);
            simpleNode.setHeight(infectedUnit.infectionEvent.time - next2.time);
            simpleNode.setAttribute("Event", next2);
        }
        if (arrayList2.size() > 1) {
            flexibleNode = simulateCoalescent(arrayList2, this.demFunct, d2);
        } else {
            flexibleNode = new FlexibleNode(new SimpleTree(arrayList2.get(0)), arrayList2.get(0), true);
            flexibleNode.setHeight(0.0d);
        }
        FlexibleNode flexibleNode2 = new FlexibleNode();
        flexibleNode2.setHeight(d2);
        flexibleNode2.addChild(flexibleNode);
        flexibleNode.setLength(d2 - flexibleNode.getHeight());
        flexibleNode2.setAttribute("Event", infectedUnit.infectionEvent);
        FlexibleTree flexibleTree = new FlexibleTree(flexibleNode2);
        for (int i = 0; i < flexibleTree.getNodeCount(); i++) {
            ((FlexibleNode) flexibleTree.getNode(i)).setAttribute("Unit", infectedUnit.id);
        }
        return flexibleTree;
    }

    private ArrayList<FlexibleTree> makeTrees() {
        ArrayList arrayList = new ArrayList();
        Iterator<InfectedUnit> it = this.units.iterator();
        while (it.hasNext()) {
            InfectedUnit next = it.next();
            if (next.parent == null) {
                arrayList.add(next);
            }
        }
        if (arrayList.size() == 0) {
            throw new RuntimeException("Can't find a first case");
        }
        ArrayList<FlexibleTree> arrayList2 = new ArrayList<>();
        Iterator it2 = arrayList.iterator();
        while (it2.hasNext()) {
            InfectedUnit infectedUnit = (InfectedUnit) it2.next();
            this.coalescentProbability = 1.0d;
            System.out.println("Building tree for descendants of " + infectedUnit.id);
            FlexibleNode makeSubtree = makeSubtree(infectedUnit);
            if (makeSubtree != null) {
                FlexibleTree flexibleTree = new FlexibleTree(makeSubtree, false, true);
                flexibleTree.setAttribute("firstCase", infectedUnit.id);
                arrayList2.add(flexibleTree);
                if (this.coalescentProbability < 0.9d) {
                    progressStream.println("WARNING: any phylogeny for descendants of " + infectedUnit.id + " is quite improbable (p<" + this.coalescentProbability + ") given this demographic function. Consider another.");
                }
            } else {
                progressStream.println("This individual has no sampled descendants");
            }
            System.out.println();
        }
        return arrayList2;
    }

    private FlexibleNode makeSubtree(InfectedUnit infectedUnit) {
        HashMap hashMap = new HashMap();
        ArrayList<Event> arrayList = new ArrayList<>();
        Iterator it = infectedUnit.childEvents.iterator();
        while (it.hasNext()) {
            Event event = (Event) it.next();
            if (event.type == EventType.INFECTION) {
                FlexibleNode makeSubtree = makeSubtree(event.infectee);
                if (makeSubtree != null) {
                    arrayList.add(event);
                    hashMap.put(event, makeSubtree);
                }
            } else if (event.type == EventType.SAMPLE) {
                arrayList.add(event);
            }
        }
        FlexibleTree makeTreelet = makeTreelet(infectedUnit, arrayList);
        if (makeTreelet == null) {
            return null;
        }
        for (int i = 0; i < makeTreelet.getExternalNodeCount(); i++) {
            FlexibleNode flexibleNode = (FlexibleNode) makeTreelet.getExternalNode(i);
            Event event2 = (Event) makeTreelet.getNodeAttribute(flexibleNode, "Event");
            if (event2.type == EventType.INFECTION) {
                FlexibleNode flexibleNode2 = (FlexibleNode) hashMap.get(event2);
                FlexibleNode child = flexibleNode2.getChild(0);
                flexibleNode2.removeChild(child);
                flexibleNode.addChild(child);
            }
        }
        return (FlexibleNode) makeTreelet.getRoot();
    }

    private FlexibleNode simulateCoalescent(ArrayList<SimpleNode> arrayList, DemographicFunction demographicFunction, double d) {
        SimpleNode[] simulateCoalescent;
        double d2 = Double.NEGATIVE_INFINITY;
        Iterator<SimpleNode> it = arrayList.iterator();
        while (it.hasNext()) {
            SimpleNode next = it.next();
            if (next.getHeight() > d2) {
                d2 = next.getHeight();
            }
        }
        this.coalescentProbability *= 1.0d - Math.exp(demographicFunction.getIntensity(d2));
        CoalescentSimulator coalescentSimulator = new CoalescentSimulator();
        int i = 0;
        do {
            simulateCoalescent = coalescentSimulator.simulateCoalescent((SimpleNode[]) arrayList.toArray(new SimpleNode[arrayList.size()]), demographicFunction, -d, 0.0d, true);
            if (simulateCoalescent.length > 1) {
                i++;
                System.out.println("Failed to coalesce lineages: " + i);
            }
        } while (simulateCoalescent.length != 1);
        SimpleNode simpleNode = simulateCoalescent[0];
        SimpleTree simpleTree = new SimpleTree(simpleNode);
        for (int i2 = 0; i2 < simpleTree.getNodeCount(); i2++) {
            SimpleNode simpleNode2 = (SimpleNode) simpleTree.getNode(i2);
            simpleNode2.setHeight(simpleNode2.getHeight() + d);
        }
        return new FlexibleNode(simpleTree, simpleNode, true);
    }

    private FlexibleTree makeWellBehavedTree(FlexibleTree flexibleTree) {
        FlexibleTree flexibleTree2 = new FlexibleTree((Tree) flexibleTree, false);
        flexibleTree2.beginTreeEdit();
        for (int i = 0; i < flexibleTree2.getInternalNodeCount(); i++) {
            FlexibleNode flexibleNode = (FlexibleNode) flexibleTree2.getInternalNode(i);
            if (flexibleTree2.getChildCount(flexibleNode) == 1) {
                FlexibleNode flexibleNode2 = (FlexibleNode) flexibleTree2.getParent(flexibleNode);
                FlexibleNode flexibleNode3 = (FlexibleNode) flexibleTree2.getChild(flexibleNode, 0);
                if (flexibleNode2 != null) {
                    double nodeHeight = flexibleTree2.getNodeHeight(flexibleNode3);
                    flexibleTree2.removeChild(flexibleNode2, flexibleNode);
                    flexibleTree2.addChild(flexibleNode2, flexibleNode3);
                    flexibleTree2.setNodeHeight(flexibleNode3, nodeHeight);
                } else {
                    flexibleNode3.setParent(null);
                    flexibleTree2.setRoot(flexibleNode3);
                }
            }
        }
        flexibleTree2.endTreeEdit();
        return new FlexibleTree((Tree) flexibleTree2, true);
    }

    public static void printUsage(Arguments arguments) {
        arguments.printUsage("virusTreeBuilder", "<infections-file-name> <sample-file-name> <output-file-name-root>");
    }

    public static void main(String[] strArr) {
        ModelType modelType = ModelType.CONSTANT;
        double d = 1.0d;
        double d2 = 0.0d;
        double d3 = 0.0d;
        Arguments arguments = new Arguments(new Arguments.Option[]{new Arguments.StringOption(DEMOGRAPHIC_MODEL, demographics, false, "The type of within-host demographic function to use, default = constant"), new Arguments.RealOption(STARTING_POPULATION_SIZE, "The effective population size at time zero (used in all models), default = 1"), new Arguments.RealOption("growthRate", "The effective population size growth rate (used in exponential and logistic models), default = 0"), new Arguments.RealOption(T50, "The time point, relative to the time of infection in backwards time, at which the population is equal to half its final asymptotic value, in the logistic model default = 0")});
        try {
            arguments.parseArguments(strArr);
        } catch (Arguments.ArgumentException e) {
            System.out.println(e);
            printUsage(arguments);
            System.exit(1);
        }
        if (arguments.hasOption("help")) {
            printUsage(arguments);
            System.exit(0);
        }
        if (arguments.hasOption(DEMOGRAPHIC_MODEL)) {
            String stringOption = arguments.getStringOption(DEMOGRAPHIC_MODEL);
            if (stringOption.toLowerCase().startsWith("c")) {
                modelType = ModelType.CONSTANT;
            } else if (stringOption.toLowerCase().startsWith("e")) {
                modelType = ModelType.EXPONENTIAL;
            } else if (stringOption.toLowerCase().startsWith("l")) {
                modelType = ModelType.LOGISTIC;
            } else {
                progressStream.print("Unrecognised demographic model type");
                System.exit(1);
            }
        }
        if (arguments.hasOption(STARTING_POPULATION_SIZE)) {
            d = arguments.getRealOption(STARTING_POPULATION_SIZE);
        }
        if (arguments.hasOption("growthRate") && modelType != ModelType.CONSTANT) {
            d2 = arguments.getRealOption("growthRate");
        }
        if (arguments.hasOption(T50) && modelType == ModelType.LOGISTIC) {
            d3 = arguments.getRealOption(T50);
        }
        LogisticGrowthN0 logisticGrowthN0 = null;
        switch (modelType) {
            case CONSTANT:
                new ConstantPopulation(Units.Type.YEARS).setN0(d);
            case EXPONENTIAL:
                ExponentialGrowth exponentialGrowth = new ExponentialGrowth(Units.Type.YEARS);
                exponentialGrowth.setN0(d);
                exponentialGrowth.setGrowthRate(d2);
            case LOGISTIC:
                logisticGrowthN0 = new LogisticGrowthN0(Units.Type.YEARS);
                logisticGrowthN0.setN0(d);
                logisticGrowthN0.setGrowthRate(d2);
                logisticGrowthN0.setT50(d3);
                break;
        }
        String[] leftoverArguments = arguments.getLeftoverArguments();
        if (leftoverArguments.length != 3) {
            printUsage(arguments);
            System.exit(1);
        }
        try {
            new TransmissionTreeToVirusTree(leftoverArguments[1], leftoverArguments[0], logisticGrowthN0, leftoverArguments[2]).run();
        } catch (IOException e2) {
            e2.printStackTrace();
        }
    }
}
