package dr.evomodel.epidemiology.casetocase.operators;

import dr.evolution.tree.NodeRef;
import dr.evomodel.epidemiology.casetocase.AbstractCase;
import dr.evomodel.epidemiology.casetocase.BranchMapModel;
import dr.evomodel.epidemiology.casetocase.CaseToCaseTreeLikelihood;
import dr.evomodel.epidemiology.casetocase.PartitionedTreeModel;
import dr.evomodel.operators.AbstractTreeOperator;
import dr.evomodel.tree.TreeModel;
import dr.math.MathUtils;
import dr.xml.AbstractXMLObjectParser;
import dr.xml.AttributeRule;
import dr.xml.ElementRule;
import dr.xml.XMLObject;
import dr.xml.XMLObjectParser;
import dr.xml.XMLParseException;
import dr.xml.XMLSyntaxRule;

/* loaded from: input_file:dr/evomodel/epidemiology/casetocase/operators/TransmissionWilsonBaldingB.class */
public class TransmissionWilsonBaldingB extends AbstractTreeOperator {
    private final CaseToCaseTreeLikelihood c2cLikelihood;
    public static final String TRANSMISSION_WILSON_BALDING_B = "transmissionWilsonBaldingB";
    private double logq;
    private static final boolean DEBUG = false;
    private final int tipCount;
    private final boolean resampleInfectionTimes;
    public static XMLObjectParser PARSER = new AbstractXMLObjectParser() { // from class: dr.evomodel.epidemiology.casetocase.operators.TransmissionWilsonBaldingB.1
        public static final String RESAMPLE_INFECTION_TIMES = "resampleInfectionTimes";
        private final XMLSyntaxRule[] rules = {AttributeRule.newDoubleRule("weight"), AttributeRule.newBooleanRule("resampleInfectionTimes", true), new ElementRule(CaseToCaseTreeLikelihood.class)};

        @Override // dr.xml.XMLObjectParser
        public String getParserName() {
            return TransmissionWilsonBaldingB.TRANSMISSION_WILSON_BALDING_B;
        }

        @Override // dr.xml.AbstractXMLObjectParser
        public Object parseXMLObject(XMLObject xMLObject) throws XMLParseException {
            CaseToCaseTreeLikelihood caseToCaseTreeLikelihood = (CaseToCaseTreeLikelihood) xMLObject.getChild(CaseToCaseTreeLikelihood.class);
            double doubleAttribute = xMLObject.getDoubleAttribute("weight");
            boolean z = false;
            if (xMLObject.hasAttribute("resampleInfectionTimes")) {
                z = xMLObject.getBooleanAttribute("resampleInfectionTimes");
            }
            return new TransmissionWilsonBaldingB(caseToCaseTreeLikelihood, doubleAttribute, z);
        }

        @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
        public String getParserDescription() {
            return "This element represents a Wilson-Balding move operator, such that the transplantation of the phylogenetic subtree is also transplantation of a transmission subtree.";
        }

        @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
        public Class getReturnType() {
            return TransmissionWilsonBaldingB.class;
        }

        @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
        public XMLSyntaxRule[] getSyntaxRules() {
            return this.rules;
        }
    };

    public TransmissionWilsonBaldingB(CaseToCaseTreeLikelihood caseToCaseTreeLikelihood, double d, boolean z) {
        this.c2cLikelihood = caseToCaseTreeLikelihood;
        setWeight(d);
        this.tipCount = caseToCaseTreeLikelihood.getTreeModel().getExternalNodeCount();
        this.resampleInfectionTimes = z;
    }

    @Override // dr.inference.operators.SimpleMCMCOperator
    public double doOperation() {
        proposeTree();
        if (this.c2cLikelihood.getTreeModel().getExternalNodeCount() == this.tipCount) {
            return this.logq;
        }
        throw new RuntimeException("Lost some tips in modified SPR! (" + this.tipCount + "-> " + this.c2cLikelihood.getTreeModel().getExternalNodeCount() + ")");
    }

