package dr.evomodel.transmission;

import dr.evolution.coalescent.Coalescent;
import dr.evolution.coalescent.DemographicFunction;
import dr.evolution.coalescent.Intervals;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.util.Taxon;
import dr.evolution.util.TaxonList;
import dr.evolution.util.Units;
import dr.evomodel.coalescent.DemographicModel;
import dr.evomodel.transmission.TransmissionHistoryModel;
import dr.evomodel.tree.TreeModel;
import dr.inference.model.AbstractModelLikelihood;
import dr.inference.model.Model;
import dr.inference.model.Variable;
import dr.xml.AbstractXMLObjectParser;
import dr.xml.ElementRule;
import dr.xml.XMLObject;
import dr.xml.XMLObjectParser;
import dr.xml.XMLParseException;
import dr.xml.XMLSyntaxRule;
import dr.xml.XORRule;

/* loaded from: input_file:dr/evomodel/transmission/TransmissionLikelihood.class */
public class TransmissionLikelihood extends AbstractModelLikelihood implements Units {
    public static final String TRANSMISSION_LIKELIHOOD = "transmissionLikelihood";
    public static final String SOURCE_PATIENT = "sourcePatient";
    public static XMLObjectParser PARSER = new AbstractXMLObjectParser() { // from class: dr.evomodel.transmission.TransmissionLikelihood.1
        private final XMLSyntaxRule[] rules = {new ElementRule("sourcePatient", DemographicModel.class, "This describes the demographic process for the source donor patient."), new ElementRule(TransmissionDemographicModel.class, "This describes the demographic process for the recipient patients."), new XORRule(new ElementRule("hostTree", new XMLSyntaxRule[]{new ElementRule(Tree.class)}), new ElementRule(TransmissionHistoryModel.class, "This describes the transmission history of the patients.")), new ElementRule("parasiteTree", new XMLSyntaxRule[]{new ElementRule(Tree.class)})};

        @Override // dr.xml.XMLObjectParser
        public String getParserName() {
            return TransmissionLikelihood.TRANSMISSION_LIKELIHOOD;
        }

        @Override // dr.xml.AbstractXMLObjectParser
        public Object parseXMLObject(XMLObject xMLObject) throws XMLParseException {
            TransmissionLikelihood transmissionLikelihood;
            DemographicModel demographicModel = (DemographicModel) xMLObject.getElementFirstChild("sourcePatient");
            TransmissionDemographicModel transmissionDemographicModel = (TransmissionDemographicModel) xMLObject.getChild(TransmissionDemographicModel.class);
            Tree tree = (Tree) xMLObject.getElementFirstChild("parasiteTree");
            if (xMLObject.getChild(TransmissionHistoryModel.class) != null) {
                try {
                    transmissionLikelihood = new TransmissionLikelihood((TransmissionHistoryModel) xMLObject.getChild(TransmissionHistoryModel.class), tree, demographicModel, transmissionDemographicModel);
                } catch (TaxonList.MissingTaxonException e) {
                    throw new XMLParseException(e.toString());
                }
            } else {
                try {
                    transmissionLikelihood = new TransmissionLikelihood((Tree) xMLObject.getElementFirstChild("hostTree"), tree, demographicModel, transmissionDemographicModel);
                } catch (TaxonList.MissingTaxonException e2) {
                    throw new XMLParseException(e2.toString());
                }
            }
            return transmissionLikelihood;
        }

        @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
        public String getParserDescription() {
            return "This element represents a likelihood function for transmission.";
        }

        @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
        public Class getReturnType() {
            return TransmissionLikelihood.class;
        }

        @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
        public XMLSyntaxRule[] getSyntaxRules() {
            return this.rules;
        }
    };
    private DemographicModel sourceDemographic;
    private TransmissionDemographicModel transmissionModel;
    private Tree hostTree;
    private TransmissionHistoryModel transmissionHistoryModel;
    private Tree virusTree;
    private int hostCount;
    private Intervals[] intervals;
    private int[] donorHost;
    private double[] transmissionTime;
    private double[] donorSize;
    private boolean likelihoodKnown;
    private double logLikelihood;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:dr/evomodel/transmission/TransmissionLikelihood$IncompatibleException.class */
    public class IncompatibleException extends Exception {
        private static final long serialVersionUID = 8439923064799668934L;

        public IncompatibleException(String str) {
            super(str);
        }
    }

    public TransmissionLikelihood(Tree tree, Tree tree2, DemographicModel demographicModel, TransmissionDemographicModel transmissionDemographicModel) throws TaxonList.MissingTaxonException {
        this(TRANSMISSION_LIKELIHOOD, tree, tree2, demographicModel, transmissionDemographicModel);
    }

