package dr.evomodel.epidemiology.casetocase.operators;

import dr.evolution.tree.NodeRef;
import dr.evomodel.epidemiology.casetocase.AbstractCase;
import dr.evomodel.epidemiology.casetocase.BranchMapModel;
import dr.evomodel.epidemiology.casetocase.CaseToCaseTreeLikelihood;
import dr.evomodel.epidemiology.casetocase.PartitionedTreeModel;
import dr.inference.operators.SimpleMCMCOperator;
import dr.math.MathUtils;
import dr.xml.AbstractXMLObjectParser;
import dr.xml.AttributeRule;
import dr.xml.ElementRule;
import dr.xml.XMLObject;
import dr.xml.XMLObjectParser;
import dr.xml.XMLParseException;
import dr.xml.XMLSyntaxRule;
import java.util.Iterator;

/* loaded from: input_file:dr/evomodel/epidemiology/casetocase/operators/InfectionBranchMovementOperator.class */
public class InfectionBranchMovementOperator extends SimpleMCMCOperator {
    public static final String INFECTION_BRANCH_MOVEMENT_OPERATOR = "infectionBranchMovementOperator";
    private CaseToCaseTreeLikelihood c2cLikelihood;
    private final boolean resampleInfectionTimes;
    private static final boolean DEBUG = false;
    public static XMLObjectParser PARSER = new AbstractXMLObjectParser() { // from class: dr.evomodel.epidemiology.casetocase.operators.InfectionBranchMovementOperator.1
        public static final String RESAMPLE_INFECTION_TIMES = "resampleInfectionTimes";
        private final XMLSyntaxRule[] rules = {AttributeRule.newDoubleRule("weight"), AttributeRule.newBooleanRule("resampleInfectionTimes", true), new ElementRule(CaseToCaseTreeLikelihood.class)};

        @Override // dr.xml.XMLObjectParser
        public String getParserName() {
            return InfectionBranchMovementOperator.INFECTION_BRANCH_MOVEMENT_OPERATOR;
        }

        @Override // dr.xml.AbstractXMLObjectParser
        public Object parseXMLObject(XMLObject xMLObject) throws XMLParseException {
            CaseToCaseTreeLikelihood caseToCaseTreeLikelihood = (CaseToCaseTreeLikelihood) xMLObject.getChild(CaseToCaseTreeLikelihood.class);
            double doubleAttribute = xMLObject.getDoubleAttribute("weight");
            boolean z = false;
            if (xMLObject.hasAttribute("resampleInfectionTimes")) {
                z = xMLObject.getBooleanAttribute("resampleInfectionTimes");
            }
            return new InfectionBranchMovementOperator(caseToCaseTreeLikelihood, doubleAttribute, z);
        }

        @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
        public String getParserDescription() {
            return "This operator switches the painting of a random eligible internal node from the painting of one of its children to the painting of the other";
        }

        @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
        public Class getReturnType() {
            return InfectionBranchMovementOperator.class;
        }

        @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
        public XMLSyntaxRule[] getSyntaxRules() {
            return this.rules;
        }
    };

    public InfectionBranchMovementOperator(CaseToCaseTreeLikelihood caseToCaseTreeLikelihood, double d, boolean z) {
        this.c2cLikelihood = caseToCaseTreeLikelihood;
        setWeight(d);
        this.resampleInfectionTimes = z;
    }

    @Override // dr.inference.operators.SimpleMCMCOperator, dr.inference.operators.MCMCOperator
    public String getOperatorName() {
        return INFECTION_BRANCH_MOVEMENT_OPERATOR;
    }

    @Override // dr.inference.operators.SimpleMCMCOperator
    public double doOperation() {
        PartitionedTreeModel treeModel = this.c2cLikelihood.getTreeModel();
        BranchMapModel branchMap = this.c2cLikelihood.getBranchMap();
        AbstractCase abstractCase = this.c2cLikelihood.getOutbreak().getCase(MathUtils.nextInt(this.c2cLikelihood.getOutbreak().size()));
        while (true) {
            AbstractCase abstractCase2 = abstractCase;
            if (branchMap.get(treeModel.getRoot().getNumber()) != abstractCase2 && abstractCase2.wasEverInfected()) {
                return adjustTree(treeModel, treeModel.getEarliestNodeInElement(abstractCase2));
            }
            abstractCase = this.c2cLikelihood.getOutbreak().getCase(MathUtils.nextInt(this.c2cLikelihood.getOutbreak().size()));
        }
    }

    private double adjustTree(PartitionedTreeModel partitionedTreeModel, NodeRef nodeRef) {
        double moveDown;
        BranchMapModel branchMap = partitionedTreeModel.getBranchMap();
        AbstractCase abstractCase = branchMap.get(nodeRef.getNumber());
        AbstractCase abstractCase2 = branchMap.get(partitionedTreeModel.getParent(nodeRef).getNumber());
        boolean z = nodeRef != partitionedTreeModel.caseMRCA(abstractCase);
        boolean z2 = (partitionedTreeModel.isRootBlockedBy(abstractCase, abstractCase2) && partitionedTreeModel.isAncestral(partitionedTreeModel.getParent(nodeRef))) ? false : true;
        if (z2 && z) {
            moveDown = MathUtils.nextBoolean() ? moveUp(partitionedTreeModel, nodeRef) : moveDown(partitionedTreeModel, nodeRef);
        } else if (z2) {
            moveDown = moveUp(partitionedTreeModel, nodeRef);
        } else {
            if (!z) {
                return Double.NEGATIVE_INFINITY;
            }
            moveDown = moveDown(partitionedTreeModel, nodeRef);
        }
        return moveDown;
    }

