package dr.evomodel.epidemiology.casetocase;

import dr.app.tools.NexusExporter;
import dr.evolution.coalescent.Coalescent;
import dr.evolution.coalescent.DemographicFunction;
import dr.evolution.coalescent.IntervalList;
import dr.evolution.coalescent.IntervalType;
import dr.evolution.coalescent.LinearGrowth;
import dr.evolution.tree.FlexibleNode;
import dr.evolution.tree.FlexibleTree;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.util.TaxonList;
import dr.evomodel.coalescent.DemographicModel;
import dr.evomodel.epidemiology.casetocase.BranchMapModel;
import dr.evomodel.epidemiology.casetocase.CaseToCaseTreeLikelihood;
import dr.evomodel.epidemiology.casetocase.PartitionedTreeModel;
import dr.evomodel.tree.TreeModel;
import dr.inference.loggers.LogColumn;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.math.BigDecimalUtils;
import dr.math.Binomial;
import dr.xml.AbstractXMLObjectParser;
import dr.xml.AttributeRule;
import dr.xml.ElementRule;
import dr.xml.XMLObject;
import dr.xml.XMLObjectParser;
import dr.xml.XMLParseException;
import dr.xml.XMLSyntaxRule;
import java.io.IOException;
import java.io.PrintStream;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;

/* loaded from: input_file:dr/evomodel/epidemiology/casetocase/WithinCaseCoalescent.class */
public class WithinCaseCoalescent extends CaseToCaseTreeLikelihood {
    public static final String WITHIN_CASE_COALESCENT = "withinCaseCoalescent";
    private double[] partitionTreeLogLikelihoods;
    private double[] storedPartitionTreeLogLikelihoods;
    private boolean[] recalculateCoalescentFlags;
    private DemographicModel demoModel;
    private Mode mode;
    private double coalescencesLogLikelihood;
    private double storedCoalescencesLogLikelihood;
    private boolean pleaseReExplode;
    public static XMLObjectParser PARSER = new AbstractXMLObjectParser() { // from class: dr.evomodel.epidemiology.casetocase.WithinCaseCoalescent.3
        public static final String STARTING_NETWORK = "startingNetwork";
        public static final String MAX_FIRST_INF_TO_ROOT = "maxFirstInfToRoot";
        public static final String DEMOGRAPHIC_MODEL = "demographicModel";
        public static final String TRUNCATE = "truncate";
        private final XMLSyntaxRule[] rules = {new ElementRule(PartitionedTreeModel.class, "The tree"), new ElementRule(CategoryOutbreak.class, "The set of cases", 0, 1), new ElementRule(CategoryOutbreak.class, "The set of cases", 0, 1), new ElementRule(STARTING_NETWORK, String.class, "A CSV file containing a specified starting network", true), new ElementRule(MAX_FIRST_INF_TO_ROOT, Parameter.class, "The maximum time from the first infection tothe root node"), new ElementRule("demographicModel", DemographicModel.class, "The demographic model for within-caseevolution"), AttributeRule.newBooleanRule(TRUNCATE)};

        @Override // dr.xml.XMLObjectParser
        public String getParserName() {
            return WithinCaseCoalescent.WITHIN_CASE_COALESCENT;
        }

        @Override // dr.xml.AbstractXMLObjectParser
        public Object parseXMLObject(XMLObject xMLObject) throws XMLParseException {
            try {
                return new WithinCaseCoalescent((PartitionedTreeModel) xMLObject.getChild(TreeModel.class), (AbstractOutbreak) xMLObject.getChild(AbstractOutbreak.class), (Parameter) xMLObject.getElementFirstChild(MAX_FIRST_INF_TO_ROOT), (DemographicModel) xMLObject.getElementFirstChild("demographicModel"), xMLObject.hasAttribute(TRUNCATE) & xMLObject.getBooleanAttribute(TRUNCATE) ? Mode.TRUNCATE : Mode.NORMAL);
            } catch (TaxonList.MissingTaxonException e) {
                throw new XMLParseException(e.toString());
            }
        }

        @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
        public String getParserDescription() {
            return "This element provides a tree prior for a partitioned tree, with each partitioned tree generatedby a coalescent process";
        }

        @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
        public Class getReturnType() {
            return WithinCaseCoalescent.class;
        }

        @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
        public XMLSyntaxRule[] getSyntaxRules() {
            return this.rules;
        }
    };

