package dr.evomodel.epidemiology.casetocase;

import dr.evomodel.coalescent.DemographicModel;
import dr.evomodel.epidemiology.casetocase.periodpriors.AbstractPeriodPriorDistribution;
import dr.inference.distribution.ParametricDistributionModel;
import dr.inference.loggers.LogColumn;
import dr.inference.loggers.Loggable;
import dr.inference.model.AbstractModelLikelihood;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.xml.AbstractXMLObjectParser;
import dr.xml.ElementRule;
import dr.xml.XMLObject;
import dr.xml.XMLObjectParser;
import dr.xml.XMLParseException;
import dr.xml.XMLSyntaxRule;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;

/* loaded from: input_file:dr/evomodel/epidemiology/casetocase/CaseToCaseTransmissionLikelihood.class */
public class CaseToCaseTransmissionLikelihood extends AbstractModelLikelihood implements Loggable {
    private static final boolean DEBUG = false;
    private CategoryOutbreak outbreak;
    private CaseToCaseTreeLikelihood treeLikelihood;
    private SpatialKernel spatialKernel;
    private Parameter transmissionRate;
    private boolean likelihoodKnown;
    private boolean storedLikelihoodKnown;
    private boolean transProbKnown;
    private boolean storedTransProbKnown;
    private boolean periodsProbKnown;
    private boolean storedPeriodsProbKnown;
    private boolean treeProbKnown;
    private boolean storedTreeProbKnown;
    private double logLikelihood;
    private double storedLogLikelihood;
    private double transLogProb;
    private double storedTransLogProb;
    private double periodsLogProb;
    private double storedPeriodsLogProb;
    private double treeLogProb;
    private double storedTreeLogProb;
    private ParametricDistributionModel initialInfectionTimePrior;
    private HashMap<AbstractCase, Double> indexCasePrior;
    private final boolean hasGeography;
    private final boolean hasLatentPeriods;
    private ArrayList<TreeEvent> sortedTreeEvents;
    private ArrayList<TreeEvent> storedSortedTreeEvents;
    private AbstractCase indexCase;
    private AbstractCase storedIndexCase;
    public static final String CASE_TO_CASE_TRANSMISSION_LIKELIHOOD = "caseToCaseTransmissionLikelihood";
    public static XMLObjectParser PARSER = new AbstractXMLObjectParser() { // from class: dr.evomodel.epidemiology.casetocase.CaseToCaseTransmissionLikelihood.1
        public static final String TRANSMISSION_RATE = "transmissionRate";
        public static final String INITIAL_INFECTION_TIME_PRIOR = "initialInfectionTimePrior";
        private final XMLSyntaxRule[] rules = {new ElementRule(CaseToCaseTreeLikelihood.class, "The tree likelihood"), new ElementRule(SpatialKernel.class, "The spatial kernel", 0, 1), new ElementRule(TRANSMISSION_RATE, Parameter.class, "The transmission rate"), new ElementRule(INITIAL_INFECTION_TIME_PRIOR, ParametricDistributionModel.class, "The prior probability distibution of the first infection", true)};

        @Override // dr.xml.AbstractXMLObjectParser
        public Object parseXMLObject(XMLObject xMLObject) throws XMLParseException {
            CaseToCaseTreeLikelihood caseToCaseTreeLikelihood = (CaseToCaseTreeLikelihood) xMLObject.getChild(CaseToCaseTreeLikelihood.class);
            SpatialKernel spatialKernel = (SpatialKernel) xMLObject.getChild(SpatialKernel.class);
            Parameter parameter = (Parameter) xMLObject.getElementFirstChild(TRANSMISSION_RATE);
            ParametricDistributionModel parametricDistributionModel = null;
            if (xMLObject.hasChildNamed(INITIAL_INFECTION_TIME_PRIOR)) {
                parametricDistributionModel = (ParametricDistributionModel) xMLObject.getElementFirstChild(INITIAL_INFECTION_TIME_PRIOR);
            }
            return new CaseToCaseTransmissionLikelihood(CaseToCaseTransmissionLikelihood.CASE_TO_CASE_TRANSMISSION_LIKELIHOOD, (CategoryOutbreak) caseToCaseTreeLikelihood.getOutbreak(), caseToCaseTreeLikelihood, spatialKernel, parameter, parametricDistributionModel);
        }

        @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
        public XMLSyntaxRule[] getSyntaxRules() {
            return this.rules;
        }

        @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
        public String getParserDescription() {
            return "This element represents a probability distribution for epidemiological parameters of an outbreakgiven a phylogenetic tree";
        }

        @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
        public Class getReturnType() {
            return CaseToCaseTransmissionLikelihood.class;
        }

        @Override // dr.xml.XMLObjectParser
        public String getParserName() {
            return CaseToCaseTransmissionLikelihood.CASE_TO_CASE_TRANSMISSION_LIKELIHOOD;
        }
    };

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:dr/evomodel/epidemiology/casetocase/CaseToCaseTransmissionLikelihood$EventComparator.class */
    public class EventComparator implements Comparator<TreeEvent> {
        private EventComparator() {
        }

