package dr.evomodel.speciation;

import dr.evolution.io.Importer;
import dr.evolution.io.NewickImporter;
import dr.evolution.tree.Tree;
import dr.evolution.util.Taxon;
import dr.evolution.util.Units;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import java.io.IOException;
import java.util.Arrays;
import java.util.Set;

/* loaded from: input_file:dr/evomodel/speciation/BirthDeathSerialSkylineModel.class */
public class BirthDeathSerialSkylineModel extends SpeciationModel {
    Variable<Double> times;
    Variable<Double> lambda;
    Variable<Double> mu;
    Variable<Double> psi;
    Variable<Double> p;
    Variable<Double> origin;
    boolean relativeDeath;
    int size;
    double t_root;
    double x0;
    protected double[] p0_iMinus1;
    protected double[] Ai;
    protected double[] Bi;
    protected boolean birthChanges;
    protected boolean deathChanges;
    protected boolean samplingChanges;
    protected boolean timesStartFromOrigin;
    protected double[] timesFromTips;

    public BirthDeathSerialSkylineModel(Variable<Double> variable, Variable<Double> variable2, Variable<Double> variable3, Variable<Double> variable4, Variable<Double> variable5, Variable<Double> variable6, boolean z, boolean z2, boolean z3, Units.Type type) {
        this("birthDeathSerialSamplingModel", variable, variable2, variable3, variable4, variable5, variable6, z, z2, z3, type);
    }

