package dr.inference.markovjumps;

import dr.evomodel.substmodel.DefaultEigenSystem;
import dr.evomodel.substmodel.EigenDecomposition;
import dr.evomodel.substmodel.EigenSystem;
import dr.math.Binomial;
import dr.math.GammaFunction;
import dr.math.matrixAlgebra.Vector;

/* loaded from: input_file:dr/inference/markovjumps/TwoStateSericolaSeriesMarkovReward.class */
public class TwoStateSericolaSeriesMarkovReward implements MarkovReward {
    private static final boolean DEBUG = true;
    private static final boolean DEBUG2 = false;
    private double[][][][] internalC;
    private EigenDecomposition eigenDecomposition;
    private final double[] Q;
    private final double[] r;
    private final double lambda;
    private final double[] P;
    private final int phi;
    private final int dim;
    private final double epsilon;
    private final EigenSystem eigenSystem;
    private double maxTime;

    public TwoStateSericolaSeriesMarkovReward(double[] dArr, double[] dArr2, int i) {
        this(dArr, dArr2, i, 1.0E-10d);
    }

    public TwoStateSericolaSeriesMarkovReward(double[] dArr, double[] dArr2, int i, double d) {
        this.Q = dArr;
        this.r = dArr2;
        this.maxTime = 0.0d;
        this.epsilon = d;
        this.dim = i;
        this.lambda = determineLambda();
        this.phi = i - 1;
        this.P = initializeP(dArr, this.lambda);
        this.eigenSystem = new DefaultEigenSystem(i);
    }

    private double[][] initializeW(int i, int i2) {
        return new double[i][i2 * i2];
    }

    private int getHfromX(double d, double d2) {
        int i = 1;
        while (d >= this.r[i] * d2) {
            i++;
        }
        return i;
    }

    private void growC(double d, int i) {
        int nfromC = getNfromC();
        if (d > this.maxTime) {
            nfromC = determineNumberOfSteps(d, this.lambda) + i;
            this.maxTime = d;
        }
        if (nfromC > getNfromC()) {
            if (nfromC > 200) {
                System.err.println("Growing C to N = " + nfromC + " with " + this.maxTime);
            }
            if (nfromC > 500) {
                System.err.println("Warning: > 500 recursion depth in SericolaSeriesMarkovReward");
            }
            initializeSpace(this.phi, nfromC);
            computeChnk();
        }
    }

    private void initializeSpace(int i, int i2) {
        this.internalC = new double[i + 1][i2 + 1][i2 + 1][this.dim * this.dim];
    }

    private double[] C(int i, int i2, int i3) {
        return this.internalC[i][i2][i3];
    }

    private int getNfromC() {
        if (this.internalC == null) {
            return -1;
        }
        return this.internalC[0].length - 1;
    }

    private int idx(int i, int i2) {
        return (i * this.dim) + i2;
    }

    private int[] getHfromX(double[] dArr, double d) {
        int[] iArr = new int[dArr.length];
        for (int i = 0; i < dArr.length; i++) {
            iArr[i] = getHfromX(dArr[0], d);
        }
        return iArr;
    }

    @Override // dr.inference.markovjumps.MarkovReward
    public double computePdf(double d, double d2, int i, int i2) {
        growC(d2, 1);
        int i3 = (i * this.dim) + i2;
        double d3 = 0.0d;
        int nfromC = getNfromC() - 1;
        for (int i4 = 0; i4 <= nfromC; i4++) {
            d3 += accumulatePdf(d, i4, d2, i3);
        }
        return d3;
    }

    @Override // dr.inference.markovjumps.MarkovReward
    public double[] computePdf(double d, double d2) {
        return computePdf(new double[]{d}, d2)[0];
    }

    public double[][] computePdf(double[] dArr, double d) {
        growC(d, 1);
        double[][] initializeW = initializeW(dArr.length, this.dim);
        int nfromC = getNfromC() - 1;
        for (int i = 0; i <= nfromC; i++) {
            accumulatePdf(initializeW, dArr, i, d);
        }
        return initializeW;
    }

    @Override // dr.inference.markovjumps.MarkovReward
    public double computeCdf(double d, double d2, int i, int i2) {
        return computeCdf(d, d2)[(i * this.dim) + i2];
    }

    public double[] computeCdf(double d, double d2) {
        return computeCdf(new double[]{d}, d2)[0];
    }

    public double[][] computeCdf(double[] dArr, double d) {
        int[] hfromX = getHfromX(dArr, d);
        growC(d, 0);
        double[][] initializeW = initializeW(dArr.length, this.dim);
        int nfromC = getNfromC();
        for (int i = 0; i <= nfromC; i++) {
            accumulateCdf(initializeW, dArr, hfromX, i, d);
        }
        return initializeW;
    }

    private double[] initializeP(double[] dArr, double d) {
        double[] dArr2 = new double[this.dim * this.dim];
        int i = 0;
        while (i < this.dim) {
            int i2 = 0;
            while (i2 < this.dim) {
                dArr2[idx(i, i2)] = (i == i2 ? 1.0d : 0.0d) + (dArr[idx(i, i2)] / d);
                i2++;
            }
            i++;
        }
        return dArr2;
    }