    public TransmissionLikelihood(String str, Tree tree, Tree tree2, DemographicModel demographicModel, TransmissionDemographicModel transmissionDemographicModel) throws TaxonList.MissingTaxonException {
        super(str);
        this.sourceDemographic = null;
        this.transmissionModel = null;
        this.hostTree = null;
        this.transmissionHistoryModel = null;
        this.virusTree = null;
        this.likelihoodKnown = false;
        this.hostTree = tree;
        if (tree instanceof TreeModel) {
            addModel((TreeModel) tree);
        }
        this.virusTree = tree2;
        if (tree2 instanceof TreeModel) {
            addModel((TreeModel) tree2);
        }
        this.sourceDemographic = demographicModel;
        addModel(demographicModel);
        this.transmissionModel = transmissionDemographicModel;
        addModel(transmissionDemographicModel);
        for (int i = 0; i < tree2.getExternalNodeCount(); i++) {
            Taxon taxon = (Taxon) tree2.getTaxonAttribute(i, "host");
            if (taxon == null) {
                throw new TaxonList.MissingTaxonException("One or more of the viruses tree's taxa are missing the 'host' attribute");
            }
            if (tree.getTaxonIndex(taxon) == -1) {
                throw new TaxonList.MissingTaxonException("One of the viruses tree's host attribute, " + taxon.getId() + ", was not found as a taxon in the host tree");
            }
        }
        setupHosts();
    }

    public TransmissionLikelihood(TransmissionHistoryModel transmissionHistoryModel, Tree tree, DemographicModel demographicModel, TransmissionDemographicModel transmissionDemographicModel) throws TaxonList.MissingTaxonException {
        this(TRANSMISSION_LIKELIHOOD, transmissionHistoryModel, tree, demographicModel, transmissionDemographicModel);
    }

    public TransmissionLikelihood(String str, TransmissionHistoryModel transmissionHistoryModel, Tree tree, DemographicModel demographicModel, TransmissionDemographicModel transmissionDemographicModel) throws TaxonList.MissingTaxonException {
        super(str);
        this.sourceDemographic = null;
        this.transmissionModel = null;
        this.hostTree = null;
        this.transmissionHistoryModel = null;
        this.virusTree = null;
        this.likelihoodKnown = false;
        this.transmissionHistoryModel = transmissionHistoryModel;
        addModel(transmissionHistoryModel);
        this.virusTree = tree;
        if (tree instanceof TreeModel) {
            addModel((TreeModel) tree);
        }
        this.sourceDemographic = demographicModel;
        addModel(demographicModel);
        this.transmissionModel = transmissionDemographicModel;
        addModel(transmissionDemographicModel);
        for (int i = 0; i < tree.getExternalNodeCount(); i++) {
            Taxon taxon = (Taxon) tree.getTaxonAttribute(i, "host");
            if (taxon == null) {
                throw new TaxonList.MissingTaxonException("One or more of the viruses tree's taxa are missing the 'host' attribute");
            }
            if (transmissionHistoryModel.getHostIndex(taxon) == -1) {
                throw new TaxonList.MissingTaxonException("One of the viruses tree's host attribute, " + taxon.getId() + ", was not found as a taxon in the transmission history");
            }
        }
        setupHosts();
    }

    private void setupHosts() {
        if (this.transmissionHistoryModel != null) {
            this.hostCount = this.transmissionHistoryModel.getHostCount();
        } else {
            this.hostCount = this.hostTree.getTaxonCount();
        }
        this.intervals = new Intervals[this.hostCount];
        for (int i = 0; i < this.hostCount; i++) {
            this.intervals[i] = new Intervals(this.virusTree.getExternalNodeCount() * 3);
        }
        this.donorHost = new int[this.hostCount];
        this.donorHost[0] = -1;
        this.transmissionTime = new double[this.hostCount];
        this.transmissionTime[0] = Double.POSITIVE_INFINITY;
        this.donorSize = new double[this.hostCount];
        if (this.transmissionHistoryModel == null) {
            setupHostsTree(this.hostTree.getRoot());
            return;
        }
        for (int i2 = 0; i2 < this.transmissionHistoryModel.getTransmissionEventCount(); i2++) {
            TransmissionHistoryModel.TransmissionEvent transmissionEvent = this.transmissionHistoryModel.getTransmissionEvent(i2);
            int hostIndex = this.transmissionHistoryModel.getHostIndex(transmissionEvent.getDonor());
            int hostIndex2 = this.transmissionHistoryModel.getHostIndex(transmissionEvent.getRecipient());
            this.donorHost[hostIndex2] = hostIndex;
            this.transmissionTime[hostIndex2] = transmissionEvent.getTransmissionTime();
        }
    }

    private int setupHostsTree(NodeRef nodeRef) {
        int i;
        if (this.hostTree.isExternal(nodeRef)) {
            i = nodeRef.getNumber();
        } else {
            int i2 = setupHostsTree(this.hostTree.getChild(nodeRef, 0));
            int i3 = setupHostsTree(this.hostTree.getChild(nodeRef, 1));
            this.donorHost[i3] = i2;
            this.transmissionTime[i3] = this.hostTree.getNodeHeight(nodeRef);
            i = i2;
        }
        return i;
    }