    public void proposeTree() {
        NodeRef node;
        NodeRef nodeRef;
        PartitionedTreeModel treeModel = this.c2cLikelihood.getTreeModel();
        BranchMapModel branchMap = this.c2cLikelihood.getBranchMap();
        int nodeCount = treeModel.getNodeCount();
        do {
            node = treeModel.getNode(MathUtils.nextInt(nodeCount));
        } while (!eligibleForMove(node, treeModel, branchMap));
        NodeRef parent = treeModel.getParent(node);
        NodeRef node2 = treeModel.getNode(MathUtils.nextInt(treeModel.getNodeCount()));
        NodeRef parent2 = treeModel.getParent(node2);
        while (true) {
            nodeRef = parent2;
            if ((nodeRef == null || treeModel.getNodeHeight(nodeRef) > treeModel.getNodeHeight(node)) && node != node2) {
                break;
            }
            node2 = treeModel.getNode(MathUtils.nextInt(treeModel.getNodeCount()));
            parent2 = treeModel.getParent(node2);
        }
        if (parent == treeModel.getRoot() || node2 == treeModel.getRoot()) {
            this.logq = Double.NEGATIVE_INFINITY;
            return;
        }
        if (nodeRef == parent || node2 == parent || nodeRef == node) {
            this.logq = Double.NEGATIVE_INFINITY;
            return;
        }
        NodeRef otherChild = getOtherChild(treeModel, parent, node);
        NodeRef parent3 = treeModel.getParent(parent);
        if (this.resampleInfectionTimes) {
            AbstractCase abstractCase = branchMap.get(node.getNumber());
            AbstractCase abstractCase2 = branchMap.get(parent.getNumber());
            AbstractCase abstractCase3 = branchMap.get(otherChild.getNumber());
            AbstractCase abstractCase4 = null;
            if (parent3 != null) {
                abstractCase4 = branchMap.get(parent3.getNumber());
            }
            if (abstractCase != abstractCase2) {
                abstractCase.setInfectionBranchPosition(MathUtils.nextDouble());
            }
            if (abstractCase4 == null || abstractCase3 != abstractCase4) {
                abstractCase3.setInfectionBranchPosition(MathUtils.nextDouble());
            }
            branchMap.get(node2.getNumber()).setInfectionBranchPosition(MathUtils.nextDouble());
        }
        double max = Math.max(treeModel.getNodeHeight(node), treeModel.getNodeHeight(node2));
        double nodeHeight = treeModel.getNodeHeight(nodeRef) - max;
        double nextDouble = max + (MathUtils.nextDouble() * nodeHeight);
        double abs = nodeHeight / Math.abs(treeModel.getNodeHeight(parent3) - Math.max(treeModel.getNodeHeight(node), treeModel.getNodeHeight(otherChild)));
        if (branchMap.get(parent3.getNumber()) != branchMap.get(otherChild.getNumber())) {
            abs *= 0.5d;
        }
        if (branchMap.get(nodeRef.getNumber()) != branchMap.get(node2.getNumber())) {
            abs *= 2.0d;
        }
        treeModel.beginTreeEdit();
        if (node2 == treeModel.getRoot()) {
            treeModel.removeChild(parent, otherChild);
            treeModel.removeChild(parent3, parent);
            treeModel.addChild(parent, node2);
            treeModel.addChild(parent3, otherChild);
            treeModel.setRoot(parent);
        } else if (parent == treeModel.getRoot()) {
            treeModel.removeChild(nodeRef, node2);
            treeModel.removeChild(parent, otherChild);
            treeModel.addChild(parent, node2);
            treeModel.addChild(nodeRef, parent);
            treeModel.setRoot(otherChild);
        } else {
            treeModel.removeChild(nodeRef, node2);
            treeModel.removeChild(parent, otherChild);
            treeModel.removeChild(parent3, parent);
            treeModel.addChild(parent, node2);
            treeModel.addChild(nodeRef, parent);
            treeModel.addChild(parent3, otherChild);
        }
        treeModel.setNodeHeight(parent, nextDouble);
        treeModel.endTreeEdit();
        this.logq = Math.log(abs);
        if (MathUtils.nextInt(2) == 0) {
            branchMap.set(parent.getNumber(), branchMap.get(nodeRef.getNumber()), true);
        } else {
            branchMap.set(parent.getNumber(), branchMap.get(node2.getNumber()), true);
        }
    }

    public String getPerformanceSuggestion() {
        return "Not implemented";
    }

    private boolean eligibleForMove(NodeRef nodeRef, TreeModel treeModel, BranchMapModel branchMapModel) {
        return (treeModel.isRoot(nodeRef) || branchMapModel.get(treeModel.getParent(nodeRef).getNumber()) == branchMapModel.get(nodeRef.getNumber())) ? false : true;
    }

    @Override // dr.inference.operators.SimpleMCMCOperator, dr.inference.operators.MCMCOperator
    public String getOperatorName() {
        return "transmissionWilsonBaldingB (" + this.c2cLikelihood.getTreeModel().getId() + ")";
    }
}