    private void accumulateCdf(double[][] dArr, double[] dArr2, int[] iArr, int i, double d) {
        double exp = Math.exp((((-this.lambda) * d) + (i * (Math.log(this.lambda) + Math.log(d)))) - GammaFunction.lnGamma(i + 1.0d));
        for (int i2 = 0; i2 < dArr2.length; i2++) {
            double d2 = dArr2[i2];
            int i3 = iArr[i2];
            double d3 = (d2 - (this.r[i3 - 1] * d)) / ((this.r[i3] - this.r[i3 - 1]) * d);
            int i4 = this.dim * this.dim;
            double[] dArr3 = new double[i4];
            for (int i5 = 0; i5 <= i; i5++) {
                double choose = Binomial.choose(i, i5) * Math.pow(d3, i5) * Math.pow(1.0d - d3, i - i5);
                for (int i6 = 0; i6 < i4; i6++) {
                    int i7 = i6;
                    dArr3[i7] = dArr3[i7] + (choose * C(i3, i, i5)[i6]);
                }
            }
            for (int i8 = 0; i8 < i4; i8++) {
                double[] dArr4 = dArr[i2];
                int i9 = i8;
                dArr4[i9] = dArr4[i9] + (exp * dArr3[i8]);
            }
        }
    }

    private double accumulatePdf(double d, int i, double d2, int i2) {
        double exp = Math.exp((((-this.lambda) * d2) + (i * (Math.log(this.lambda) + Math.log(d2)))) - GammaFunction.lnGamma(i + 1.0d));
        double d3 = this.lambda / (this.r[1] - this.r[1 - 1]);
        double d4 = (d - (this.r[1 - 1] * d2)) / ((this.r[1] - this.r[1 - 1]) * d2);
        double d5 = 0.0d;
        for (int i3 = 0; i3 <= i; i3++) {
            d5 += Binomial.choose(i, i3) * Math.pow(d4, i3) * Math.pow(1.0d - d4, i - i3) * (C(1, i + 1, i3 + 1)[i2] - C(1, i + 1, i3)[i2]);
        }
        return 0.0d + (d3 * exp * d5);
    }

    private void accumulatePdf(double[][] dArr, double[] dArr2, int i, double d) {
        double exp = Math.exp((((-this.lambda) * d) + (i * (Math.log(this.lambda) + Math.log(d)))) - GammaFunction.lnGamma(i + 1.0d));
        for (int i2 = 0; i2 < dArr2.length; i2++) {
            double d2 = dArr2[i2];
            double d3 = this.lambda / (this.r[1] - this.r[1 - 1]);
            double d4 = (d2 - (this.r[1 - 1] * d)) / ((this.r[1] - this.r[1 - 1]) * d);
            int i3 = this.dim * this.dim;
            double[] dArr3 = new double[i3];
            for (int i4 = 0; i4 <= i; i4++) {
                double choose = Binomial.choose(i, i4) * Math.pow(d4, i4) * Math.pow(1.0d - d4, i - i4);
                for (int i5 = 0; i5 < i3; i5++) {
                    int i6 = i5;
                    dArr3[i6] = dArr3[i6] + (choose * (C(1, i + 1, i4 + 1)[i5] - C(1, i + 1, i4)[i5]));
                }
            }
            for (int i7 = 0; i7 < i3; i7++) {
                double[] dArr4 = dArr[i2];
                int i8 = i7;
                dArr4[i8] = dArr4[i8] + (d3 * exp * dArr3[i7]);
            }
        }
    }

    private double relationTwelve(int i, int i2, int i3, int i4, int i5) {
        double d = ((this.r[i4] - this.r[i]) / (this.r[i4] - this.r[i - 1])) * C(i, i2, i3 - 1)[idx(i4, i5)];
        double d2 = 0.0d;
        for (int i6 = 0; i6 <= this.phi; i6++) {
            d2 += this.P[idx(i4, i6)] * C(i, i2 - 1, i3 - 1)[idx(i6, i5)];
        }
        return d + (d2 * ((this.r[i] - this.r[i - 1]) / (this.r[i4] - this.r[i - 1])));
    }

    private double relationThirteen(int i, int i2, int i3, int i4, int i5) {
        double d = ((this.r[i - 1] - this.r[i4]) / (this.r[i] - this.r[i4])) * C(i, i2, i3 + 1)[idx(i4, i5)];
        double d2 = 0.0d;
        for (int i6 = 0; i6 <= this.phi; i6++) {
            d2 += this.P[idx(i4, i6)] * C(i, i2 - 1, i3)[idx(i6, i5)];
        }
        return d + (d2 * ((this.r[i] - this.r[i - 1]) / (this.r[i] - this.r[i4])));
    }

