package dr.evomodel.epidemiology.casetocase.operators;

import dr.evolution.tree.NodeRef;
import dr.evomodel.epidemiology.casetocase.AbstractCase;
import dr.evomodel.epidemiology.casetocase.BranchMapModel;
import dr.evomodel.epidemiology.casetocase.CaseToCaseTreeLikelihood;
import dr.evomodel.epidemiology.casetocase.PartitionedTreeModel;
import dr.evomodel.operators.AbstractTreeOperator;
import dr.evomodel.tree.TreeModel;
import dr.math.MathUtils;
import dr.xml.AbstractXMLObjectParser;
import dr.xml.AttributeRule;
import dr.xml.ElementRule;
import dr.xml.XMLObject;
import dr.xml.XMLObjectParser;
import dr.xml.XMLParseException;
import dr.xml.XMLSyntaxRule;
import java.util.ArrayList;
import java.util.HashSet;

/* loaded from: input_file:dr/evomodel/epidemiology/casetocase/operators/TransmissionWilsonBaldingA.class */
public class TransmissionWilsonBaldingA extends AbstractTreeOperator {
    private final CaseToCaseTreeLikelihood c2cLikelihood;
    public static final String TRANSMISSION_WILSON_BALDING_A = "transmissionWilsonBaldingA";
    private double logq;
    private static final boolean DEBUG = false;
    private final int tipCount;
    private final boolean resampleInfectionTimes;
    public static XMLObjectParser PARSER = new AbstractXMLObjectParser() { // from class: dr.evomodel.epidemiology.casetocase.operators.TransmissionWilsonBaldingA.1
        public static final String RESAMPLE_INFECTION_TIMES = "resampleInfectionTimes";
        private final XMLSyntaxRule[] rules = {AttributeRule.newDoubleRule("weight"), AttributeRule.newBooleanRule("resampleInfectionTimes", true), new ElementRule(CaseToCaseTreeLikelihood.class)};

        @Override // dr.xml.XMLObjectParser
        public String getParserName() {
            return TransmissionWilsonBaldingA.TRANSMISSION_WILSON_BALDING_A;
        }

        @Override // dr.xml.AbstractXMLObjectParser
        public Object parseXMLObject(XMLObject xMLObject) throws XMLParseException {
            CaseToCaseTreeLikelihood caseToCaseTreeLikelihood = (CaseToCaseTreeLikelihood) xMLObject.getChild(CaseToCaseTreeLikelihood.class);
            double doubleAttribute = xMLObject.getDoubleAttribute("weight");
            boolean z = false;
            if (xMLObject.hasAttribute("resampleInfectionTimes")) {
                z = xMLObject.getBooleanAttribute("resampleInfectionTimes");
            }
            return new TransmissionWilsonBaldingA(caseToCaseTreeLikelihood, doubleAttribute, z);
        }

        @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
        public String getParserDescription() {
            return "This element represents a Wilson-Balding move operator, such that the transplantation of the subtree does not affect the topology of the transmission tree.";
        }

        @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
        public Class getReturnType() {
            return TransmissionWilsonBaldingA.class;
        }

        @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
        public XMLSyntaxRule[] getSyntaxRules() {
            return this.rules;
        }
    };

    public TransmissionWilsonBaldingA(CaseToCaseTreeLikelihood caseToCaseTreeLikelihood, double d, boolean z) {
        this.c2cLikelihood = caseToCaseTreeLikelihood;
        setWeight(d);
        this.tipCount = caseToCaseTreeLikelihood.getTreeModel().getExternalNodeCount();
        this.resampleInfectionTimes = z;
    }

    @Override // dr.inference.operators.SimpleMCMCOperator
    public double doOperation() {
        proposeTree();
        if (this.c2cLikelihood.getTreeModel().getExternalNodeCount() == this.tipCount) {
            return this.logq;
        }
        throw new RuntimeException("Lost some tips in modified SPR! (" + this.tipCount + "-> " + this.c2cLikelihood.getTreeModel().getExternalNodeCount() + ")");
    }