    public BirthDeathSerialSkylineModel(String str, Variable<Double> variable, Variable<Double> variable2, Variable<Double> variable3, Variable<Double> variable4, Variable<Double> variable5, Variable<Double> variable6, boolean z, boolean z2, boolean z3, Units.Type type) {
        super(str, type);
        this.relativeDeath = false;
        this.size = 1;
        this.birthChanges = true;
        this.deathChanges = true;
        this.samplingChanges = true;
        this.timesStartFromOrigin = true;
        this.size = variable.getSize();
        if (variable2.getSize() != 1 && variable2.getSize() != this.size) {
            throw new RuntimeException("Length of Lambda parameter should be one or equal to the size of time parameter (size = " + this.size + ")");
        }
        if (variable3.getSize() != 1 && variable3.getSize() != this.size) {
            throw new RuntimeException("Length of mu parameter should be one or equal to the size of time parameter (size = " + this.size + ")");
        }
        this.timesStartFromOrigin = z3;
        this.times = variable;
        addVariable(variable);
        variable.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, 0.0d, variable.getSize()));
        this.lambda = variable2;
        addVariable(variable2);
        variable2.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, 0.0d, variable2.getSize()));
        this.mu = variable3;
        addVariable(variable3);
        variable3.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, 0.0d, variable3.getSize()));
        this.p = variable5;
        addVariable(variable5);
        variable5.addBounds(new Parameter.DefaultBounds(1.0d, 0.0d, variable5.getSize()));
        this.origin = variable6;
        addVariable(variable6);
        variable5.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, 0.0d, variable6.getSize()));
        this.psi = variable4;
        addVariable(variable4);
        variable4.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, 0.0d, variable4.getSize()));
        this.relativeDeath = z;
    }

    public int lineageCountAtTime(double d, Tree tree) {
        int i = 1;
        for (int i2 = 0; i2 < tree.getInternalNodeCount(); i2++) {
            if (tree.getNodeHeight(tree.getInternalNode(i2)) > d) {
                i++;
            }
        }
        for (int i3 = 0; i3 < tree.getExternalNodeCount(); i3++) {
            if (tree.getNodeHeight(tree.getExternalNode(i3)) > d) {
                i--;
            }
        }
        return i;
    }

    public double Ai(double d, double d2, double d3) {
        return Math.sqrt((((d - d2) - d3) * ((d - d2) - d3)) + (4.0d * d * d3));
    }

    public double Bi(double d, double d2, double d3, double d4, double d5) {
        return (-((((1.0d - (2.0d * d5)) * d) + d2) + d3)) / d4;
    }

    public double p0(int i, double d, double d2) {
        return p0(birth(this.birthChanges ? i : 0), death(this.deathChanges ? i : 0), psi(this.samplingChanges ? i : 0), this.Ai[i], this.Bi[i], d, d2);
    }

    public double p0(double d, double d2, double d3, double d4, double d5, double d6, double d7) {
        return (((d + d2) + d3) - ((d4 * ((Math.exp(d4 * (d6 - d7)) * (1.0d - d5)) - (1.0d + d5))) / ((Math.exp(d4 * (d6 - d7)) * (1.0d - d5)) + (1.0d + d5)))) / (2.0d * d);
    }

    public double g(int i, double d, double d2) {
        return 4.0d / (((2.0d * (1.0d - (this.Bi[i] * this.Bi[i]))) + (Math.exp(this.Ai[i] * (d - d2)) * ((1.0d - this.Bi[i]) * (1.0d - this.Bi[i])))) + (Math.exp((-this.Ai[i]) * (d - d2)) * ((1.0d + this.Bi[i]) * (1.0d + this.Bi[i]))));
    }

    public double t(int i) {
        return this.timesFromTips[i];
    }

    public double birth(int i) {
        return this.lambda.getValue(i).doubleValue();
    }

    public double death(int i) {
        return this.relativeDeath ? this.mu.getValue(i).doubleValue() * birth(i) : this.mu.getValue(i).doubleValue();
    }

    public double psi(int i) {
        return this.psi.getValue(i).doubleValue();
    }

    public double p() {
        return this.p.getValue(0).doubleValue();
    }

    public double lambda(double d) {
        return this.lambda.getValue(index(d)).doubleValue();
    }

    public double mu(double d) {
        return this.mu.getValue(index(d)).doubleValue();
    }

    public int index(double d) {
        int binarySearch = Arrays.binarySearch(this.timesFromTips, d);
        if (binarySearch < 0) {
            binarySearch = (-binarySearch) - 1;
        }
        return Math.max(binarySearch - 1, 0);
    }

    public void preCalculation(Tree tree) {
        this.t_root = tree.getNodeHeight(tree.getRoot());
        this.x0 = this.t_root + this.origin.getValue(0).doubleValue();
        if (this.timesFromTips == null) {
            this.timesFromTips = new double[this.times.getSize()];
        }
        if (this.timesStartFromOrigin) {
            this.timesFromTips[0] = 0.0d;
            for (int i = 1; i < this.timesFromTips.length; i++) {
                this.timesFromTips[i] = Math.max(0.0d, this.x0 - this.times.getValue(this.timesFromTips.length - i).doubleValue());
            }
        } else {
            for (int i2 = 0; i2 < this.timesFromTips.length; i2++) {
                this.timesFromTips[i2] = this.times.getValue(i2).doubleValue();
            }
        }
        this.Ai = new double[this.size];
        this.Bi = new double[this.size];
        this.p0_iMinus1 = new double[this.size];
        for (int i3 = 0; i3 < this.size; i3++) {
            this.Ai[i3] = Ai(birth(this.birthChanges ? i3 : 0), death(this.deathChanges ? i3 : 0), psi(this.samplingChanges ? i3 : 0));
        }
        this.Bi[0] = Bi(birth(0), death(0), psi(0), this.Ai[0], 1.0d);
        for (int i4 = 1; i4 < this.size; i4++) {
            this.p0_iMinus1[i4 - 1] = p0(birth(this.birthChanges ? i4 - 1 : 0), death(this.deathChanges ? i4 - 1 : 0), psi(this.samplingChanges ? i4 - 1 : 0), this.Ai[i4 - 1], this.Bi[i4 - 1], t(i4), t(i4 - 1));
            this.Bi[i4] = Bi(birth(this.birthChanges ? i4 : 0), death(this.deathChanges ? i4 : 0), psi(this.samplingChanges ? i4 : 0), this.Ai[i4], this.p0_iMinus1[i4 - 1]);
        }
    }

    @Override // dr.evomodel.speciation.SpeciationModel
    public final double calculateTreeLogLikelihood(Tree tree) {
        int[] iArr = new int[this.size];
        int externalNodeCount = tree.getExternalNodeCount();
        preCalculation(tree);
        int i = this.size - 1;
        double log = Math.log(g(i, this.x0, t(i)));
        for (int i2 = 0; i2 < tree.getInternalNodeCount(); i2++) {
            double nodeHeight = tree.getNodeHeight(tree.getInternalNode(i2));
            int index = index(nodeHeight);
            log += Math.log(birth(this.birthChanges ? index : 0) * g(index, nodeHeight, t(index)));
            g(index, nodeHeight, t(index));
        }
        for (int i3 = 0; i3 < externalNodeCount; i3++) {
            double nodeHeight2 = tree.getNodeHeight(tree.getExternalNode(i3));
            int index2 = index(nodeHeight2);
            log += Math.log(psi(this.samplingChanges ? index2 : 0)) - Math.log(g(index2, nodeHeight2, t(index2)));
        }
        for (int i4 = 0; i4 < this.size - 1; i4++) {
            double d = 0.0d;
            double t = t(i4 + 1);
            iArr[i4] = lineageCountAtTime(t, tree);
            if (iArr[i4] > 0) {
                d = 0.0d + (iArr[i4] * Math.log(g(i4, t, t(i4))));
            }
            log += d;
        }
        return log;
    }

    @Override // dr.evomodel.speciation.SpeciationModel
    public double calculateTreeLogLikelihood(Tree tree, Set<Taxon> set) {
        if (set.size() == 0) {
            return calculateTreeLogLikelihood(tree);
        }
        throw new RuntimeException("Not implemented!");
    }

    public static void main(String[] strArr) throws IOException, Importer.ImportException {
        Variable.D d = new Variable.D(1.0d, 10);
        Variable.D d2 = new Variable.D(1.0d, 10);
        for (int i = 0; i < d2.getSize(); i++) {
            d.setValue(i, (int) Double.valueOf((i + 1) * 2.0d));
            d2.setValue(i, (int) Double.valueOf(i + 1.0d));
        }
        BirthDeathSerialSkylineModel birthDeathSerialSkylineModel = new BirthDeathSerialSkylineModel(d, new Variable.D(1.0d, 10), d2, new Variable.D(0.5d, 1), new Variable.D(0.5d, 1), new Variable.D(0.5d, 1), false, false, false, Units.Type.SUBSTITUTIONS);
        birthDeathSerialSkylineModel.calculateTreeLogLikelihood(new NewickImporter("((A:6,B:5):4,(C:3,D:2):1);").importNextTree());
        for (int i2 = 0; i2 < d.getSize(); i2++) {
            System.out.println("mu at time " + i2 + " is " + birthDeathSerialSkylineModel.mu(i2));
            System.out.println("p0 at time " + i2 + " is " + birthDeathSerialSkylineModel.p0(0, i2, i2));
        }
    }
}