    private double[] product(double[] dArr, double[] dArr2) {
        double[] dArr3 = new double[this.dim * this.dim];
        for (int i = 0; i < this.dim; i++) {
            for (int i2 = 0; i2 < this.dim; i2++) {
                int idx = idx(i, i2);
                for (int i3 = 0; i3 < this.dim; i3++) {
                    dArr3[idx] = dArr3[idx] + (dArr[idx(i, i3)] * dArr2[idx(i3, i2)]);
                }
            }
        }
        return dArr3;
    }

    private void computeChnk() {
        double[] dArr = new double[this.dim * this.dim];
        for (int i = 0; i < this.dim; i++) {
            dArr[idx(i, i)] = 1.0d;
        }
        for (int i2 = 1; i2 <= this.phi; i2++) {
            for (int i3 = 0; i3 <= i2 - 1; i3++) {
                C(i2, 0, 0)[idx(i3, i3)] = 1.0d;
            }
        }
        int nfromC = getNfromC();
        for (int i4 = 1; i4 <= nfromC; i4++) {
            for (int i5 = 1; i5 <= this.phi; i5++) {
                for (int i6 = 1; i6 <= i4; i6++) {
                    for (int i7 = i5; i7 <= this.phi; i7++) {
                        for (int i8 = 0; i8 <= this.phi; i8++) {
                            C(i5, i4, i6)[idx(i7, i8)] = relationTwelve(i5, i4, i6, i7, i8);
                        }
                    }
                }
                for (int i9 = i5 + 1; i9 <= this.phi; i9++) {
                    for (int i10 = 0; i10 <= this.phi; i10++) {
                        C(i5 + 1, i4, 0)[idx(i9, i10)] = C(i5, i4, i4)[idx(i9, i10)];
                    }
                }
            }
            dArr = product(dArr, this.P);
            for (int i11 = 0; i11 <= this.phi - 1; i11++) {
                for (int i12 = 0; i12 <= this.phi; i12++) {
                    C(this.phi, i4, i4)[idx(i11, i12)] = dArr[idx(i11, i12)];
                }
            }
            for (int i13 = this.phi; i13 >= 1; i13--) {
                for (int i14 = i4 - 1; i14 >= 0; i14--) {
                    for (int i15 = 0; i15 <= i13 - 1; i15++) {
                        for (int i16 = 0; i16 <= this.phi; i16++) {
                            C(i13, i4, i14)[idx(i15, i16)] = relationThirteen(i13, i4, i14, i15, i16);
                        }
                    }
                    for (int i17 = 0; i17 <= i13 - 2; i17++) {
                        for (int i18 = 0; i18 <= this.phi; i18++) {
                            C(i13 - 1, i4, i4)[idx(i17, i18)] = C(i13, i4, 0)[idx(i17, i18)];
                        }
                    }
                }
            }
        }
    }

    private double determineLambda() {
        double d = this.Q[0];
        for (int i = 1; i < this.dim; i++) {
            int idx = idx(i, i);
            if (this.Q[idx] < d) {
                d = this.Q[idx];
            }
        }
        return -d;
    }

    private double[][] squareMatrix(double[] dArr) {
        double[][] dArr2 = new double[this.dim][this.dim];
        for (int i = 0; i < this.dim; i++) {
            for (int i2 = 0; i2 < this.dim; i2++) {
                dArr2[i][i2] = dArr[idx(i, i2)];
            }
        }
        return dArr2;
    }

    private int determineNumberOfSteps(double d, double d2) {
        int i = -1;
        double d3 = 1.0d - this.epsilon;
        double d4 = 0.0d;
        while (true) {
            double d5 = d4;
            if (Math.abs(d5 - d3) <= this.epsilon || d5 >= 1.0d) {
                break;
            }
            i++;
            d4 = d5 + Math.exp((((-d2) * d) + (i * (Math.log(d2) + Math.log(d)))) - GammaFunction.lnGamma(i + 1));
        }
        return i;
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append("Q: " + new Vector(this.Q) + "\n");
        sb.append("r: " + new Vector(this.r) + "\n");
        sb.append("lambda: " + this.lambda + "\n");
        sb.append("N: " + getNfromC() + "\n");
        sb.append("maxTime: " + this.maxTime + "\n");
        sb.append("cprob at maxTime: " + new Vector(computeConditionalProbabilities(this.maxTime)) + "\n");
        return sb.toString();
    }

    private EigenDecomposition getEigenDecomposition() {
        if (this.eigenDecomposition == null) {
            this.eigenDecomposition = this.eigenSystem.decomposeMatrix(squareMatrix(this.Q));
        }
        return this.eigenDecomposition;
    }

    public double[] computeConditionalProbabilities(double d) {
        double[] dArr = new double[this.dim * this.dim];
        this.eigenSystem.computeExponential(getEigenDecomposition(), d, dArr);
        return dArr;
    }

    @Override // dr.inference.markovjumps.MarkovReward
    public double computeConditionalProbability(double d, int i, int i2) {
        return this.eigenSystem.computeExponential(getEigenDecomposition(), d, i, i2);
    }
}
