package dr.evomodel.epidemiology.casetocase;

import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeUtils;
import dr.evolution.util.Taxon;
import dr.evomodel.tree.TreeModel;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.math.MathUtils;
import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Set;

/* loaded from: input_file:dr/evomodel/epidemiology/casetocase/PartitionedTreeModel.class */
public class PartitionedTreeModel extends TreeModel {
    private final AbstractOutbreak outbreak;
    private BranchMapModel branchMap;
    private final int elementCount;
    public static final String PARTITIONED_TREE_MODEL = "partitionedTreeModel";
    Set<NodeRef> partitionsQueue;

    /* loaded from: input_file:dr/evomodel/epidemiology/casetocase/PartitionedTreeModel$PartitionsChangedEvent.class */
    public class PartitionsChangedEvent {
        private final HashSet<AbstractCase> casesToRecalculate;

        public PartitionsChangedEvent(HashSet<AbstractCase> hashSet) {
            this.casesToRecalculate = hashSet;
        }

        public HashSet<AbstractCase> getCasesToRecalculate() {
            return this.casesToRecalculate;
        }
    }

    public PartitionedTreeModel(String str, Tree tree, AbstractOutbreak abstractOutbreak) {
        super(str, tree);
        this.partitionsQueue = new HashSet();
        this.outbreak = abstractOutbreak;
        this.elementCount = abstractOutbreak.infectedSize();
        this.branchMap = new BranchMapModel(this);
        partitionAccordingToRandomTT(false);
    }

    public PartitionedTreeModel(String str, Tree tree, AbstractOutbreak abstractOutbreak, String str2) {
        super(str, tree);
        this.partitionsQueue = new HashSet();
        this.outbreak = abstractOutbreak;
        this.elementCount = abstractOutbreak.infectedSize();
        this.branchMap = new BranchMapModel(this);
        partitionAccordingToSpecificTT(str2);
    }

    public PartitionedTreeModel(TreeModel treeModel, AbstractOutbreak abstractOutbreak) {
        this(PARTITIONED_TREE_MODEL, treeModel, abstractOutbreak);
    }

    public PartitionedTreeModel(TreeModel treeModel, AbstractOutbreak abstractOutbreak, String str) {
        this(PARTITIONED_TREE_MODEL, treeModel, abstractOutbreak, str);
    }

    public void partitionsChangingAlert(HashSet<AbstractCase> hashSet) {
        this.listenerHelper.fireModelChanged(this, new PartitionsChangedEvent(hashSet));
    }

    public void partitionChangingAlert(AbstractCase abstractCase) {
        HashSet<AbstractCase> hashSet = new HashSet<>();
        hashSet.add(abstractCase);
        partitionsChangingAlert(hashSet);
    }

    public void universalAlert() {
        partitionsChangingAlert(new HashSet<>(Arrays.asList(this.branchMap.getArrayCopy())));
    }

    public BranchMapModel getBranchMap() {
        return this.branchMap;
    }

    @Override // dr.evomodel.tree.TreeModel, dr.inference.model.AbstractModel
    protected void handleModelChangedEvent(Model model, Object obj, int i) {
    }

    public void pushNodePartitionsChangedEvent(NodeRef nodeRef) {
        int number = nodeRef.getNumber();
        if (inTreeEdit()) {
            this.partitionsQueue.add(nodeRef);
        } else {
            partitionChangingAlert(this.branchMap.get(number));
        }
    }

    @Override // dr.evomodel.tree.TreeModel, dr.inference.model.AbstractModel
    public void handleVariableChangedEvent(Variable variable, int i, Variable.ChangeType changeType) {
        super.handleVariableChangedEvent(variable, i, changeType);
        if (changeType == Variable.ChangeType.ALL_VALUES_CHANGED) {
            universalAlert();
        } else {
            partitionsChangingAlert(adjacentElements(getNodeOfParameter((Parameter) variable)));
        }
    }

