package dr.evomodel.epidemiology.casetocase;

import dr.app.tools.NexusExporter;
import dr.evolution.tree.FlexibleNode;
import dr.evolution.tree.FlexibleTree;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeTrait;
import dr.evolution.tree.TreeTraitProvider;
import dr.evolution.util.Taxon;
import dr.evolution.util.TaxonList;
import dr.evomodel.epidemiology.casetocase.BranchMapModel;
import dr.evomodel.epidemiology.casetocase.PartitionedTreeModel;
import dr.evomodel.epidemiology.casetocase.periodpriors.AbstractPeriodPriorDistribution;
import dr.evomodel.tree.TreeModel;
import dr.evoxml.util.GraphMLUtils;
import dr.inference.loggers.LogColumn;
import dr.inference.loggers.Loggable;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.oldevomodel.treelikelihood.AbstractTreeLikelihood;
import dr.util.Author;
import dr.util.Citable;
import dr.util.Citation;
import java.io.IOException;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import org.apache.commons.math.stat.descriptive.moment.Mean;
import org.apache.commons.math.stat.descriptive.moment.Variance;
import org.apache.commons.math.stat.descriptive.rank.Median;

/* loaded from: input_file:dr/evomodel/epidemiology/casetocase/CaseToCaseTreeLikelihood.class */
public abstract class CaseToCaseTreeLikelihood extends AbstractTreeLikelihood implements Loggable, Citable, TreeTraitProvider {
    protected static final boolean DEBUG = false;
    protected static double tolerance = 1.0E-10d;
    protected int noTips;
    protected int noCases;
    private double estimatedLastSampleTime;
    protected TreeTraitProvider.Helper treeTraits;
    protected AbstractOutbreak outbreak;
    protected double[] infectionTimes;
    private double[] storedInfectionTimes;
    protected double[] infectiousPeriods;
    private double[] storedInfectiousPeriods;
    protected double[] infectiousTimes;
    private double[] storedInfectiousTimes;
    protected double[] latentPeriods;
    private double[] storedLatentPeriods;
    protected boolean[] recalculateCaseFlags;
    protected HashMap<AbstractCase, Treelet> elementsAsTrees;
    protected HashMap<AbstractCase, Treelet> storedElementsAsTrees;
    protected Parameter maxFirstInfToRoot;
    protected boolean hasLatentPeriods;
    public static final String CASE_TO_CASE_TREE_LIKELIHOOD = "caseToCaseTreeLikelihood";
    public static final String PARTITIONS_KEY = "partition";

    /* loaded from: input_file:dr/evomodel/epidemiology/casetocase/CaseToCaseTreeLikelihood$Treelet.class */
    protected class Treelet extends FlexibleTree {
        private double zeroHeight;

        /* JADX INFO: Access modifiers changed from: protected */
        public Treelet(FlexibleTree flexibleTree, double d) {
            super(flexibleTree);
            this.zeroHeight = d;
        }

        /* JADX INFO: Access modifiers changed from: protected */
        public double getZeroHeight() {
            return this.zeroHeight;
        }

        protected void setZeroHeight(double d) {
            this.zeroHeight = this.zeroHeight;
        }
    }

    public CaseToCaseTreeLikelihood(PartitionedTreeModel partitionedTreeModel, AbstractOutbreak abstractOutbreak, Parameter parameter) throws TaxonList.MissingTaxonException {
        this(CASE_TO_CASE_TREE_LIKELIHOOD, partitionedTreeModel, abstractOutbreak, parameter);
    }