    public void proposeTree() {
        NodeRef nodeRef;
        PartitionedTreeModel treeModel = this.c2cLikelihood.getTreeModel();
        BranchMapModel branchMap = this.c2cLikelihood.getBranchMap();
        ArrayList<NodeRef> eligibleNodes = getEligibleNodes(treeModel, branchMap);
        NodeRef nodeRef2 = eligibleNodes.get(MathUtils.nextInt(eligibleNodes.size()));
        double size = eligibleNodes.size();
        NodeRef parent = treeModel.getParent(nodeRef2);
        Integer[] samePartitionElement = treeModel.samePartitionElement(parent);
        HashSet hashSet = new HashSet();
        for (Integer num : samePartitionElement) {
            hashSet.add(num);
            if (!treeModel.isExternal(treeModel.getNode(num.intValue()))) {
                hashSet.add(Integer.valueOf(treeModel.getChild(treeModel.getNode(num.intValue()), 0).getNumber()));
                hashSet.add(Integer.valueOf(treeModel.getChild(treeModel.getNode(num.intValue()), 1).getNumber()));
            }
        }
        Integer[] numArr = (Integer[]) hashSet.toArray(new Integer[hashSet.size()]);
        NodeRef node = treeModel.getNode(numArr[MathUtils.nextInt(numArr.length)].intValue());
        NodeRef parent2 = treeModel.getParent(node);
        while (true) {
            nodeRef = parent2;
            if ((nodeRef == null || treeModel.getNodeHeight(nodeRef) > treeModel.getNodeHeight(nodeRef2)) && nodeRef2 != node) {
                break;
            }
            node = treeModel.getNode(numArr[MathUtils.nextInt(numArr.length)].intValue());
            parent2 = treeModel.getParent(node);
        }
        if (parent == treeModel.getRoot() || node == treeModel.getRoot()) {
            this.logq = Double.NEGATIVE_INFINITY;
            return;
        }
        if (nodeRef == parent || node == parent || nodeRef == nodeRef2) {
            this.logq = Double.NEGATIVE_INFINITY;
            return;
        }
        NodeRef otherChild = getOtherChild(treeModel, parent, nodeRef2);
        NodeRef parent3 = treeModel.getParent(parent);
        if (this.resampleInfectionTimes) {
            AbstractCase abstractCase = branchMap.get(nodeRef2.getNumber());
            AbstractCase abstractCase2 = branchMap.get(parent.getNumber());
            AbstractCase abstractCase3 = branchMap.get(otherChild.getNumber());
            AbstractCase abstractCase4 = parent3 != null ? branchMap.get(parent3.getNumber()) : null;
            if (abstractCase != abstractCase2) {
                abstractCase.setInfectionBranchPosition(MathUtils.nextDouble());
            }
            if (abstractCase4 == null || abstractCase3 != abstractCase4) {
                abstractCase3.setInfectionBranchPosition(MathUtils.nextDouble());
            }
            AbstractCase abstractCase5 = branchMap.get(node.getNumber());
            AbstractCase abstractCase6 = branchMap.get(nodeRef.getNumber());
            if (abstractCase2 != abstractCase5 && abstractCase2 != abstractCase6) {
                throw new RuntimeException("TWBA misbehaving.");
            }
            abstractCase5.setInfectionBranchPosition(MathUtils.nextDouble());
        }
        double max = Math.max(treeModel.getNodeHeight(nodeRef2), treeModel.getNodeHeight(node));
        double nodeHeight = treeModel.getNodeHeight(nodeRef) - max;
        double nextDouble = max + (MathUtils.nextDouble() * nodeHeight);
        double abs = nodeHeight / Math.abs(treeModel.getNodeHeight(parent3) - Math.max(treeModel.getNodeHeight(nodeRef2), treeModel.getNodeHeight(otherChild)));
        treeModel.beginTreeEdit();
        if (node == treeModel.getRoot()) {
            treeModel.removeChild(parent, otherChild);
            treeModel.removeChild(parent3, parent);
            treeModel.addChild(parent, node);
            treeModel.addChild(parent3, otherChild);
            treeModel.setRoot(parent);
        } else if (parent == treeModel.getRoot()) {
            treeModel.removeChild(nodeRef, node);
            treeModel.removeChild(parent, otherChild);
            treeModel.addChild(parent, node);
            treeModel.addChild(nodeRef, parent);
            treeModel.setRoot(otherChild);
        } else {
            treeModel.removeChild(nodeRef, node);
            treeModel.removeChild(parent, otherChild);
            treeModel.removeChild(parent3, parent);
            treeModel.addChild(parent, node);
            treeModel.addChild(nodeRef, parent);
            treeModel.addChild(parent3, otherChild);
        }
        treeModel.setNodeHeight(parent, nextDouble);
        treeModel.endTreeEdit();
        this.logq = Math.log(abs);
        this.logq += Math.log(size / getEligibleNodes(treeModel, branchMap).size());
    }

    public String getPerformanceSuggestion() {
        return "Not implemented";
    }

    private boolean eligibleForMove(NodeRef nodeRef, TreeModel treeModel, BranchMapModel branchMapModel) {
        return !treeModel.isRoot(nodeRef) && ((treeModel.getParent(treeModel.getParent(nodeRef)) != null && branchMapModel.get(treeModel.getParent(nodeRef).getNumber()) == branchMapModel.get(treeModel.getParent(treeModel.getParent(nodeRef)).getNumber())) || branchMapModel.get(treeModel.getParent(nodeRef).getNumber()) == branchMapModel.get(getOtherChild(treeModel, treeModel.getParent(nodeRef), nodeRef).getNumber()));
    }

    private ArrayList<NodeRef> getEligibleNodes(TreeModel treeModel, BranchMapModel branchMapModel) {
        ArrayList<NodeRef> arrayList = new ArrayList<>();
        for (NodeRef nodeRef : treeModel.getNodes()) {
            if (eligibleForMove(nodeRef, treeModel, branchMapModel)) {
                arrayList.add(nodeRef);
            }
        }
        return arrayList;
    }

    @Override // dr.inference.operators.SimpleMCMCOperator, dr.inference.operators.MCMCOperator
    public String getOperatorName() {
        return "transmissionWilsonBaldingA (" + this.c2cLikelihood.getTreeModel().getId() + ")";
    }
}