    private double moveUp(PartitionedTreeModel partitionedTreeModel, NodeRef nodeRef) {
        double log;
        double log2;
        BranchMapModel branchMap = partitionedTreeModel.getBranchMap();
        AbstractCase abstractCase = branchMap.get(nodeRef.getNumber());
        AbstractCase[] arrayCopy = branchMap.getArrayCopy();
        NodeRef parent = partitionedTreeModel.getParent(nodeRef);
        NodeRef nodeRef2 = nodeRef;
        for (int i = 0; i < partitionedTreeModel.getChildCount(parent); i++) {
            if (partitionedTreeModel.getChild(parent, i) != nodeRef) {
                nodeRef2 = partitionedTreeModel.getChild(parent, i);
            }
        }
        AbstractCase abstractCase2 = branchMap.get(parent.getNumber());
        NodeRef caseMRCA = partitionedTreeModel.caseMRCA(abstractCase);
        NodeRef caseMRCA2 = partitionedTreeModel.caseMRCA(abstractCase2);
        if (this.c2cLikelihood.getTreeModel().isAncestral(parent)) {
            if (this.resampleInfectionTimes) {
                abstractCase2.setInfectionBranchPosition(MathUtils.nextDouble());
            }
            NodeRef parent2 = partitionedTreeModel.getParent(parent);
            if (parent2 != null && branchMap.get(parent2.getNumber()) == branchMap.get(parent.getNumber())) {
                Iterator<Integer> it = this.c2cLikelihood.getTreeModel().samePartitionElementUpTree(parent).iterator();
                while (it.hasNext()) {
                    arrayCopy[it.next().intValue()] = branchMap.get(nodeRef.getNumber());
                }
                arrayCopy[parent2.getNumber()] = branchMap.get(nodeRef.getNumber());
            }
            log = 0.0d + (nodeRef == caseMRCA ? Math.log(0.5d) : 0.0d);
        } else {
            if (branchMap.get(nodeRef2.getNumber()) == branchMap.get(parent.getNumber())) {
                Iterator<Integer> it2 = this.c2cLikelihood.getTreeModel().samePartitionElementDownTree(nodeRef2).iterator();
                while (it2.hasNext()) {
                    arrayCopy[it2.next().intValue()] = branchMap.get(nodeRef.getNumber());
                }
                arrayCopy[nodeRef2.getNumber()] = branchMap.get(nodeRef.getNumber());
            }
            log = 0.0d + (nodeRef == caseMRCA ? Math.log(0.5d) : 0.0d);
        }
        arrayCopy[parent.getNumber()] = branchMap.get(nodeRef.getNumber());
        branchMap.setAll(arrayCopy, false);
        if (this.c2cLikelihood.getTreeModel().isAncestral(parent)) {
            log2 = log + (nodeRef2 == caseMRCA2 ? Math.log(2.0d) : 0.0d);
        } else {
            log2 = log + ((partitionedTreeModel.isRootBlockedBy(abstractCase, abstractCase2) && partitionedTreeModel.isAncestral(partitionedTreeModel.getParent(parent))) ? Math.log(2.0d) : 0.0d);
        }
        if (this.resampleInfectionTimes) {
            abstractCase.setInfectionBranchPosition(MathUtils.nextDouble());
        }
        return log2;
    }

    private double moveDown(PartitionedTreeModel partitionedTreeModel, NodeRef nodeRef) {
        BranchMapModel branchMap = partitionedTreeModel.getBranchMap();
        AbstractCase abstractCase = branchMap.get(nodeRef.getNumber());
        AbstractCase abstractCase2 = branchMap.get(partitionedTreeModel.getParent(nodeRef).getNumber());
        AbstractCase[] arrayCopy = branchMap.getArrayCopy();
        double d = 0.0d;
        NodeRef parent = partitionedTreeModel.getParent(nodeRef);
        NodeRef caseMRCA = partitionedTreeModel.caseMRCA(abstractCase);
        for (int i = 0; i < partitionedTreeModel.getChildCount(nodeRef); i++) {
            NodeRef child = partitionedTreeModel.getChild(nodeRef, i);
            if (!this.c2cLikelihood.getTreeModel().isAncestral(child)) {
                Iterator<Integer> it = this.c2cLikelihood.getTreeModel().samePartitionElementDownTree(child).iterator();
                while (it.hasNext()) {
                    arrayCopy[it.next().intValue()] = branchMap.get(parent.getNumber());
                }
                arrayCopy[child.getNumber()] = branchMap.get(parent.getNumber());
            } else if (child == caseMRCA && branchMap.get(child.getNumber()) == branchMap.get(nodeRef.getNumber())) {
                d += Math.log(2.0d);
            }
        }
        double log = d + ((partitionedTreeModel.isRootBlockedBy(abstractCase, abstractCase2) && partitionedTreeModel.isAncestral(parent)) ? Math.log(0.5d) : 0.0d);
        if (this.resampleInfectionTimes) {
            abstractCase.setInfectionBranchPosition(MathUtils.nextDouble());
        }
        arrayCopy[nodeRef.getNumber()] = branchMap.get(parent.getNumber());
        branchMap.setAll(arrayCopy, false);
        return log;
    }

    public String getPerformanceSuggestion() {
        return "Not implemented";
    }
}