    /* loaded from: input_file:dr/evomodel/epidemiology/casetocase/WithinCaseCoalescent$Mode.class */
    private enum Mode {
        TRUNCATE,
        NORMAL
    }

    /* loaded from: input_file:dr/evomodel/epidemiology/casetocase/WithinCaseCoalescent$SpecifiedZeroCoalescent.class */
    private class SpecifiedZeroCoalescent extends Coalescent {
        private double zeroHeight;
        boolean truncate;

        private SpecifiedZeroCoalescent(Tree tree, DemographicModel demographicModel, double d, boolean z) {
            super(tree, demographicModel.getDemographicFunction());
            this.zeroHeight = d;
            this.truncate = z;
        }

        @Override // dr.evolution.coalescent.Coalescent
        public double calculateLogLikelihood() {
            return WithinCaseCoalescent.calculatePartitionTreeLogLikelihood(getIntervals(), getDemographicFunction(), 0.0d, this.zeroHeight, this.truncate);
        }
    }

    public WithinCaseCoalescent(PartitionedTreeModel partitionedTreeModel, AbstractOutbreak abstractOutbreak, Parameter parameter, DemographicModel demographicModel, Mode mode) throws TaxonList.MissingTaxonException {
        super(WITHIN_CASE_COALESCENT, partitionedTreeModel, abstractOutbreak, parameter);
        this.pleaseReExplode = true;
        this.mode = mode;
        this.demoModel = demographicModel;
        addModel(demographicModel);
        addModel(this.outbreak);
        this.partitionTreeLogLikelihoods = new double[this.outbreak.getCases().size()];
        this.storedPartitionTreeLogLikelihoods = new double[this.outbreak.getCases().size()];
        this.recalculateCoalescentFlags = new boolean[this.outbreak.getCases().size()];
        Arrays.fill(this.recalculateCoalescentFlags, true);
        this.elementsAsTrees = new HashMap<>();
        Iterator<AbstractCase> it = this.outbreak.getCases().iterator();
        while (it.hasNext()) {
            AbstractCase next = it.next();
            if (next.wasEverInfected()) {
                this.elementsAsTrees.put(next, null);
            }
        }
        this.storedElementsAsTrees = new HashMap<>();
    }

    @Override // dr.evomodel.epidemiology.casetocase.CaseToCaseTreeLikelihood, dr.oldevomodel.treelikelihood.AbstractTreeLikelihood
    protected double calculateLogLikelihood() {
        if (this.pleaseReExplode) {
            explodeTree();
        }
        this.coalescencesLogLikelihood = 0.0d;
        Iterator<AbstractCase> it = this.outbreak.getCases().iterator();
        while (it.hasNext()) {
            AbstractCase next = it.next();
            int caseIndex = this.outbreak.getCaseIndex(next);
            if (!next.wasEverInfected()) {
                this.recalculateCoalescentFlags[caseIndex] = false;
            } else if (this.recalculateCoalescentFlags[caseIndex]) {
                CaseToCaseTreeLikelihood.Treelet treelet = this.elementsAsTrees.get(next);
                if (treelet.getExternalNodeCount() > 1) {
                    this.partitionTreeLogLikelihoods[caseIndex] = new SpecifiedZeroCoalescent(treelet, this.demoModel, treelet.getZeroHeight(), this.mode == Mode.TRUNCATE).calculateLogLikelihood();
                    this.coalescencesLogLikelihood += this.partitionTreeLogLikelihoods[caseIndex];
                } else {
                    this.partitionTreeLogLikelihoods[caseIndex] = 0.0d;
                }
                this.recalculateCoalescentFlags[caseIndex] = false;
            } else {
                this.coalescencesLogLikelihood += this.partitionTreeLogLikelihoods[caseIndex];
            }
        }
        double d = 0.0d + this.coalescencesLogLikelihood;
        this.likelihoodKnown = true;
        return d;
    }