    public HashSet<AbstractCase> adjacentElements(NodeRef nodeRef) {
        HashSet<AbstractCase> hashSet = new HashSet<>();
        ArrayList arrayList = new ArrayList();
        arrayList.add(nodeRef);
        arrayList.add(getParent(nodeRef));
        arrayList.add(getChild(nodeRef, 0));
        arrayList.add(getChild(nodeRef, 1));
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            NodeRef nodeRef2 = (NodeRef) it.next();
            if (nodeRef2 != null) {
                hashSet.add(this.branchMap.get(nodeRef2.getNumber()));
            }
        }
        return hashSet;
    }

    private void flushQueue() {
        if (inTreeEdit()) {
            throw new RuntimeException("Wait until you've finished editing the tree before flushing the partitionqueue");
        }
        for (NodeRef nodeRef : this.partitionsQueue) {
            partitionChangingAlert(this.branchMap.get(nodeRef.getNumber()));
            NodeRef parent = getParent(nodeRef);
            if (parent != null && this.branchMap.get(nodeRef.getNumber()) != this.branchMap.get(parent.getNumber())) {
                partitionChangingAlert(this.branchMap.get(parent.getNumber()));
            }
        }
        this.partitionsQueue.clear();
    }

    @Override // dr.evomodel.tree.TreeModel, dr.evolution.tree.MutableTree
    public void addChild(NodeRef nodeRef, NodeRef nodeRef2) {
        pushNodePartitionsChangedEvent(nodeRef);
        pushNodePartitionsChangedEvent(nodeRef2);
        super.addChild(nodeRef, nodeRef2);
    }

    @Override // dr.evomodel.tree.TreeModel, dr.evolution.tree.MutableTree
    public void removeChild(NodeRef nodeRef, NodeRef nodeRef2) {
        pushNodePartitionsChangedEvent(nodeRef);
        pushNodePartitionsChangedEvent(nodeRef2);
        super.removeChild(nodeRef, nodeRef2);
    }

    @Override // dr.evomodel.tree.TreeModel, dr.evolution.tree.MutableTree
    public void setNodeHeight(NodeRef nodeRef, double d) {
        partitionsChangingAlert(adjacentElements(nodeRef));
        super.setNodeHeight(nodeRef, d);
    }

    @Override // dr.evomodel.tree.TreeModel, dr.evolution.tree.MutableTree
    public void endTreeEdit() {
        super.endTreeEdit();
        flushQueue();
    }

    public boolean checkPartitions() {
        return checkPartitions(this.branchMap, true);
    }

    protected boolean checkPartitions(BranchMapModel branchMapModel, boolean z) {
        boolean z2 = false;
        for (int i = 0; i < getInternalNodeCount(); i++) {
            boolean z3 = false;
            for (Integer num : samePartitionElement(getInternalNode(i))) {
                if (isExternal(getNode(num.intValue()))) {
                    z3 = true;
                }
            }
            if (!z2 && !z3) {
                z2 = true;
                if (z) {
                    System.out.println("Node " + (i + getExternalNodeCount()) + " is not connected to a tip");
                }
            }
        }
        for (int i2 = 0; i2 < getExternalNodeCount(); i2++) {
            AbstractCase abstractCase = this.branchMap.get(i2);
            if (this.branchMap.get(caseMRCA(abstractCase).getNumber()) != abstractCase) {
                throw new BadPartitionException("Node partition disconnected");
            }
        }
        return !z2;
    }

    public HashSet<Integer> samePartitionElementUpTree(NodeRef nodeRef) {
        HashSet<Integer> hashSet = new HashSet<>();
        AbstractCase abstractCase = this.branchMap.get(nodeRef.getNumber());
        NodeRef nodeRef2 = nodeRef;
        NodeRef parent = getParent(nodeRef);
        while (true) {
            NodeRef nodeRef3 = parent;
            if (nodeRef3 == null || this.branchMap.get(nodeRef3.getNumber()) != abstractCase) {
                break;
            }
            hashSet.add(Integer.valueOf(nodeRef3.getNumber()));
            if (countChildrenInSameElement(nodeRef3) == 2) {
                NodeRef sibling = sibling(this, nodeRef2);
                hashSet.add(Integer.valueOf(sibling.getNumber()));
                hashSet.addAll(samePartitionElementDownTree(sibling));
            }
            nodeRef2 = nodeRef3;
            parent = getParent(nodeRef2);
        }
        return hashSet;
    }

    public HashSet<Integer> samePartitionElementDownTree(NodeRef nodeRef) {
        HashSet<Integer> hashSet = new HashSet<>();
        AbstractCase abstractCase = this.branchMap.get(nodeRef.getNumber());
        for (int i = 0; i < getChildCount(nodeRef); i++) {
            if (this.branchMap.get(getChild(nodeRef, i).getNumber()) == abstractCase) {
                hashSet.add(Integer.valueOf(getChild(nodeRef, i).getNumber()));
                hashSet.addAll(samePartitionElementDownTree(getChild(nodeRef, i)));
            }
        }
        return hashSet;
    }

    public Integer[] samePartitionElement(NodeRef nodeRef) {
        HashSet hashSet = new HashSet();
        hashSet.add(Integer.valueOf(nodeRef.getNumber()));
        hashSet.addAll(samePartitionElementUpTree(nodeRef));
        hashSet.addAll(samePartitionElementDownTree(nodeRef));
        return (Integer[]) hashSet.toArray(new Integer[hashSet.size()]);
    }

    public int[] allTipsForThisCase(AbstractCase abstractCase) {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < getExternalNodeCount(); i++) {
            if (this.branchMap.get(i) == abstractCase) {
                arrayList.add(Integer.valueOf(i));
            }
        }
        int[] iArr = new int[arrayList.size()];
        for (int i2 = 0; i2 < iArr.length; i2++) {
            iArr[i2] = ((Integer) arrayList.get(i2)).intValue();
        }
        return iArr;
    }

    public NodeRef getEarliestNodeInElement(AbstractCase abstractCase) {
        if (!abstractCase.wasEverInfected()) {
            return null;
        }
        NodeRef caseMRCA = caseMRCA(abstractCase);
        if (this.branchMap.get(caseMRCA.getNumber()) != abstractCase) {
            throw new BadPartitionException("Node partition element disconnected");
        }
        NodeRef nodeRef = caseMRCA;
        NodeRef parent = getParent(nodeRef);
        boolean z = parent == null;
        while (!z) {
            if (this.branchMap.get(nodeRef.getNumber()) != this.branchMap.get(parent.getNumber())) {
                z = true;
            } else {
                nodeRef = parent;
                parent = getParent(nodeRef);
                if (parent == null) {
                    z = true;
                }
            }
        }
        return nodeRef;
    }

    public HashSet<AbstractCase> getDescendants(AbstractCase abstractCase) {
        HashSet<AbstractCase> hashSet = new HashSet<>(getInfectees(abstractCase));
        if (abstractCase.wasEverInfected()) {
            Iterator<AbstractCase> it = hashSet.iterator();
            while (it.hasNext()) {
                hashSet.addAll(getDescendants(it.next()));
            }
        }
        return hashSet;
    }

    public AbstractCase getInfector(AbstractCase abstractCase) {
        if (!abstractCase.wasEverInfected()) {
            return null;
        }
        NodeRef caseMRCA = caseMRCA(abstractCase);
        if (this.branchMap.get(caseMRCA.getNumber()) != abstractCase) {
            throw new BadPartitionException("Node partition element disconnected");
        }
        NodeRef nodeRef = caseMRCA;
        while (this.branchMap.get(nodeRef.getNumber()) == abstractCase) {
            nodeRef = getParent(nodeRef);
            if (nodeRef == null) {
                return null;
            }
        }
        return this.branchMap.get(nodeRef.getNumber());
    }

    public AbstractCase getRootCase() {
        return this.branchMap.get(getRoot().getNumber());
    }

    public HashSet<AbstractCase> getInfectees(AbstractCase abstractCase) {
        return abstractCase.wasEverInfected() ? getInfecteesInClade(getEarliestNodeInElement(abstractCase)) : new HashSet<>();
    }

    public HashSet<AbstractCase> getInfecteesInClade(NodeRef nodeRef) {
        HashSet<AbstractCase> hashSet = new HashSet<>();
        if (isExternal(nodeRef)) {
            return hashSet;
        }
        AbstractCase abstractCase = this.branchMap.get(nodeRef.getNumber());
        for (int i = 0; i < getChildCount(nodeRef); i++) {
            NodeRef child = getChild(nodeRef, i);
            AbstractCase abstractCase2 = this.branchMap.get(child.getNumber());
            if (abstractCase2 != abstractCase) {
                hashSet.add(abstractCase2);
            } else {
                hashSet.addAll(getInfecteesInClade(child));
            }
        }
        return hashSet;
    }

    public AbstractCase getInfector(NodeRef nodeRef) {
        if (isRoot(nodeRef) || nodeRef.getNumber() == getRoot().getNumber()) {
            return null;
        }
        return this.branchMap.get(getParent(nodeRef).getNumber()) != this.branchMap.get(nodeRef.getNumber()) ? this.branchMap.get(getParent(nodeRef).getNumber()) : getInfector(getParent(nodeRef));
    }

    public AbstractCase getParentCase(NodeRef nodeRef) {
        return this.branchMap.get(getParent(nodeRef).getNumber());
    }

    public int getElementCount() {
        return this.elementCount;
    }

    public int countChildrenInSameElement(NodeRef nodeRef) {
        if (isExternal(nodeRef)) {
            return -1;
        }
        int i = 0;
        AbstractCase abstractCase = this.branchMap.get(nodeRef.getNumber());
        for (int i2 = 0; i2 < getChildCount(nodeRef); i2++) {
            if (this.branchMap.get(getChild(nodeRef, i2).getNumber()) == abstractCase) {
                i++;
            }
        }
        return i;
    }

    public static NodeRef sibling(TreeModel treeModel, NodeRef nodeRef) {
        if (treeModel.isRoot(nodeRef)) {
            return null;
        }
        NodeRef parent = treeModel.getParent(nodeRef);
        for (int i = 0; i < treeModel.getChildCount(parent); i++) {
            if (treeModel.getChild(parent, i) != nodeRef) {
                return treeModel.getChild(parent, i);
            }
        }
        return null;
    }

    public NodeRef caseMRCA(AbstractCase abstractCase, boolean z) {
        NodeRef commonAncestor = TreeUtils.getCommonAncestor(this, allTipsForThisCase(abstractCase));
        if (!z || this.branchMap.get(commonAncestor.getNumber()) == abstractCase) {
            return commonAncestor;
        }
        throw new BadPartitionException("A partition element is disconnected");
    }

    public NodeRef caseMRCA(AbstractCase abstractCase) {
        return caseMRCA(abstractCase, true);
    }

    private HashSet<NodeRef> getDescendantTips(NodeRef nodeRef) {
        HashSet<NodeRef> hashSet = new HashSet<>();
        if (isExternal(nodeRef)) {
            hashSet.add(nodeRef);
            return hashSet;
        }
        hashSet.addAll(getDescendantTips(getChild(nodeRef, 0)));
        hashSet.addAll(getDescendantTips(getChild(nodeRef, 1)));
        return hashSet;
    }

    public boolean isAncestral(NodeRef nodeRef) {
        AbstractCase abstractCase = this.branchMap.get(nodeRef.getNumber());
        Iterator<NodeRef> it = getDescendantTips(nodeRef).iterator();
        while (it.hasNext()) {
            if (this.branchMap.get(it.next().getNumber()) == abstractCase) {
                return true;
            }
        }
        return false;
    }

    public boolean isRootBlockedBy(AbstractCase abstractCase, AbstractCase abstractCase2) {
        return directDescendant(caseMRCA(abstractCase), caseMRCA(abstractCase2));
    }

    public boolean isRootBlocked(AbstractCase abstractCase) {
        Iterator<AbstractCase> it = this.outbreak.getCases().iterator();
        while (it.hasNext()) {
            AbstractCase next = it.next();
            if (next.wasEverInfected && next != abstractCase && isRootBlockedBy(abstractCase, next)) {
                return true;
            }
        }
        return false;
    }

    private HashSet<NodeRef> getTipsInThisPartitionElement(AbstractCase abstractCase) {
        HashSet<NodeRef> hashSet = new HashSet<>();
        for (int i = 0; i < getExternalNodeCount(); i++) {
            if (this.branchMap.get(i) == abstractCase) {
                hashSet.add(getExternalNode(i));
            }
        }
        return hashSet;
    }

    private boolean directDescendant(NodeRef nodeRef, NodeRef nodeRef2) {
        NodeRef nodeRef3 = nodeRef;
        while (true) {
            NodeRef nodeRef4 = nodeRef3;
            if (nodeRef4 == null) {
                return false;
            }
            if (nodeRef4 == nodeRef2) {
                return true;
            }
            nodeRef3 = getParent(nodeRef4);
        }
    }

    private boolean directRelationship(NodeRef nodeRef, NodeRef nodeRef2) {
        return directDescendant(nodeRef, nodeRef2) || directDescendant(nodeRef2, nodeRef);
    }

    private AbstractCase[] prepareExternalNodeMap(AbstractCase[] abstractCaseArr) {
        for (int i = 0; i < getExternalNodeCount(); i++) {
            TreeModel.Node node = (TreeModel.Node) getExternalNode(i);
            Taxon taxon = node.taxon;
            Iterator<AbstractCase> it = this.outbreak.getCases().iterator();
            while (it.hasNext()) {
                AbstractCase next = it.next();
                if (next.wasEverInfected()) {
                    Iterator<Taxon> it2 = next.getAssociatedTaxa().iterator();
                    while (it2.hasNext()) {
                        if (it2.next().equals(taxon)) {
                            abstractCaseArr[node.getNumber()] = next;
                        }
                    }
                }
            }
        }
        return abstractCaseArr;
    }

    private void partitionAccordingToSpecificTT(String str) {
        System.out.println("Using specified starting transmission tree.");
        try {
            BufferedReader bufferedReader = new BufferedReader(new FileReader(str));
            HashMap<AbstractCase, AbstractCase> hashMap = new HashMap<>();
            bufferedReader.readLine();
            for (String readLine = bufferedReader.readLine(); readLine != null; readLine = bufferedReader.readLine()) {
                String[] split = readLine.replace("\"", "").split("\\,");
                if (split[1].equals("Start")) {
                    hashMap.put(this.outbreak.getCase(split[0]), null);
                } else {
                    hashMap.put(this.outbreak.getCase(split[0]), this.outbreak.getCase(split[1]));
                }
            }
            bufferedReader.close();
            partitionAccordingToSpecificTT(hashMap);
        } catch (IOException e) {
            throw new RuntimeException("Cannot read file: " + str);
        }
    }

    private void partitionAccordingToSpecificTT(HashMap<AbstractCase, AbstractCase> hashMap) {
        this.branchMap.setAll(prepareExternalNodeMap(new AbstractCase[getNodeCount()]), true);
        Iterator<AbstractCase> it = hashMap.keySet().iterator();
        while (it.hasNext()) {
            if (!it.next().wasEverInfected) {
                throw new RuntimeException("This starting transmission tree involves never-infected cases");
            }
        }
        AbstractCase abstractCase = null;
        int i = 0;
        Iterator<AbstractCase> it2 = this.outbreak.getCases().iterator();
        while (it2.hasNext()) {
            AbstractCase next = it2.next();
            if (next.wasEverInfected() && hashMap.get(next) == null) {
                abstractCase = next;
                i++;
            }
        }
        if (i == 0) {
            throw new RuntimeException("Given starting transmission tree appears to have a cycle");
        }
        if (i > 1) {
            throw new RuntimeException("Given starting transmission tree appears not to be connected");
        }
        specificallyPartitionDownwards(getRoot(), abstractCase, hashMap);
        if (!checkPartitions()) {
            throw new RuntimeException("Given starting transmission tree is not compatible with the starting tree");
        }
    }

    private void specificallyPartitionDownwards(NodeRef nodeRef, AbstractCase abstractCase, HashMap<AbstractCase, AbstractCase> hashMap) {
        NodeRef nodeRef2;
        if (isExternal(nodeRef)) {
            return;
        }
        this.branchMap.set(nodeRef.getNumber(), abstractCase, true);
        if (isAncestral(nodeRef)) {
            for (int i = 0; i < getChildCount(nodeRef); i++) {
                specificallyPartitionDownwards(getChild(nodeRef, i), abstractCase, hashMap);
            }
            return;
        }
        this.branchMap.set(nodeRef.getNumber(), null, true);
        HashSet hashSet = new HashSet();
        Iterator<AbstractCase> it = this.outbreak.getCases().iterator();
        while (it.hasNext()) {
            AbstractCase next = it.next();
            if (hashMap.get(next) == abstractCase) {
                hashSet.add(next);
            }
        }
        HashSet hashSet2 = new HashSet(hashSet);
        Iterator it2 = hashSet.iterator();
        while (true) {
            if (!it2.hasNext()) {
                break;
            }
            AbstractCase abstractCase2 = (AbstractCase) it2.next();
            NodeRef caseMRCA = caseMRCA(abstractCase2);
            if (directDescendant(nodeRef, caseMRCA)) {
                throw new RuntimeException("Starting transmission tree is incompatible with starting phylogeny");
            }
            if (caseMRCA == nodeRef) {
                hashSet2 = new HashSet();
                hashSet2.add(abstractCase2);
                break;
            }
            NodeRef nodeRef3 = caseMRCA;
            while (true) {
                nodeRef2 = nodeRef3;
                if (nodeRef2 == nodeRef || nodeRef2 == null) {
                    break;
                } else {
                    nodeRef3 = getParent(nodeRef2);
                }
            }
            if (nodeRef2 == null) {
                hashSet2.remove(abstractCase2);
            }
        }
        if (hashSet2.size() == 1) {
            this.branchMap.set(nodeRef.getNumber(), (AbstractCase) hashSet2.iterator().next(), true);
        } else {
            this.branchMap.set(nodeRef.getNumber(), abstractCase, true);
        }
        for (int i2 = 0; i2 < getChildCount(nodeRef); i2++) {
            specificallyPartitionDownwards(getChild(nodeRef, i2), this.branchMap.get(nodeRef.getNumber()), hashMap);
        }
    }

    private void partitionAccordingToRandomTT(boolean z) {
        System.out.println("Generating a random starting partition of the tree");
        this.branchMap.setAll(prepareExternalNodeMap(new AbstractCase[getNodeCount()]), true);
        randomlyAssignNode(getRoot(), z);
    }

    private AbstractCase randomlyAssignNode(NodeRef nodeRef, boolean z) {
        if (isExternal(nodeRef)) {
            return this.branchMap.get(nodeRef.getNumber());
        }
        ArrayList arrayList = new ArrayList();
        Iterator<AbstractCase> it = this.outbreak.getCases().iterator();
        while (it.hasNext()) {
            AbstractCase next = it.next();
            if (next.wasEverInfected) {
                NodeRef caseMRCA = caseMRCA(next, false);
                Iterator<NodeRef> it2 = getTipsInThisPartitionElement(next).iterator();
                while (it2.hasNext()) {
                    NodeRef next2 = it2.next();
                    if (directDescendant(nodeRef, caseMRCA) && directDescendant(next2, nodeRef) && !arrayList.contains(next)) {
                        arrayList.add(next);
                    }
                }
            }
        }
        if (arrayList.size() > 1) {
            throw new RuntimeException("Starting phylogeny is incompatible with this tip partition");
        }
        if (arrayList.size() == 1) {
            this.branchMap.set(nodeRef.getNumber(), (AbstractCase) arrayList.get(0), true);
            for (int i = 0; i < getChildCount(nodeRef); i++) {
                if (!isExternal(getChild(nodeRef, i))) {
                    randomlyAssignNode(getChild(nodeRef, i), z);
                }
            }
            return (AbstractCase) arrayList.get(0);
        }
        AbstractCase[] abstractCaseArr = new AbstractCase[2];
        for (int i2 = 0; i2 < getChildCount(nodeRef); i2++) {
            if (isExternal(getChild(nodeRef, i2))) {
                abstractCaseArr[i2] = this.branchMap.get(getChild(nodeRef, i2).getNumber());
            } else {
                abstractCaseArr[i2] = randomlyAssignNode(getChild(nodeRef, i2), z);
            }
        }
        while (isRoot(nodeRef) && abstractCaseArr[0] == null && abstractCaseArr[1] == null) {
            for (int i3 = 0; i3 < getChildCount(nodeRef); i3++) {
                if (isExternal(getChild(nodeRef, i3))) {
                    abstractCaseArr[i3] = this.branchMap.get(getChild(nodeRef, i3).getNumber());
                } else {
                    abstractCaseArr[i3] = randomlyAssignNode(getChild(nodeRef, i3), z);
                }
            }
        }
        if (isRoot(nodeRef)) {
            int nextInt = MathUtils.nextInt(2);
            if (abstractCaseArr[nextInt] == null) {
                nextInt = 1 - nextInt;
            }
            AbstractCase abstractCase = abstractCaseArr[nextInt];
            fillDownTree(nodeRef, abstractCase);
            return abstractCase;
        }
        int nextInt2 = MathUtils.nextInt(z ? 3 : 2);
        if (nextInt2 == 2) {
            return null;
        }
        AbstractCase abstractCase2 = abstractCaseArr[nextInt2];
        AbstractCase abstractCase3 = abstractCaseArr[1 - nextInt2];
        if (getNodeHeight(getChild(nodeRef, nextInt2)) > (abstractCase3.getInfectionBranchPosition().getParameterValue(0) * getBranchLength(getChild(nodeRef, 1 - nextInt2))) + getNodeHeight(getChild(nodeRef, 1 - nextInt2))) {
            abstractCase2 = abstractCase3;
        }
        if (abstractCase2 != null) {
            fillDownTree(nodeRef, abstractCase2);
        } else {
            this.branchMap.set(nodeRef.getNumber(), null, true);
        }
        return abstractCase2;
    }

    private void fillDownTree(NodeRef nodeRef, AbstractCase abstractCase) {
        if (this.branchMap.get(nodeRef.getNumber()) == null) {
            this.branchMap.set(nodeRef.getNumber(), abstractCase, true);
            for (int i = 0; i < 2; i++) {
                fillDownTree(getChild(nodeRef, i), abstractCase);
            }
        }
    }
}