        @Override // java.util.Comparator
        public int compare(TreeEvent treeEvent, TreeEvent treeEvent2) {
            return Double.compare(treeEvent.getTime(), treeEvent2.getTime());
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:dr/evomodel/epidemiology/casetocase/CaseToCaseTransmissionLikelihood$EventType.class */
    public enum EventType {
        INFECTION,
        INFECTIOUSNESS,
        END
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:dr/evomodel/epidemiology/casetocase/CaseToCaseTransmissionLikelihood$TreeEvent.class */
    public class TreeEvent {
        private EventType type;
        private double time;
        private AbstractCase aCase;
        private AbstractCase infectorCase;

        private TreeEvent(EventType eventType, double d, AbstractCase abstractCase) {
            this.type = eventType;
            this.time = d;
            this.aCase = abstractCase;
            this.infectorCase = null;
        }

        private TreeEvent(double d, AbstractCase abstractCase, AbstractCase abstractCase2) {
            this.type = EventType.INFECTION;
            this.time = d;
            this.aCase = abstractCase;
            this.infectorCase = abstractCase2;
        }

        public double getTime() {
            return this.time;
        }

        public EventType getType() {
            return this.type;
        }

        public AbstractCase getCase() {
            return this.aCase;
        }

        public AbstractCase getInfector() {
            return this.infectorCase;
        }
    }

    public CaseToCaseTransmissionLikelihood(String str, CategoryOutbreak categoryOutbreak, CaseToCaseTreeLikelihood caseToCaseTreeLikelihood, SpatialKernel spatialKernel, Parameter parameter, ParametricDistributionModel parametricDistributionModel) {
        super(str);
        this.outbreak = categoryOutbreak;
        this.treeLikelihood = caseToCaseTreeLikelihood;
        this.spatialKernel = spatialKernel;
        if (spatialKernel != null) {
            addModel(spatialKernel);
        }
        this.transmissionRate = parameter;
        addModel(caseToCaseTreeLikelihood);
        addVariable(parameter);
        this.likelihoodKnown = false;
        this.hasGeography = spatialKernel != null;
        this.hasLatentPeriods = caseToCaseTreeLikelihood.hasLatentPeriods();
        this.initialInfectionTimePrior = parametricDistributionModel;
        HashMap<AbstractCase, Double> weightMap = categoryOutbreak.getWeightMap();
        double d = 0.0d;
        for (AbstractCase abstractCase : weightMap.keySet()) {
            if (abstractCase.wasEverInfected) {
                d += weightMap.get(abstractCase).doubleValue();
            }
        }
        this.indexCasePrior = new HashMap<>();
        Iterator<AbstractCase> it = categoryOutbreak.getCases().iterator();
        while (it.hasNext()) {
            AbstractCase next = it.next();
            if (next.wasEverInfected) {
                this.indexCasePrior.put(next, Double.valueOf(weightMap.get(next).doubleValue() / d));
            }
        }
        sortEvents();
    }

    @Override // dr.inference.model.AbstractModel
    protected void handleModelChangedEvent(Model model, Object obj, int i) {
        if (model instanceof CaseToCaseTreeLikelihood) {
            this.treeProbKnown = false;
            if (!(obj instanceof DemographicModel)) {
                this.transProbKnown = false;
                this.periodsProbKnown = false;
                this.sortedTreeEvents = null;
                this.indexCase = null;
            }
        } else if (model instanceof SpatialKernel) {
            this.transProbKnown = false;
        } else if (model instanceof AbstractOutbreak) {
            this.transProbKnown = false;
            this.periodsProbKnown = false;
            this.sortedTreeEvents = null;
            this.indexCase = null;
        }
        this.likelihoodKnown = false;
    }

    @Override // dr.inference.model.AbstractModel
    protected void handleVariableChangedEvent(Variable variable, int i, Variable.ChangeType changeType) {
        if (variable == this.transmissionRate) {
            this.transProbKnown = false;
        }
        this.likelihoodKnown = false;
    }

    @Override // dr.inference.model.AbstractModel
    protected void storeState() {
        this.storedLogLikelihood = this.logLikelihood;
        this.storedLikelihoodKnown = this.likelihoodKnown;
        this.storedPeriodsLogProb = this.periodsLogProb;
        this.storedPeriodsProbKnown = this.periodsProbKnown;
        this.storedTransLogProb = this.transLogProb;
        this.storedTransProbKnown = this.transProbKnown;
        this.storedTreeLogProb = this.treeLogProb;
        this.storedTreeProbKnown = this.treeProbKnown;
        this.storedSortedTreeEvents = new ArrayList<>(this.sortedTreeEvents);
        this.storedIndexCase = this.indexCase;
    }

    @Override // dr.inference.model.AbstractModel
    protected void restoreState() {
        this.logLikelihood = this.storedLogLikelihood;
        this.likelihoodKnown = this.storedLikelihoodKnown;
        this.transLogProb = this.storedTransLogProb;
        this.transProbKnown = this.storedTransProbKnown;
        this.treeLogProb = this.storedTreeLogProb;
        this.treeProbKnown = this.storedTreeProbKnown;
        this.periodsLogProb = this.storedPeriodsLogProb;
        this.periodsProbKnown = this.storedPeriodsProbKnown;
        this.sortedTreeEvents = this.storedSortedTreeEvents;
        this.indexCase = this.storedIndexCase;
    }

    @Override // dr.inference.model.AbstractModel
    protected void acceptState() {
    }

    public SpatialKernel getSpatialKernel() {
        return this.spatialKernel;
    }

    @Override // dr.inference.model.Likelihood
    public Model getModel() {
        return this;
    }

    public CaseToCaseTreeLikelihood getTreeLikelihood() {
        return this.treeLikelihood;
    }

    @Override // dr.inference.model.Likelihood
    public double getLogLikelihood() {
        if (!this.likelihoodKnown) {
            if (!this.treeProbKnown) {
                this.treeLikelihood.prepareTimings();
            }
            if (!this.transProbKnown) {
                try {
                    this.transLogProb = 0.0d;
                    if (this.sortedTreeEvents == null) {
                        sortEvents();
                    }
                    double parameterValue = this.transmissionRate.getParameterValue(0);
                    ArrayList arrayList = new ArrayList();
                    boolean z = true;
                    Iterator<TreeEvent> it = this.sortedTreeEvents.iterator();
                    while (it.hasNext()) {
                        TreeEvent next = it.next();
                        double time = next.getTime();
                        AbstractCase abstractCase = next.getCase();
                        if (next.getType() == EventType.INFECTION) {
                            if (z) {
                                if (this.indexCasePrior != null) {
                                    this.transLogProb += Math.log(this.indexCasePrior.get(abstractCase).doubleValue());
                                }
                                if (this.initialInfectionTimePrior != null) {
                                    this.transLogProb += this.initialInfectionTimePrior.logPdf(time);
                                }
                                if (!this.hasLatentPeriods) {
                                    arrayList.add(abstractCase);
                                }
                                z = false;
                            } else {
                                AbstractCase infector = next.getInfector();
                                if (abstractCase.wasEverInfected()) {
                                    if (arrayList.contains(abstractCase)) {
                                        throw new BadPartitionException(abstractCase.caseID + " infected after it was infectious");
                                    }
                                    if (next.getTime() > abstractCase.endOfInfectiousTime) {
                                        throw new BadPartitionException(abstractCase.caseID + " ceased to be infected before it was infected");
                                    }
                                    if (infector.endOfInfectiousTime < next.getTime()) {
                                        throw new BadPartitionException(abstractCase.caseID + " infected by " + infector.caseID + " after the latter ceased to be infectious");
                                    }
                                    if (this.treeLikelihood.getInfectiousTime(infector) > next.getTime()) {
                                        throw new BadPartitionException(abstractCase.caseID + " infected by " + infector.caseID + " before the latter became infectious");
                                    }
                                    if (!arrayList.contains(infector)) {
                                        throw new RuntimeException("Infector not previously infected");
                                    }
                                }
                                Iterator it2 = arrayList.iterator();
                                while (it2.hasNext()) {
                                    AbstractCase abstractCase2 = (AbstractCase) it2.next();
                                    double infectiousTime = abstractCase2.endOfInfectiousTime < next.getTime() ? abstractCase2.endOfInfectiousTime - this.treeLikelihood.getInfectiousTime(abstractCase2) : next.getTime() - this.treeLikelihood.getInfectiousTime(abstractCase2);
                                    if (infectiousTime < 0.0d) {
                                        throw new RuntimeException("negative time");
                                    }
                                    double d = parameterValue;
                                    if (this.hasGeography) {
                                        d *= this.outbreak.getKernelValue(abstractCase, abstractCase2, this.spatialKernel);
                                    }
                                    this.transLogProb += (-d) * infectiousTime;
                                }
                                if (abstractCase.wasEverInfected()) {
                                    double d2 = parameterValue;
                                    if (this.hasGeography) {
                                        d2 *= this.outbreak.getKernelValue(abstractCase, infector, this.spatialKernel);
                                    }
                                    this.transLogProb += Math.log(d2);
                                }
                                if (!this.hasLatentPeriods) {
                                    arrayList.add(abstractCase);
                                }
                            }
                        } else if (next.getType() == EventType.INFECTIOUSNESS && next.getTime() < Double.POSITIVE_INFINITY) {
                            if (next.getTime() > next.getCase().endOfInfectiousTime) {
                                throw new BadPartitionException(next.getCase().caseID + " noninfectious beforeinfectious");
                            }
                            if (z) {
                                throw new RuntimeException("First event is not an infection");
                            }
                            arrayList.add(abstractCase);
                        }
                    }
                    this.transProbKnown = true;
                } catch (BadPartitionException e) {
                    this.transLogProb = Double.NEGATIVE_INFINITY;
                    this.transProbKnown = true;
                    this.logLikelihood = Double.NEGATIVE_INFINITY;
                    this.likelihoodKnown = true;
                    return this.logLikelihood;
                }
            }
            if (!this.periodsProbKnown) {
                this.periodsLogProb = 0.0d;
                HashMap hashMap = new HashMap();
                Iterator<AbstractCase> it3 = this.outbreak.getCases().iterator();
                while (it3.hasNext()) {
                    AbstractCase next2 = it3.next();
                    if (next2.wasEverInfected()) {
                        String infectiousCategory = this.outbreak.getInfectiousCategory(next2);
                        if (!hashMap.keySet().contains(infectiousCategory)) {
                            hashMap.put(infectiousCategory, new ArrayList());
                        }
                        ((ArrayList) hashMap.get(infectiousCategory)).add(Double.valueOf(this.treeLikelihood.getInfectiousPeriod(next2)));
                    }
                }
                Iterator<String> it4 = this.outbreak.getInfectiousCategories().iterator();
                while (it4.hasNext()) {
                    String next3 = it4.next();
                    Double[] dArr = (Double[]) ((ArrayList) hashMap.get(next3)).toArray(new Double[((ArrayList) hashMap.get(next3)).size()]);
                    AbstractPeriodPriorDistribution infectiousCategoryPrior = this.outbreak.getInfectiousCategoryPrior(next3);
                    double[] dArr2 = new double[dArr.length];
                    for (int i = 0; i < dArr.length; i++) {
                        dArr2[i] = dArr[i].doubleValue();
                    }
                    this.periodsLogProb += infectiousCategoryPrior.getLogLikelihood(dArr2);
                }
                this.periodsProbKnown = true;
            }
            if (!this.treeProbKnown) {
                this.treeLogProb = this.treeLikelihood.getLogLikelihood();
                this.treeProbKnown = true;
            }
            if (this.transLogProb == Double.POSITIVE_INFINITY) {
                System.out.println("TransLogProb +INF");
                return Double.NEGATIVE_INFINITY;
            }
            if (this.periodsLogProb == Double.POSITIVE_INFINITY) {
                System.out.println("PeriodsLogProb +INF");
                return Double.NEGATIVE_INFINITY;
            }
            if (this.treeLogProb == Double.POSITIVE_INFINITY) {
                System.out.println("TreeLogProb +INF");
                return Double.NEGATIVE_INFINITY;
            }
            this.logLikelihood = this.treeLogProb + this.periodsLogProb + this.transLogProb;
            this.likelihoodKnown = true;
        }
        return this.logLikelihood;
    }

    @Override // dr.inference.model.Likelihood
    public void makeDirty() {
        this.likelihoodKnown = false;
        this.transProbKnown = false;
        this.periodsProbKnown = false;
        this.treeProbKnown = false;
        this.sortedTreeEvents = null;
        this.treeLikelihood.makeDirty();
        this.indexCase = null;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public void sortEvents() {
        ArrayList<TreeEvent> arrayList = new ArrayList<>();
        Iterator<AbstractCase> it = this.outbreak.getCases().iterator();
        while (it.hasNext()) {
            AbstractCase next = it.next();
            arrayList.add(new TreeEvent(this.treeLikelihood.getInfectionTime(next), next, this.treeLikelihood.getInfector(this.outbreak.getCaseIndex(next))));
            if (next.wasEverInfected()) {
                arrayList.add(new TreeEvent(EventType.END, next.endOfInfectiousTime, next));
                if (this.hasLatentPeriods) {
                    arrayList.add(new TreeEvent(EventType.INFECTIOUSNESS, this.treeLikelihood.getInfectiousTime(next), next));
                }
            }
        }
        Collections.sort(arrayList, new EventComparator());
        this.indexCase = arrayList.get(0).getCase();
        this.sortedTreeEvents = arrayList;
    }

    @Override // dr.inference.model.AbstractModelLikelihood, dr.inference.loggers.Loggable
    public LogColumn[] getColumns() {
        ArrayList arrayList = new ArrayList();
        arrayList.add(new LogColumn.Abstract("trans_LL") { // from class: dr.evomodel.epidemiology.casetocase.CaseToCaseTransmissionLikelihood.2
            @Override // dr.inference.loggers.LogColumn.Abstract
            protected String getFormattedValue() {
                return String.valueOf(CaseToCaseTransmissionLikelihood.this.transLogProb);
            }
        });
        arrayList.add(new LogColumn.Abstract("period_LL") { // from class: dr.evomodel.epidemiology.casetocase.CaseToCaseTransmissionLikelihood.3
            @Override // dr.inference.loggers.LogColumn.Abstract
            protected String getFormattedValue() {
                return String.valueOf(CaseToCaseTransmissionLikelihood.this.periodsLogProb);
            }
        });
        arrayList.addAll(Arrays.asList(this.treeLikelihood.passColumns()));
        Iterator<AbstractPeriodPriorDistribution> it = this.outbreak.getInfectiousMap().values().iterator();
        while (it.hasNext()) {
            arrayList.addAll(Arrays.asList(it.next().getColumns()));
        }
        arrayList.add(new LogColumn.Abstract("FirstInfectionTime") { // from class: dr.evomodel.epidemiology.casetocase.CaseToCaseTransmissionLikelihood.4
            @Override // dr.inference.loggers.LogColumn.Abstract
            protected String getFormattedValue() {
                if (CaseToCaseTransmissionLikelihood.this.sortedTreeEvents == null) {
                    CaseToCaseTransmissionLikelihood.this.sortEvents();
                }
                return String.valueOf(CaseToCaseTransmissionLikelihood.this.treeLikelihood.getInfectionTime(CaseToCaseTransmissionLikelihood.this.indexCase));
            }
        });
        arrayList.add(new LogColumn.Abstract("IndexCaseIndex") { // from class: dr.evomodel.epidemiology.casetocase.CaseToCaseTransmissionLikelihood.5
            @Override // dr.inference.loggers.LogColumn.Abstract
            protected String getFormattedValue() {
                return String.valueOf(CaseToCaseTransmissionLikelihood.this.treeLikelihood.getOutbreak().getCaseIndex(CaseToCaseTransmissionLikelihood.this.indexCase));
            }
        });
        return (LogColumn[]) arrayList.toArray(new LogColumn[arrayList.size()]);
    }
}