    @Override // dr.inference.model.AbstractModel
    protected final void handleModelChangedEvent(Model model, Object obj, int i) {
        if (model != this.virusTree && model != this.hostTree && model == this.transmissionHistoryModel) {
        }
        this.likelihoodKnown = false;
    }

    @Override // dr.inference.model.AbstractModel
    protected final void handleVariableChangedEvent(Variable variable, int i, Variable.ChangeType changeType) {
    }

    @Override // dr.inference.model.AbstractModel
    protected final void storeState() {
    }

    @Override // dr.inference.model.AbstractModel
    protected final void restoreState() {
        this.likelihoodKnown = false;
    }

    @Override // dr.inference.model.AbstractModel
    protected final void acceptState() {
    }

    @Override // dr.inference.model.Likelihood
    public final Model getModel() {
        return this;
    }

    @Override // dr.inference.model.Likelihood
    public final double getLogLikelihood() {
        if (!this.likelihoodKnown) {
            this.logLikelihood = calculateLogLikelihood();
            this.likelihoodKnown = true;
        }
        return this.logLikelihood;
    }

    @Override // dr.inference.model.Likelihood
    public final void makeDirty() {
        this.likelihoodKnown = false;
    }

    public double calculateLogLikelihood() {
        makeDirty();
        setupHosts();
        for (int i = 0; i < this.hostCount; i++) {
            this.intervals[i].resetEvents();
            this.donorSize[i] = -1.0d;
        }
        try {
            setupIntervals(this.virusTree.getRoot());
            for (int i2 = 0; i2 < this.hostCount; i2++) {
                this.donorSize[i2] = -1.0d;
            }
            double calculateLogLikelihood = Coalescent.calculateLogLikelihood(this.intervals[0], this.sourceDemographic.getDemographicFunction());
            for (int i3 = 1; i3 < this.hostCount; i3++) {
                calculateLogLikelihood += Coalescent.calculateLogLikelihood(this.intervals[i3], this.transmissionModel.getDemographicFunction(this.transmissionTime[i3], getDonorSize(i3), i3));
            }
            return calculateLogLikelihood;
        } catch (IncompatibleException e) {
            return Double.NEGATIVE_INFINITY;
        }
    }

    private double getDonorSize(int i) {
        DemographicFunction demographicFunction;
        if (this.donorSize[i] > 0.0d) {
            return this.donorSize[i];
        }
        if (this.donorHost[i] == 0) {
            demographicFunction = this.sourceDemographic.getDemographicFunction();
        } else {
            demographicFunction = this.transmissionModel.getDemographicFunction(this.transmissionTime[i], getDonorSize(this.donorHost[i]), i);
        }
        this.donorSize[i] = demographicFunction.getDemographic(this.transmissionTime[i]);
        return this.donorSize[i];
    }

    private int setupIntervals(NodeRef nodeRef) throws IncompatibleException {
        int i;
        double nodeHeight = this.virusTree.getNodeHeight(nodeRef);
        if (this.virusTree.isExternal(nodeRef)) {
            Taxon taxon = (Taxon) this.virusTree.getTaxonAttribute(nodeRef.getNumber(), "host");
            i = this.transmissionHistoryModel != null ? this.transmissionHistoryModel.getHostIndex(taxon) : this.hostTree.getTaxonIndex(taxon);
            this.intervals[i].addSampleEvent(nodeHeight);
        } else {
            int i2 = setupIntervals(this.virusTree.getChild(nodeRef, 0));
            int i3 = setupIntervals(this.virusTree.getChild(nodeRef, 1));
            while (nodeHeight > this.transmissionTime[i2]) {
                double d = this.transmissionTime[i2];
                this.intervals[i2].addNothingEvent(d);
                i2 = this.donorHost[i2];
                this.intervals[i2].addSampleEvent(d);
            }
            while (nodeHeight > this.transmissionTime[i3]) {
                double d2 = this.transmissionTime[i3];
                this.intervals[i3].addNothingEvent(d2);
                i3 = this.donorHost[i3];
                this.intervals[i3].addSampleEvent(d2);
            }
            if (i2 != i3) {
                throw new IncompatibleException("Virus tree is not compatible with transmission history");
            }
            i = i2;
            this.intervals[i].addCoalescentEvent(nodeHeight);
        }
        return i;
    }

    @Override // dr.evolution.util.Units
    public final void setUnits(Units.Type type) {
        this.transmissionModel.setUnits(type);
    }

    @Override // dr.evolution.util.Units
    public final Units.Type getUnits() {
        return this.transmissionModel.getUnits();
    }
}