    @Override // dr.evomodel.epidemiology.casetocase.CaseToCaseTreeLikelihood, dr.oldevomodel.treelikelihood.AbstractTreeLikelihood, dr.inference.model.AbstractModel
    public void storeState() {
        super.storeState();
        this.storedElementsAsTrees = new HashMap<>(this.elementsAsTrees);
        this.storedPartitionTreeLogLikelihoods = Arrays.copyOf(this.partitionTreeLogLikelihoods, this.partitionTreeLogLikelihoods.length);
        this.storedCoalescencesLogLikelihood = this.coalescencesLogLikelihood;
    }

    @Override // dr.evomodel.epidemiology.casetocase.CaseToCaseTreeLikelihood, dr.oldevomodel.treelikelihood.AbstractTreeLikelihood, dr.inference.model.AbstractModel
    public void restoreState() {
        super.restoreState();
        this.elementsAsTrees = this.storedElementsAsTrees;
        this.partitionTreeLogLikelihoods = this.storedPartitionTreeLogLikelihoods;
        this.coalescencesLogLikelihood = this.storedCoalescencesLogLikelihood;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // dr.evomodel.epidemiology.casetocase.CaseToCaseTreeLikelihood, dr.oldevomodel.treelikelihood.AbstractTreeLikelihood, dr.inference.model.AbstractModel
    public void handleModelChangedEvent(Model model, Object obj, int i) {
        super.handleModelChangedEvent(model, obj, i);
        if (model == this.treeModel) {
            if (obj instanceof PartitionedTreeModel.PartitionsChangedEvent) {
                Iterator<AbstractCase> it = ((PartitionedTreeModel.PartitionsChangedEvent) obj).getCasesToRecalculate().iterator();
                while (it.hasNext()) {
                    recalculateCaseWCC(it.next());
                }
                return;
            }
            return;
        }
        if (model == getBranchMap()) {
            if (!(obj instanceof ArrayList)) {
                throw new RuntimeException("Unanticipated model changed event from BranchMapModel");
            }
            for (int i2 = 0; i2 < ((ArrayList) obj).size(); i2++) {
                BranchMapModel.BranchMapChangedEvent branchMapChangedEvent = (BranchMapModel.BranchMapChangedEvent) ((ArrayList) obj).get(i2);
                recalculateCaseWCC(branchMapChangedEvent.getOldCase());
                recalculateCaseWCC(branchMapChangedEvent.getNewCase());
                NodeRef parent = this.treeModel.getParent(this.treeModel.getNode(branchMapChangedEvent.getNodeToRecalculate()));
                if (parent != null) {
                    recalculateCaseWCC(getBranchMap().get(parent.getNumber()));
                }
            }
            return;
        }
        if (model == this.demoModel) {
            Arrays.fill(this.recalculateCoalescentFlags, true);
            return;
        }
        if (model == this.outbreak && (obj instanceof AbstractCase)) {
            AbstractCase abstractCase = (AbstractCase) obj;
            recalculateCaseWCC(abstractCase);
            AbstractCase infector = ((PartitionedTreeModel) this.treeModel).getInfector(abstractCase);
            if (infector != null) {
                recalculateCaseWCC(infector);
            }
        }
    }

    protected void recalculateCaseWCC(int i) {
        this.elementsAsTrees.put(this.outbreak.getCase(i), null);
        this.pleaseReExplode = true;
        this.recalculateCoalescentFlags[i] = true;
    }

    protected void recalculateCaseWCC(AbstractCase abstractCase) {
        if (abstractCase.wasEverInfected()) {
            recalculateCaseWCC(this.outbreak.getCaseIndex(abstractCase));
        }
    }

    @Override // dr.evomodel.epidemiology.casetocase.CaseToCaseTreeLikelihood, dr.oldevomodel.treelikelihood.AbstractTreeLikelihood, dr.inference.model.Likelihood
    public void makeDirty() {
        super.makeDirty();
        Arrays.fill(this.recalculateCoalescentFlags, true);
        Iterator<AbstractCase> it = this.outbreak.getCases().iterator();
        while (it.hasNext()) {
            AbstractCase next = it.next();
            if (next.wasEverInfected()) {
                this.elementsAsTrees.put(next, null);
            }
        }
        this.pleaseReExplode = true;
    }

    public ArrayList<AbstractCase> postOrderTransmissionTreeTraversal() {
        return traverseTransmissionTree(getBranchMap().get(this.treeModel.getRoot().getNumber()));
    }

    private ArrayList<AbstractCase> traverseTransmissionTree(AbstractCase abstractCase) {
        ArrayList<AbstractCase> arrayList = new ArrayList<>();
        HashSet<AbstractCase> infectees = ((PartitionedTreeModel) this.treeModel).getInfectees(abstractCase);
        for (int i = 0; i < getOutbreak().size(); i++) {
            AbstractCase abstractCase2 = getOutbreak().getCase(i);
            if (infectees.contains(abstractCase2)) {
                arrayList.addAll(traverseTransmissionTree(abstractCase2));
            }
        }
        arrayList.add(abstractCase);
        return arrayList;
    }

    private CaseToCaseTreeLikelihood.Treelet transformTreelet(CaseToCaseTreeLikelihood.Treelet treelet) {
        double[] dArr = new double[treelet.getNodeCount()];
        double zeroHeight = treelet.getZeroHeight();
        double d = zeroHeight - 1.0d;
        for (int i = 0; i < treelet.getNodeCount(); i++) {
            dArr[i] = -Math.log(-(treelet.getNodeHeight(treelet.getNode(i)) - zeroHeight));
        }
        double d2 = Double.POSITIVE_INFINITY;
        for (double d3 : dArr) {
            if (d3 < d2) {
                d2 = d3;
            }
        }
        CaseToCaseTreeLikelihood.Treelet treelet2 = new CaseToCaseTreeLikelihood.Treelet(treelet, -d2);
        for (int i2 = 0; i2 < treelet2.getNodeCount(); i2++) {
            treelet2.setNodeHeight(treelet2.getNode(i2), dArr[i2] - d2);
        }
        treelet2.resolveTree();
        return treelet2;
    }

    public static double calculatePartitionTreeLogLikelihood(IntervalList intervalList, DemographicFunction demographicFunction, double d, double d2, boolean z) {
        double d3;
        double log;
        double d4 = 0.0d;
        double d5 = -d2;
        int intervalCount = intervalList.getIntervalCount();
        for (int i = 0; i < intervalCount; i++) {
            if (z) {
                double interval = intervalList.getInterval(i);
                d3 = d5 + interval;
                if (d3 == 0.0d) {
                    return Double.NEGATIVE_INFINITY;
                }
                double integral = demographicFunction.getIntegral(d5, d3);
                double integral2 = demographicFunction.getIntegral(d5, 0.0d);
                if (integral == 0.0d && interval > tolerance) {
                    return Double.NEGATIVE_INFINITY;
                }
                int lineageCount = intervalList.getLineageCount(i);
                if (lineageCount >= 2) {
                    double choose2 = Binomial.choose2(lineageCount);
                    if (intervalList.getIntervalType(i) == IntervalType.COALESCENT) {
                        double d6 = d4 + ((-choose2) * integral);
                        double demographic = demographicFunction.getDemographic(d3);
                        if (interval != 0.0d && demographic * (integral / interval) < d) {
                            return Double.NEGATIVE_INFINITY;
                        }
                        log = d6 - Math.log(demographic);
                    } else {
                        log = d4 + Math.log(Math.exp((-choose2) * integral) - Math.exp((-choose2) * integral2));
                    }
                    double exp = Math.exp((-choose2) * integral2);
                    d4 = log - (exp != 1.0d ? Math.log1p(-exp) : handleDenominatorUnderflow((-choose2) * integral2));
                } else {
                    continue;
                }
            } else {
                if (!(demographicFunction instanceof LinearGrowth)) {
                    throw new RuntimeException("Function must have zero population at t=0 if truncate=false");
                }
                double interval2 = intervalList.getInterval(i);
                d3 = d5 + interval2;
                double integral3 = demographicFunction.getIntegral(d5, d3);
                if (integral3 == 0.0d && interval2 != 0.0d) {
                    return Double.NEGATIVE_INFINITY;
                }
                d4 += (-Binomial.choose2(intervalList.getLineageCount(i))) * integral3;
                if (intervalList.getIntervalType(i) == IntervalType.COALESCENT) {
                    double demographic2 = demographicFunction.getDemographic(d3);
                    if (interval2 != 0.0d && demographic2 * (integral3 / interval2) < d) {
                        return Double.NEGATIVE_INFINITY;
                    }
                    d4 -= Math.log(demographic2);
                } else {
                    continue;
                }
            }
            d5 = d3;
        }
        return d4;
    }

    private static double handleDenominatorUnderflow(double d) {
        BigDecimal bigDecimal = new BigDecimal(d);
        BigDecimal subtract = new BigDecimal(1.0d).subtract(BigDecimalUtils.exp(bigDecimal, bigDecimal.scale()));
        return BigDecimalUtils.ln(subtract, subtract.scale()).doubleValue();
    }

    public void debugTreelet(Tree tree, String str) {
        try {
            FlexibleTree flexibleTree = new FlexibleTree(tree);
            for (int i = 0; i < flexibleTree.getNodeCount(); i++) {
                FlexibleNode flexibleNode = (FlexibleNode) flexibleTree.getNode(i);
                flexibleNode.setAttribute("Number", Integer.valueOf(flexibleNode.getNumber()));
            }
            new NexusExporter(new PrintStream(str)).exportTree(flexibleTree);
        } catch (IOException e) {
            System.out.println("IOException");
        }
    }

    @Override // dr.evomodel.epidemiology.casetocase.CaseToCaseTreeLikelihood
    public LogColumn[] passColumns() {
        ArrayList arrayList = new ArrayList(Arrays.asList(super.passColumns()));
        if (!(this.outbreak instanceof CategoryOutbreak)) {
            return null;
        }
        for (int i = 0; i < this.outbreak.size(); i++) {
            if (this.outbreak.getCase(i).wasEverInfected()) {
                final int i2 = i;
                arrayList.add(new LogColumn.Abstract("coal_LL_" + i) { // from class: dr.evomodel.epidemiology.casetocase.WithinCaseCoalescent.1
                    @Override // dr.inference.loggers.LogColumn.Abstract
                    protected String getFormattedValue() {
                        return String.valueOf(WithinCaseCoalescent.this.partitionTreeLogLikelihoods[i2]);
                    }
                });
            }
        }
        arrayList.add(new LogColumn.Abstract("total_coal_LL") { // from class: dr.evomodel.epidemiology.casetocase.WithinCaseCoalescent.2
            @Override // dr.inference.loggers.LogColumn.Abstract
            protected String getFormattedValue() {
                return String.valueOf(WithinCaseCoalescent.this.coalescencesLogLikelihood);
            }
        });
        return (LogColumn[]) arrayList.toArray(new LogColumn[arrayList.size()]);
    }
}