    public CaseToCaseTreeLikelihood(String str, PartitionedTreeModel partitionedTreeModel, AbstractOutbreak abstractOutbreak, Parameter parameter) {
        super(str, abstractOutbreak, partitionedTreeModel);
        this.treeTraits = new TreeTraitProvider.Helper();
        if (this.stateCount != this.treeModel.getExternalNodeCount()) {
            throw new RuntimeException("There are duplicate tip outbreak.");
        }
        this.noTips = partitionedTreeModel.getExternalNodeCount();
        this.outbreak = abstractOutbreak;
        this.noCases = this.outbreak.getCases().size();
        addModel(this.outbreak);
        this.estimatedLastSampleTime = getLatestTaxonTime();
        addModel(partitionedTreeModel.getBranchMap());
        this.hasLatentPeriods = this.outbreak.hasLatentPeriods();
        this.infectionTimes = new double[this.outbreak.size()];
        this.infectiousPeriods = new double[this.outbreak.size()];
        if (this.hasLatentPeriods) {
            this.infectiousTimes = new double[this.outbreak.size()];
            this.latentPeriods = new double[this.outbreak.size()];
        }
        this.recalculateCaseFlags = new boolean[this.outbreak.size()];
        Arrays.fill(this.recalculateCaseFlags, true);
        this.maxFirstInfToRoot = parameter;
        this.treeTraits.addTrait("partition", new TreeTrait.S() { // from class: dr.evomodel.epidemiology.casetocase.CaseToCaseTreeLikelihood.1
            @Override // dr.evolution.tree.TreeTrait
            public String getTraitName() {
                return "partition";
            }

            @Override // dr.evolution.tree.TreeTrait
            public TreeTrait.Intent getIntent() {
                return TreeTrait.Intent.NODE;
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // dr.evolution.tree.TreeTrait
            public String getTrait(Tree tree, NodeRef nodeRef) {
                return CaseToCaseTreeLikelihood.this.getNodePartition(tree, nodeRef);
            }
        });
        this.likelihoodKnown = false;
    }

    public AbstractOutbreak getOutbreak() {
        return this.outbreak;
    }

    public boolean hasLatentPeriods() {
        return this.hasLatentPeriods;
    }

    private double getLatestTaxonTime() {
        double d = Double.NEGATIVE_INFINITY;
        for (int i = 0; i < this.treeModel.getExternalNodeCount(); i++) {
            Taxon nodeTaxon = this.treeModel.getNodeTaxon(this.treeModel.getExternalNode(i));
            if (nodeTaxon.getDate().getTimeValue() > d) {
                d = nodeTaxon.getDate().getTimeValue();
            }
        }
        return d;
    }

    private NodeRef[] getChildren(NodeRef nodeRef) {
        NodeRef[] nodeRefArr = new NodeRef[this.treeModel.getChildCount(nodeRef)];
        for (int i = 0; i < this.treeModel.getChildCount(nodeRef); i++) {
            nodeRefArr[i] = this.treeModel.getChild(nodeRef, i);
        }
        return nodeRefArr;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void explodeTree() {
        for (int i = 0; i < this.outbreak.size(); i++) {
            AbstractCase abstractCase = this.outbreak.getCase(i);
            if (abstractCase.wasEverInfected() && this.elementsAsTrees.get(abstractCase) == null) {
                NodeRef earliestNodeInElement = ((PartitionedTreeModel) this.treeModel).getEarliestNodeInElement(abstractCase);
                double parameterValue = this.treeModel.isRoot(earliestNodeInElement) ? this.maxFirstInfToRoot.getParameterValue(0) * abstractCase.getInfectionBranchPosition().getParameterValue(0) : this.treeModel.getBranchLength(earliestNodeInElement) * abstractCase.getInfectionBranchPosition().getParameterValue(0);
                FlexibleNode flexibleNode = new FlexibleNode();
                FlexibleTree flexibleTree = new FlexibleTree(flexibleNode);
                flexibleTree.beginTreeEdit();
                if (!this.treeModel.isExternal(earliestNodeInElement)) {
                    for (int i2 = 0; i2 < this.treeModel.getChildCount(earliestNodeInElement); i2++) {
                        copyElementToTreelet(flexibleTree, this.treeModel.getChild(earliestNodeInElement, i2), flexibleNode, abstractCase);
                    }
                }
                flexibleTree.endTreeEdit();
                flexibleTree.resolveTree();
                this.elementsAsTrees.put(abstractCase, new Treelet(flexibleTree, flexibleTree.getRootHeight() + parameterValue));
            }
        }
    }

    private void copyElementToTreelet(FlexibleTree flexibleTree, NodeRef nodeRef, NodeRef nodeRef2, AbstractCase abstractCase) {
        if (abstractCase.wasEverInfected()) {
            if (getBranchMap().get(nodeRef.getNumber()) != abstractCase) {
                FlexibleNode flexibleNode = new FlexibleNode(new Taxon("Transmission_" + getBranchMap().get(nodeRef.getNumber()).getName()));
                double nodeTime = getNodeTime(this.treeModel.getParent(nodeRef));
                double infectionTime = getInfectionTime(getBranchMap().get(nodeRef.getNumber()));
                flexibleTree.addChild(nodeRef2, flexibleNode);
                flexibleTree.setBranchLength(flexibleNode, infectionTime - nodeTime);
                return;
            }
            if (this.treeModel.isExternal(nodeRef)) {
                FlexibleNode flexibleNode2 = new FlexibleNode(new Taxon(this.treeModel.getNodeTaxon(nodeRef).getId()));
                flexibleTree.addChild(nodeRef2, flexibleNode2);
                flexibleTree.setBranchLength(flexibleNode2, this.treeModel.getBranchLength(nodeRef));
                return;
            }
            FlexibleNode flexibleNode3 = new FlexibleNode();
            flexibleTree.addChild(nodeRef2, flexibleNode3);
            flexibleTree.setBranchLength(flexibleNode3, this.treeModel.getBranchLength(nodeRef));
            for (int i = 0; i < this.treeModel.getChildCount(nodeRef); i++) {
                copyElementToTreelet(flexibleTree, this.treeModel.getChild(nodeRef, i), flexibleNode3, abstractCase);
            }
        }
    }

    public HashSet<AbstractCase> descendantTipPartitions(NodeRef nodeRef, HashMap<Integer, HashSet<AbstractCase>> hashMap) {
        HashSet<AbstractCase> hashSet = new HashSet<>();
        if (this.treeModel.isExternal(nodeRef)) {
            hashSet.add(getBranchMap().get(nodeRef.getNumber()));
            if (hashMap != null) {
                hashMap.put(Integer.valueOf(nodeRef.getNumber()), hashSet);
            }
            return hashSet;
        }
        for (int i = 0; i < this.treeModel.getChildCount(nodeRef); i++) {
            hashSet.addAll(descendantTipPartitions(this.treeModel.getChild(nodeRef, i), hashMap));
        }
        if (hashMap != null) {
            hashMap.put(Integer.valueOf(nodeRef.getNumber()), hashSet);
        }
        return hashSet;
    }

    protected static void flagForDescendantRecalculation(TreeModel treeModel, NodeRef nodeRef, boolean[] zArr) {
        zArr[nodeRef.getNumber()] = true;
        for (int i = 0; i < treeModel.getChildCount(nodeRef); i++) {
            zArr[treeModel.getChild(nodeRef, i).getNumber()] = true;
        }
        NodeRef nodeRef2 = nodeRef;
        while (!treeModel.isRoot(nodeRef2) && !zArr[nodeRef2.getNumber()]) {
            nodeRef2 = treeModel.getParent(nodeRef2);
            zArr[nodeRef2.getNumber()] = true;
        }
    }

    public void flagForDescendantRecalculation(TreeModel treeModel, NodeRef nodeRef) {
        flagForDescendantRecalculation(treeModel, nodeRef, this.updateNode);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // dr.oldevomodel.treelikelihood.AbstractTreeLikelihood, dr.inference.model.AbstractModel
    public void handleModelChangedEvent(Model model, Object obj, int i) {
        if (model instanceof AbstractPeriodPriorDistribution) {
            return;
        }
        if (model == this.treeModel) {
            if (obj instanceof PartitionedTreeModel.PartitionsChangedEvent) {
                Iterator<AbstractCase> it = ((PartitionedTreeModel.PartitionsChangedEvent) obj).getCasesToRecalculate().iterator();
                while (it.hasNext()) {
                    recalculateCase(it.next());
                }
            }
        } else if (model == getBranchMap()) {
            if (!(obj instanceof ArrayList)) {
                throw new RuntimeException("Unanticipated model changed event from BranchMapModel");
            }
            for (int i2 = 0; i2 < ((ArrayList) obj).size(); i2++) {
                BranchMapModel.BranchMapChangedEvent branchMapChangedEvent = (BranchMapModel.BranchMapChangedEvent) ((ArrayList) obj).get(i2);
                recalculateCase(branchMapChangedEvent.getOldCase());
                recalculateCase(branchMapChangedEvent.getNewCase());
                NodeRef parent = this.treeModel.getParent(this.treeModel.getNode(branchMapChangedEvent.getNodeToRecalculate()));
                if (parent != null) {
                    recalculateCase(getBranchMap().get(parent.getNumber()));
                }
            }
        } else if (model == this.outbreak) {
            if (obj instanceof AbstractCase) {
                recalculateCase((AbstractCase) obj);
            } else {
                Iterator<AbstractCase> it2 = this.outbreak.getCases().iterator();
                while (it2.hasNext()) {
                    recalculateCase(it2.next());
                }
            }
        }
        fireModelChanged(model);
        this.likelihoodKnown = false;
    }

    protected void recalculateCase(int i) {
        this.recalculateCaseFlags[i] = true;
    }

    protected void recalculateCase(AbstractCase abstractCase) {
        if (abstractCase.wasEverInfected()) {
            recalculateCase(this.outbreak.getCaseIndex(abstractCase));
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // dr.oldevomodel.treelikelihood.AbstractTreeLikelihood, dr.inference.model.AbstractModel
    public void handleVariableChangedEvent(Variable variable, int i, Variable.ChangeType changeType) {
        fireModelChanged();
        this.likelihoodKnown = false;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // dr.oldevomodel.treelikelihood.AbstractTreeLikelihood, dr.inference.model.AbstractModel
    public void storeState() {
        super.storeState();
        this.storedInfectionTimes = Arrays.copyOf(this.infectionTimes, this.infectionTimes.length);
        this.storedInfectiousPeriods = Arrays.copyOf(this.infectiousPeriods, this.infectiousPeriods.length);
        if (this.hasLatentPeriods) {
            this.storedInfectiousTimes = Arrays.copyOf(this.infectiousTimes, this.infectionTimes.length);
            this.storedLatentPeriods = Arrays.copyOf(this.latentPeriods, this.latentPeriods.length);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // dr.oldevomodel.treelikelihood.AbstractTreeLikelihood, dr.inference.model.AbstractModel
    public void restoreState() {
        super.restoreState();
        this.infectionTimes = this.storedInfectionTimes;
        this.infectiousPeriods = this.storedInfectiousPeriods;
        if (this.hasLatentPeriods) {
            this.infectiousTimes = this.storedInfectiousTimes;
            this.latentPeriods = this.storedLatentPeriods;
        }
    }

    @Override // dr.oldevomodel.treelikelihood.AbstractTreeLikelihood, dr.inference.model.AbstractModel
    protected final void acceptState() {
    }

    public final BranchMapModel getBranchMap() {
        return ((PartitionedTreeModel) this.treeModel).getBranchMap();
    }

    @Override // dr.oldevomodel.treelikelihood.AbstractTreeLikelihood
    public final PartitionedTreeModel getTreeModel() {
        return (PartitionedTreeModel) this.treeModel;
    }

    @Override // dr.oldevomodel.treelikelihood.AbstractTreeLikelihood, dr.inference.model.Likelihood
    public void makeDirty() {
        this.likelihoodKnown = false;
        Arrays.fill(this.recalculateCaseFlags, true);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void prepareTimings() {
        this.infectionTimes = getInfectionTimes(true);
        if (this.hasLatentPeriods) {
            this.infectiousTimes = getInfectiousTimes(true);
        }
        this.infectiousPeriods = getInfectiousPeriods(true);
        if (this.hasLatentPeriods) {
            this.latentPeriods = getLatentPeriods(true);
        }
        Arrays.fill(this.recalculateCaseFlags, false);
    }

    @Override // dr.oldevomodel.treelikelihood.AbstractTreeLikelihood
    protected abstract double calculateLogLikelihood();

    protected boolean isAllowed() {
        return isAllowed(this.treeModel.getRoot());
    }

    private boolean isAllowed(NodeRef nodeRef) {
        AbstractCase abstractCase;
        AbstractCase abstractCase2;
        if (!this.treeModel.isRoot(nodeRef) && (abstractCase = getBranchMap().get(nodeRef.getNumber())) != (abstractCase2 = getBranchMap().get(this.treeModel.getParent(nodeRef).getNumber()))) {
            double d = this.infectionTimes[this.outbreak.getCaseIndex(abstractCase)];
            if (d > abstractCase2.getEndTime()) {
                return false;
            }
            if (this.hasLatentPeriods && d < this.infectiousTimes[this.outbreak.getCaseIndex(abstractCase2)]) {
                return false;
            }
        }
        return this.treeModel.isExternal(nodeRef) || (isAllowed(this.treeModel.getChild(nodeRef, 0)) && isAllowed(this.treeModel.getChild(nodeRef, 1)));
    }

    public double getNodeTime(NodeRef nodeRef) {
        return this.estimatedLastSampleTime - getHeight(nodeRef);
    }

    public double heightToTime(double d) {
        return this.estimatedLastSampleTime - d;
    }

    public double timeToHeight(double d) {
        return this.estimatedLastSampleTime - d;
    }

    private double getHeight(NodeRef nodeRef) {
        return this.treeModel.getNodeHeight(nodeRef);
    }

    public double getInfectionTime(AbstractCase abstractCase) {
        if (!this.recalculateCaseFlags[this.outbreak.getCaseIndex(abstractCase)]) {
            return this.infectionTimes[this.outbreak.getCaseIndex(abstractCase)];
        }
        if (!abstractCase.wasEverInfected()) {
            return Double.POSITIVE_INFINITY;
        }
        NodeRef earliestNodeInElement = ((PartitionedTreeModel) this.treeModel).getEarliestNodeInElement(abstractCase);
        NodeRef parent = this.treeModel.getParent(earliestNodeInElement);
        return parent != null ? getInfectionTime(heightToTime(this.treeModel.getNodeHeight(parent)), heightToTime(this.treeModel.getNodeHeight(earliestNodeInElement)), abstractCase) : getRootInfectionTime(getBranchMap());
    }

    private double getInfectionTime(double d, double d2, AbstractCase abstractCase) {
        return d + ((d2 - d) * (1.0d - abstractCase.getInfectionBranchPosition().getParameterValue(0)));
    }

    public double[] getInfectionTimes(boolean z) {
        if (z) {
            for (int i = 0; i < this.noCases; i++) {
                if (this.recalculateCaseFlags[i]) {
                    this.infectionTimes[i] = getInfectionTime(this.outbreak.getCase(i));
                }
            }
        }
        return this.infectionTimes;
    }

    public void setInfectionTime(AbstractCase abstractCase, double d) {
        setInfectionHeight(abstractCase, timeToHeight(d));
    }

    public void setInfectionHeight(AbstractCase abstractCase, double d) {
        if (abstractCase.wasEverInfected()) {
            NodeRef earliestNodeInElement = ((PartitionedTreeModel) this.treeModel).getEarliestNodeInElement(abstractCase);
            NodeRef parent = this.treeModel.getParent(earliestNodeInElement);
            double nodeHeight = this.treeModel.getNodeHeight(earliestNodeInElement);
            double nodeHeight2 = parent != null ? this.treeModel.getNodeHeight(parent) : nodeHeight + this.maxFirstInfToRoot.getParameterValue(0);
            if (d < nodeHeight || d > nodeHeight2) {
                throw new RuntimeException("Trying to set an infection time outside the branch on which it must occur");
            }
            abstractCase.setInfectionBranchPosition((d - nodeHeight) / (nodeHeight2 - nodeHeight));
        }
    }

    public double getInfectiousTime(AbstractCase abstractCase) {
        if (!this.hasLatentPeriods) {
            return getInfectionTime(abstractCase);
        }
        if (this.recalculateCaseFlags[this.outbreak.getCaseIndex(abstractCase)]) {
            if (abstractCase.wasEverInfected()) {
                this.infectiousTimes[this.outbreak.getCaseIndex(abstractCase)] = getInfectionTime(abstractCase) + ((CategoryOutbreak) this.outbreak).getLatentPeriod(((CategoryOutbreak) this.outbreak).getLatentCategory(abstractCase)).getParameterValue(0);
            } else {
                this.infectiousTimes[this.outbreak.getCaseIndex(abstractCase)] = Double.POSITIVE_INFINITY;
            }
        }
        return this.infectiousTimes[this.outbreak.getCaseIndex(abstractCase)];
    }

    public double[] getInfectiousTimes(boolean z) {
        if (z) {
            for (int i = 0; i < this.noCases; i++) {
                if (this.recalculateCaseFlags[i]) {
                    this.infectiousTimes[i] = getInfectiousTime(this.outbreak.getCase(i));
                }
            }
        }
        return this.infectiousTimes;
    }

    public double getInfectiousPeriod(AbstractCase abstractCase) {
        if (this.recalculateCaseFlags[this.outbreak.getCaseIndex(abstractCase)]) {
            if (!abstractCase.wasEverInfected()) {
                this.infectiousPeriods[this.outbreak.getCaseIndex(abstractCase)] = 0.0d;
            } else if (this.hasLatentPeriods) {
                double infectiousTime = getInfectiousTime(abstractCase);
                this.infectiousPeriods[this.outbreak.getCaseIndex(abstractCase)] = abstractCase.getEndTime() - infectiousTime;
            } else {
                double infectionTime = getInfectionTime(abstractCase);
                this.infectiousPeriods[this.outbreak.getCaseIndex(abstractCase)] = abstractCase.getEndTime() - infectionTime;
            }
        }
        return this.infectiousPeriods[this.outbreak.getCaseIndex(abstractCase)];
    }

    public double[] getInfectiousPeriods(boolean z) {
        if (z) {
            for (int i = 0; i < this.noCases; i++) {
                if (this.recalculateCaseFlags[i]) {
                    this.infectiousPeriods[i] = getInfectiousPeriod(this.outbreak.getCase(i));
                }
            }
        }
        return this.infectiousPeriods;
    }

    public Double[] getNonzeroInfectiousPeriods() {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < this.noCases; i++) {
            AbstractCase abstractCase = this.outbreak.getCase(i);
            if (abstractCase.wasEverInfected()) {
                arrayList.add(Double.valueOf(getInfectiousPeriod(abstractCase)));
            }
        }
        return (Double[]) arrayList.toArray(new Double[arrayList.size()]);
    }

    public double getLatentPeriod(AbstractCase abstractCase) {
        if (!this.hasLatentPeriods || !abstractCase.wasEverInfected()) {
            return 0.0d;
        }
        if (this.recalculateCaseFlags[this.outbreak.getCaseIndex(abstractCase)]) {
            this.latentPeriods[this.outbreak.getCaseIndex(abstractCase)] = getInfectiousTime(abstractCase) - getInfectionTime(abstractCase);
        }
        return this.latentPeriods[this.outbreak.getCaseIndex(abstractCase)];
    }

    public double[] getLatentPeriods(boolean z) {
        if (z) {
            for (int i = 0; i < this.noCases; i++) {
                if (this.recalculateCaseFlags[i]) {
                    this.latentPeriods[i] = getLatentPeriod(this.outbreak.getCase(i));
                }
            }
        }
        return this.latentPeriods;
    }

    public Double[] getNonzeroLatentPeriods() {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < this.noCases; i++) {
            AbstractCase abstractCase = this.outbreak.getCase(i);
            if (abstractCase.wasEverInfected()) {
                arrayList.add(Double.valueOf(getLatentPeriod(abstractCase)));
            }
        }
        return (Double[]) arrayList.toArray(new Double[arrayList.size()]);
    }

    public double[] getInfectedPeriods(boolean z) {
        if (!this.hasLatentPeriods) {
            return getInfectiousPeriods(z);
        }
        double[] dArr = new double[this.noCases];
        for (int i = 0; i < this.noCases; i++) {
            dArr[i] = getInfectedPeriod(this.outbreak.getCase(i));
        }
        return dArr;
    }

    public Double[] getNonzeroInfectedPeriods() {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < this.noCases; i++) {
            AbstractCase abstractCase = this.outbreak.getCase(i);
            if (abstractCase.wasEverInfected()) {
                arrayList.add(Double.valueOf(getInfectedPeriod(abstractCase)));
            }
        }
        return (Double[]) arrayList.toArray(new Double[arrayList.size()]);
    }

    public double getInfectedPeriod(AbstractCase abstractCase) {
        if (abstractCase.wasEverInfected) {
            return abstractCase.getEndTime() - getInfectionTime(abstractCase);
        }
        return 0.0d;
    }

    public static Double[] getSummaryStatistics(Double[] dArr) {
        double[] dArr2 = new double[dArr.length];
        for (int i = 0; i < dArr.length; i++) {
            dArr2[i] = dArr[i].doubleValue();
        }
        Double[] dArr3 = {Double.valueOf(new Mean().evaluate(dArr2)), Double.valueOf(new Median().evaluate(dArr2)), Double.valueOf(new Variance().evaluate(dArr2)), Double.valueOf(Math.sqrt(dArr3[2].doubleValue()))};
        return dArr3;
    }

    private double getRootInfectionTime(BranchMapModel branchMapModel) {
        NodeRef root = this.treeModel.getRoot();
        AbstractCase abstractCase = branchMapModel.get(root.getNumber());
        return heightToTime(this.treeModel.getNodeHeight(root) + (this.maxFirstInfToRoot.getParameterValue(0) * abstractCase.getInfectionBranchPosition().getParameterValue(0)));
    }

    protected double getRootInfectionTime() {
        return getInfectionTime(getBranchMap().get(this.treeModel.getRoot().getNumber()));
    }

    public void outputTreeToFile(String str, boolean z) {
        outputTreeToFile(getBranchMap(), str, z);
    }

    public void outputTreeToFile(BranchMapModel branchMapModel, String str, boolean z) {
        FlexibleTree addTransmissionNodes;
        try {
            if (z) {
                addTransmissionNodes = addTransmissionNodes(this.treeModel);
            } else {
                addTransmissionNodes = new FlexibleTree(this.treeModel);
                for (int i = 0; i < addTransmissionNodes.getNodeCount(); i++) {
                    FlexibleNode flexibleNode = (FlexibleNode) addTransmissionNodes.getNode(i);
                    flexibleNode.setAttribute("Number", Integer.valueOf(flexibleNode.getNumber()));
                    flexibleNode.setAttribute("Time", Double.valueOf(heightToTime(flexibleNode.getHeight())));
                    flexibleNode.setAttribute("partition", branchMapModel.get(flexibleNode.getNumber()));
                }
            }
            new NexusExporter(new PrintStream(str)).exportTree(addTransmissionNodes);
        } catch (IOException e) {
            System.out.println("IOException");
        }
    }

    public FlexibleTree addTransmissionNodes(Tree tree) {
        prepareTimings();
        FlexibleTree flexibleTree = new FlexibleTree(tree, true);
        for (int i = 0; i < flexibleTree.getNodeCount(); i++) {
            FlexibleNode flexibleNode = (FlexibleNode) flexibleTree.getNode(i);
            flexibleNode.setAttribute("Number", Integer.valueOf(flexibleNode.getNumber()));
            flexibleNode.setAttribute("Time", Double.valueOf(heightToTime(flexibleNode.getHeight())));
            flexibleNode.setAttribute("partition", getBranchMap().get(flexibleNode.getNumber()));
        }
        Iterator<AbstractCase> it = this.outbreak.getCases().iterator();
        while (it.hasNext()) {
            AbstractCase next = it.next();
            if (next.wasEverInfected()) {
                NodeRef earliestNodeInElement = ((PartitionedTreeModel) this.treeModel).getEarliestNodeInElement(next);
                int number = earliestNodeInElement.getNumber();
                if (this.treeModel.isRoot(earliestNodeInElement)) {
                    double height = getHeight(earliestNodeInElement) + (getNodeTime(earliestNodeInElement) - getInfectionTime(next));
                    FlexibleNode flexibleNode2 = (FlexibleNode) flexibleTree.getNode(number);
                    flexibleTree.beginTreeEdit();
                    FlexibleNode flexibleNode3 = new FlexibleNode();
                    flexibleNode3.setHeight(height);
                    flexibleNode3.setAttribute("Time", Double.valueOf(heightToTime(height)));
                    flexibleNode3.setAttribute("partition", "Origin");
                    flexibleTree.addChild(flexibleNode3, flexibleNode2);
                    flexibleNode2.setLength(height - getHeight(earliestNodeInElement));
                    flexibleTree.setRoot(flexibleNode3);
                    flexibleTree.endTreeEdit();
                } else {
                    NodeRef parent = this.treeModel.getParent(earliestNodeInElement);
                    double nodeTime = getNodeTime(earliestNodeInElement);
                    double infectionTime = getInfectionTime(next);
                    double height2 = getHeight(earliestNodeInElement) + (nodeTime - infectionTime);
                    FlexibleNode flexibleNode4 = (FlexibleNode) flexibleTree.getNode(number);
                    FlexibleNode flexibleNode5 = (FlexibleNode) flexibleTree.getParent(flexibleNode4);
                    flexibleTree.beginTreeEdit();
                    flexibleTree.removeChild(flexibleNode5, flexibleNode4);
                    FlexibleNode flexibleNode6 = new FlexibleNode();
                    flexibleNode6.setHeight(height2);
                    flexibleNode6.setLength(flexibleNode5.getHeight() - height2);
                    flexibleNode6.setAttribute("partition", getNodePartition(this.treeModel, parent));
                    flexibleNode6.setAttribute("Time", Double.valueOf(heightToTime(height2)));
                    flexibleNode4.setLength(nodeTime - infectionTime);
                    flexibleTree.addChild(flexibleNode5, flexibleNode6);
                    flexibleTree.addChild(flexibleNode6, flexibleNode4);
                    flexibleTree.endTreeEdit();
                }
            }
        }
        FlexibleTree flexibleTree2 = new FlexibleTree((FlexibleNode) flexibleTree.getRoot());
        for (int i2 = 0; i2 < flexibleTree2.getNodeCount(); i2++) {
            NodeRef node = flexibleTree2.getNode(i2);
            NodeRef parent2 = flexibleTree2.getParent(node);
            if (parent2 != null && flexibleTree2.getNodeHeight(node) > flexibleTree2.getNodeHeight(parent2)) {
                try {
                    new NexusExporter(new PrintStream("fancyProblem.nex")).exportTree(flexibleTree2);
                } catch (IOException e) {
                    e.printStackTrace();
                }
                try {
                    ((PartitionedTreeModel) this.treeModel).checkPartitions();
                } catch (BadPartitionException e2) {
                    System.out.print("Rewiring messed up because of partition problem.");
                }
                throw new RuntimeException("Rewiring messed up; investigate");
            }
        }
        return flexibleTree2;
    }

    @Override // dr.inference.model.AbstractModelLikelihood, dr.inference.loggers.Loggable
    public LogColumn[] getColumns() {
        LogColumn[] logColumnArr = new LogColumn[this.outbreak.infectedSize()];
        int i = 0;
        for (int i2 = 0; i2 < this.outbreak.size(); i2++) {
            final AbstractCase abstractCase = this.outbreak.getCase(i2);
            if (abstractCase.wasEverInfected()) {
                logColumnArr[i] = new LogColumn.Abstract(abstractCase.toString() + "_infector") { // from class: dr.evomodel.epidemiology.casetocase.CaseToCaseTreeLikelihood.3
                    @Override // dr.inference.loggers.LogColumn.Abstract
                    protected String getFormattedValue() {
                        return ((PartitionedTreeModel) CaseToCaseTreeLikelihood.this.treeModel).getInfector(abstractCase) == null ? "{Start}" : GraphMLUtils.START_SECTION + ((PartitionedTreeModel) CaseToCaseTreeLikelihood.this.treeModel).getInfector(abstractCase).toString() + GraphMLUtils.END_SECTION;
                    }
                };
                i++;
            }
        }
        return logColumnArr;
    }

    public LogColumn[] passColumns() {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < this.outbreak.size(); i++) {
            final AbstractCase abstractCase = this.outbreak.getCase(i);
            if (abstractCase.wasEverInfected()) {
                arrayList.add(new LogColumn.Abstract(abstractCase.toString() + "_infection_date") { // from class: dr.evomodel.epidemiology.casetocase.CaseToCaseTreeLikelihood.4
                    @Override // dr.inference.loggers.LogColumn.Abstract
                    protected String getFormattedValue() {
                        return String.valueOf(CaseToCaseTreeLikelihood.this.getInfectionTime(abstractCase));
                    }
                });
            }
        }
        if (this.hasLatentPeriods) {
            for (int i2 = 0; i2 < this.outbreak.size(); i2++) {
                final AbstractCase abstractCase2 = this.outbreak.getCase(i2);
                if (abstractCase2.wasEverInfected()) {
                    arrayList.add(new LogColumn.Abstract(abstractCase2.toString() + "_infectious_date") { // from class: dr.evomodel.epidemiology.casetocase.CaseToCaseTreeLikelihood.5
                        @Override // dr.inference.loggers.LogColumn.Abstract
                        protected String getFormattedValue() {
                            return String.valueOf(CaseToCaseTreeLikelihood.this.getInfectiousTime(abstractCase2));
                        }
                    });
                }
            }
            for (int i3 = 0; i3 < this.outbreak.size(); i3++) {
                final AbstractCase abstractCase3 = this.outbreak.getCase(i3);
                if (abstractCase3.wasEverInfected()) {
                    arrayList.add(new LogColumn.Abstract(abstractCase3.toString() + "_latent_period") { // from class: dr.evomodel.epidemiology.casetocase.CaseToCaseTreeLikelihood.6
                        @Override // dr.inference.loggers.LogColumn.Abstract
                        protected String getFormattedValue() {
                            return String.valueOf(CaseToCaseTreeLikelihood.this.getLatentPeriod(abstractCase3));
                        }
                    });
                }
            }
        }
        for (int i4 = 0; i4 < this.outbreak.size(); i4++) {
            final AbstractCase abstractCase4 = this.outbreak.getCase(i4);
            if (abstractCase4.wasEverInfected()) {
                arrayList.add(new LogColumn.Abstract(abstractCase4.toString() + "_infectious_period") { // from class: dr.evomodel.epidemiology.casetocase.CaseToCaseTreeLikelihood.7
                    @Override // dr.inference.loggers.LogColumn.Abstract
                    protected String getFormattedValue() {
                        return String.valueOf(CaseToCaseTreeLikelihood.this.getInfectiousPeriod(abstractCase4));
                    }
                });
            }
        }
        if (this.hasLatentPeriods) {
            for (int i5 = 0; i5 < this.outbreak.size(); i5++) {
                final AbstractCase abstractCase5 = this.outbreak.getCase(i5);
                if (abstractCase5.wasEverInfected()) {
                    arrayList.add(new LogColumn.Abstract(abstractCase5.toString() + "_infected_period") { // from class: dr.evomodel.epidemiology.casetocase.CaseToCaseTreeLikelihood.8
                        @Override // dr.inference.loggers.LogColumn.Abstract
                        protected String getFormattedValue() {
                            return String.valueOf(CaseToCaseTreeLikelihood.this.getInfectiousPeriod(abstractCase5) + CaseToCaseTreeLikelihood.this.getLatentPeriod(abstractCase5));
                        }
                    });
                }
            }
        }
        arrayList.add(new LogColumn.Abstract("infectious_period.mean") { // from class: dr.evomodel.epidemiology.casetocase.CaseToCaseTreeLikelihood.9
            @Override // dr.inference.loggers.LogColumn.Abstract
            protected String getFormattedValue() {
                return String.valueOf(CaseToCaseTreeLikelihood.getSummaryStatistics(CaseToCaseTreeLikelihood.this.getNonzeroInfectiousPeriods())[0]);
            }
        });
        arrayList.add(new LogColumn.Abstract("infectious_period.median") { // from class: dr.evomodel.epidemiology.casetocase.CaseToCaseTreeLikelihood.10
            @Override // dr.inference.loggers.LogColumn.Abstract
            protected String getFormattedValue() {
                return String.valueOf(CaseToCaseTreeLikelihood.getSummaryStatistics(CaseToCaseTreeLikelihood.this.getNonzeroInfectiousPeriods())[1]);
            }
        });
        arrayList.add(new LogColumn.Abstract("infectious_period.var") { // from class: dr.evomodel.epidemiology.casetocase.CaseToCaseTreeLikelihood.11
            @Override // dr.inference.loggers.LogColumn.Abstract
            protected String getFormattedValue() {
                return String.valueOf(CaseToCaseTreeLikelihood.getSummaryStatistics(CaseToCaseTreeLikelihood.this.getNonzeroInfectiousPeriods())[2]);
            }
        });
        arrayList.add(new LogColumn.Abstract("infectious_period.stdev") { // from class: dr.evomodel.epidemiology.casetocase.CaseToCaseTreeLikelihood.12
            @Override // dr.inference.loggers.LogColumn.Abstract
            protected String getFormattedValue() {
                return String.valueOf(CaseToCaseTreeLikelihood.getSummaryStatistics(CaseToCaseTreeLikelihood.this.getNonzeroInfectiousPeriods())[3]);
            }
        });
        if (this.hasLatentPeriods) {
            arrayList.add(new LogColumn.Abstract("latent_period.mean") { // from class: dr.evomodel.epidemiology.casetocase.CaseToCaseTreeLikelihood.13
                @Override // dr.inference.loggers.LogColumn.Abstract
                protected String getFormattedValue() {
                    return String.valueOf(CaseToCaseTreeLikelihood.getSummaryStatistics(CaseToCaseTreeLikelihood.this.getNonzeroLatentPeriods())[0]);
                }
            });
            arrayList.add(new LogColumn.Abstract("latent_period.median") { // from class: dr.evomodel.epidemiology.casetocase.CaseToCaseTreeLikelihood.14
                @Override // dr.inference.loggers.LogColumn.Abstract
                protected String getFormattedValue() {
                    return String.valueOf(CaseToCaseTreeLikelihood.getSummaryStatistics(CaseToCaseTreeLikelihood.this.getNonzeroLatentPeriods())[1]);
                }
            });
            arrayList.add(new LogColumn.Abstract("latent_period.var") { // from class: dr.evomodel.epidemiology.casetocase.CaseToCaseTreeLikelihood.15
                @Override // dr.inference.loggers.LogColumn.Abstract
                protected String getFormattedValue() {
                    return String.valueOf(CaseToCaseTreeLikelihood.getSummaryStatistics(CaseToCaseTreeLikelihood.this.getNonzeroLatentPeriods())[2]);
                }
            });
            arrayList.add(new LogColumn.Abstract("latent_period.stdev") { // from class: dr.evomodel.epidemiology.casetocase.CaseToCaseTreeLikelihood.16
                @Override // dr.inference.loggers.LogColumn.Abstract
                protected String getFormattedValue() {
                    return String.valueOf(CaseToCaseTreeLikelihood.getSummaryStatistics(CaseToCaseTreeLikelihood.this.getNonzeroLatentPeriods())[3]);
                }
            });
            arrayList.add(new LogColumn.Abstract("infected_period.mean") { // from class: dr.evomodel.epidemiology.casetocase.CaseToCaseTreeLikelihood.17
                @Override // dr.inference.loggers.LogColumn.Abstract
                protected String getFormattedValue() {
                    return String.valueOf(CaseToCaseTreeLikelihood.getSummaryStatistics(CaseToCaseTreeLikelihood.this.getNonzeroInfectedPeriods())[0]);
                }
            });
            arrayList.add(new LogColumn.Abstract("infected_period.median") { // from class: dr.evomodel.epidemiology.casetocase.CaseToCaseTreeLikelihood.18
                @Override // dr.inference.loggers.LogColumn.Abstract
                protected String getFormattedValue() {
                    return String.valueOf(CaseToCaseTreeLikelihood.getSummaryStatistics(CaseToCaseTreeLikelihood.this.getNonzeroInfectedPeriods())[1]);
                }
            });
            arrayList.add(new LogColumn.Abstract("infected_period.var") { // from class: dr.evomodel.epidemiology.casetocase.CaseToCaseTreeLikelihood.19
                @Override // dr.inference.loggers.LogColumn.Abstract
                protected String getFormattedValue() {
                    return String.valueOf(CaseToCaseTreeLikelihood.getSummaryStatistics(CaseToCaseTreeLikelihood.this.getNonzeroInfectedPeriods())[2]);
                }
            });
            arrayList.add(new LogColumn.Abstract("infected_period.stdev") { // from class: dr.evomodel.epidemiology.casetocase.CaseToCaseTreeLikelihood.20
                @Override // dr.inference.loggers.LogColumn.Abstract
                protected String getFormattedValue() {
                    return String.valueOf(CaseToCaseTreeLikelihood.getSummaryStatistics(CaseToCaseTreeLikelihood.this.getNonzeroInfectedPeriods())[3]);
                }
            });
            for (int i6 = 0; i6 < this.outbreak.size(); i6++) {
                final AbstractCase abstractCase6 = this.outbreak.getCase(i6);
                if (abstractCase6.wasEverInfected()) {
                    arrayList.add(new LogColumn.Abstract(abstractCase6.toString() + "_ibp") { // from class: dr.evomodel.epidemiology.casetocase.CaseToCaseTreeLikelihood.21
                        @Override // dr.inference.loggers.LogColumn.Abstract
                        protected String getFormattedValue() {
                            return String.valueOf(abstractCase6.getInfectionBranchPosition().getParameterValue(0));
                        }
                    });
                }
            }
        }
        return (LogColumn[]) arrayList.toArray(new LogColumn[arrayList.size()]);
    }

    @Override // dr.util.Citable
    public Citation.Category getCategory() {
        return Citation.Category.TREE_PRIORS;
    }

    @Override // dr.util.Citable
    public String getDescription() {
        return "Case to Case Transmission Tree model";
    }

    @Override // dr.util.Citable
    public List<Citation> getCitations() {
        return Arrays.asList(new Citation(new Author[]{new Author("M", "Hall"), new Author("M", "Woolhouse"), new Author("A", "Rambaut")}, "Epidemic Reconstruction in a Phylogenetics Framework: Transmission Trees as Partitions of the Node Set", 2016, "PLOS Comput Biol", 11, 0, 0, "10.1371/journal.pcbi.1004613", Citation.Status.PUBLISHED));
    }

    @Override // dr.evolution.tree.TreeTraitProvider
    public TreeTrait[] getTreeTraits() {
        return this.treeTraits.getTreeTraits();
    }

    @Override // dr.evolution.tree.TreeTraitProvider
    public TreeTrait getTreeTrait(String str) {
        return this.treeTraits.getTreeTrait(str);
    }

    public String getNodePartition(Tree tree, NodeRef nodeRef) {
        if (tree == this.treeModel) {
            return getBranchMap().get(nodeRef.getNumber()).toString();
        }
        try {
            NodeRef node = this.treeModel.getNode(((Integer) tree.getNodeAttribute(nodeRef, "Number")).intValue());
            if (this.treeModel.getNodeHeight(node) != tree.getNodeHeight(nodeRef)) {
                throw new RuntimeException("Can only reconstruct states on treeModel given to constructor or a partitioned tree derived from it");
            }
            return getBranchMap().get(node.getNumber()).toString();
        } catch (NullPointerException e) {
            if (tree.isRoot(nodeRef)) {
                return "Start";
            }
            return getBranchMap().get(((Integer) tree.getNodeAttribute(tree.getParent(nodeRef), "Number")).intValue()).toString();
        }
    }

    public Integer[] getParentsArray() {
        Integer[] numArr = new Integer[this.outbreak.size()];
        Iterator<AbstractCase> it = this.outbreak.getCases().iterator();
        while (it.hasNext()) {
            AbstractCase next = it.next();
            if (next.wasEverInfected()) {
                numArr[this.outbreak.getCaseIndex(next)] = Integer.valueOf(this.outbreak.getCaseIndex(((PartitionedTreeModel) this.treeModel).getInfector(next)));
            } else {
                numArr[this.outbreak.getCaseIndex(next)] = null;
            }
        }
        return numArr;
    }

    public AbstractCase getInfector(int i) {
        return ((PartitionedTreeModel) this.treeModel).getInfector(getOutbreak().getCase(i));
    }
}
