package dr.evomodel.speciation;

import dr.evolution.tree.Tree;
import dr.evolution.util.Taxon;
import dr.evolution.util.Units;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.util.Author;
import dr.util.Citable;
import dr.util.Citation;
import java.util.Collections;
import java.util.List;
import java.util.Set;

/* loaded from: input_file:dr/evomodel/speciation/BirthDeathSerialSamplingModel.class */
public class BirthDeathSerialSamplingModel extends MaskableSpeciationModel implements Citable {
    Variable<Double> R0;
    Variable<Double> recoveryRate;
    Variable<Double> samplingProbability;
    Variable<Double> lambda;
    Variable<Double> mu;
    Variable<Double> psi;
    Variable<Double> p;
    boolean relativeDeath;
    Variable<Double> r;
    boolean hasFinalSample;
    Variable<Double> origin;
    BirthDeathSerialSamplingModel mask;

    public BirthDeathSerialSamplingModel(Variable<Double> variable, Variable<Double> variable2, Variable<Double> variable3, Variable<Double> variable4, boolean z, Variable<Double> variable5, boolean z2, Variable<Double> variable6, Units.Type type) {
        this("birthDeathSerialSamplingModel", variable, variable2, variable3, variable4, z, variable5, z2, variable6, type);
    }

    public BirthDeathSerialSamplingModel(String str, Variable<Double> variable, Variable<Double> variable2, Variable<Double> variable3, Variable<Double> variable4, boolean z, Variable<Double> variable5, boolean z2, Variable<Double> variable6, Units.Type type) {
        super(str, type);
        this.relativeDeath = false;
        this.hasFinalSample = false;
        this.mask = null;
        this.relativeDeath = z;
        this.lambda = variable;
        addVariable(variable);
        variable.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, 0.0d, 1));
        this.mu = variable2;
        addVariable(variable2);
        variable2.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, 0.0d, 1));
        this.psi = variable3;
        addVariable(variable3);
        variable3.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, 0.0d, 1));
        this.p = variable4;
        addVariable(variable4);
        variable4.addBounds(new Parameter.DefaultBounds(1.0d, 0.0d, 1));
        this.hasFinalSample = z2;
        this.r = variable5;
        addVariable(variable5);
        variable5.addBounds(new Parameter.DefaultBounds(1.0d, 0.0d, 1));
        this.origin = variable6;
        if (variable6 != null) {
            addVariable(variable6);
            variable6.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, 0.0d, 1));
        }
    }

    public BirthDeathSerialSamplingModel(String str, Variable<Double> variable, Variable<Double> variable2, Variable<Double> variable3, Variable<Double> variable4, Units.Type type) {
        super(str, type);
        this.relativeDeath = false;
        this.hasFinalSample = false;
        this.mask = null;
        this.relativeDeath = false;
        this.hasFinalSample = false;
        this.R0 = variable;
        addVariable(variable);
        variable.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, 0.0d, 1));
        this.recoveryRate = variable2;
        addVariable(variable2);
        variable2.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, 0.0d, 1));
        this.samplingProbability = variable3;
        addVariable(variable3);
        variable3.addBounds(new Parameter.DefaultBounds(1.0d, 0.0d, 1));
        this.origin = variable4;
        if (variable4 != null) {
            addVariable(variable4);
            variable4.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, 0.0d, 1));
        }
    }

    public static double p0(double d, double d2, double d3, double d4, double d5) {
        double c1 = c1(d, d2, d4);
        double c2 = c2(d, d2, d3, d4);
        double exp = Math.exp((-c1) * d5) * (1.0d - c2);
        return (((d + d2) + d4) + (c1 * ((exp - (1.0d + c2)) / (exp + (1.0d + c2))))) / (2.0d * d);
    }

    public static double q(double d, double d2, double d3, double d4, double d5) {
        double c1 = c1(d, d2, d4);
        double c2 = c2(d, d2, d3, d4);
        return (c1 * d5) + (2.0d * Math.log((Math.exp((-c1) * d5) * (1.0d - c2)) + 1.0d + c2));
    }

    private static double c1(double d, double d2, double d3) {
        return Math.abs(Math.sqrt(Math.pow((d - d2) - d3, 2.0d) + (4.0d * d * d3)));
    }

    private static double c2(double d, double d2, double d3, double d4) {
        return (-(((d - d2) - ((2.0d * d) * d3)) - d4)) / c1(d, d2, d4);
    }

    public double p0(double d) {
        return p0(birth(), death(), p(), psi(), d);
    }

    public double q(double d) {
        return q(birth(), death(), p(), psi(), d);
    }

    private double c1() {
        return c1(birth(), death(), psi());
    }

    private double c2() {
        return c2(birth(), death(), p(), psi());
    }

    public double birth() {
        return this.mask != null ? this.mask.birth() : this.lambda != null ? this.lambda.getValue(0).doubleValue() : this.R0.getValue(0).doubleValue() * this.recoveryRate.getValue(0).doubleValue();
    }

    public double death() {
        return this.mask != null ? this.mask.death() : this.mu != null ? this.relativeDeath ? this.mu.getValue(0).doubleValue() * birth() : this.mu.getValue(0).doubleValue() : this.recoveryRate.getValue(0).doubleValue() * (1.0d - this.samplingProbability.getValue(0).doubleValue());
    }

    public double psi() {
        return this.mask != null ? this.mask.psi() : this.psi != null ? this.psi.getValue(0).doubleValue() : this.recoveryRate.getValue(0).doubleValue() * this.samplingProbability.getValue(0).doubleValue();
    }

    public double p() {
        if (this.mask != null) {
            return this.mask.p.getValue(0).doubleValue();
        }
        if (this.hasFinalSample) {
            return this.p.getValue(0).doubleValue();
        }
        return 0.0d;
    }

    public boolean isSamplingOrigin() {
        return this.origin != null;
    }

    public double x0() {
        return this.origin.getValue(0).doubleValue();
    }

    @Override // dr.evomodel.speciation.SpeciationModel
    public final double calculateTreeLogLikelihood(Tree tree) {
        if (isSamplingOrigin() && x0() < tree.getNodeHeight(tree.getRoot())) {
            return Double.NEGATIVE_INFINITY;
        }
        int i = 0;
        int i2 = 0;
        for (int i3 = 0; i3 < tree.getExternalNodeCount(); i3++) {
            if (tree.getNodeHeight(tree.getExternalNode(i3)) == 0.0d) {
                i++;
            } else {
                i2++;
            }
        }
        if (!this.hasFinalSample && i < 1) {
            throw new RuntimeException("For sampling-through-time model there must be at least one tip at time zero.");
        }
        double birth = birth();
        double p = p();
        if (!isSamplingOrigin()) {
            throw new RuntimeException("The origin must be sampled, as integrating it out is not implemented!");
        }
        double d = -q(x0());
        if (this.hasFinalSample) {
            d += i * Math.log(4.0d * p);
        }
        for (int i4 = 0; i4 < tree.getInternalNodeCount(); i4++) {
            d += Math.log(birth) - q(tree.getNodeHeight(tree.getInternalNode(i4)));
        }
        for (int i5 = 0; i5 < tree.getExternalNodeCount(); i5++) {
            double nodeHeight = tree.getNodeHeight(tree.getExternalNode(i5));
            if (nodeHeight > 0.0d) {
                d += Math.log(psi()) + q(nodeHeight);
            } else if (!this.hasFinalSample) {
                d += Math.log(psi()) + q(nodeHeight);
            }
        }
        return d;
    }

    @Override // dr.evomodel.speciation.SpeciationModel
    public double calculateTreeLogLikelihood(Tree tree, Set<Taxon> set) {
        if (set.size() == 0) {
            return calculateTreeLogLikelihood(tree);
        }
        throw new RuntimeException("Not implemented!");
    }

    @Override // dr.evomodel.speciation.MaskableSpeciationModel
    public void mask(SpeciationModel speciationModel) {
        if (!(speciationModel instanceof BirthDeathSerialSamplingModel)) {
            throw new IllegalArgumentException();
        }
        this.mask = (BirthDeathSerialSamplingModel) speciationModel;
    }

    @Override // dr.evomodel.speciation.MaskableSpeciationModel
    public void unmask() {
        this.mask = null;
    }

    @Override // dr.util.Citable
    public Citation.Category getCategory() {
        return Citation.Category.TREE_PRIORS;
    }

    @Override // dr.util.Citable
    public String getDescription() {
        return "Gernhard 2008 Birth Death Tree Model";
    }

    @Override // dr.util.Citable
    public List<Citation> getCitations() {
        return Collections.singletonList(new Citation(new Author[]{new Author("T", "Gernhard")}, "The conditioned reconstructed process", 2008, "Journal of Theoretical Biology", 253, 769, 778, "10.1016/j.jtbi.2008.04.005"));
    }
}
