package dr.geo;

import dr.evoxml.util.GraphMLUtils;
import dr.geo.SpaceTimeRejector;
import dr.math.distributions.MultivariateNormalDistribution;
import java.awt.geom.Point2D;
import java.awt.geom.Rectangle2D;
import java.io.FileNotFoundException;
import java.io.PrintWriter;

/* loaded from: input_file:dr/geo/NumericalSpaceTimeProbs2D.class */
public class NumericalSpaceTimeProbs2D {
    final int latticeWidth;
    final int latticeHeight;
    final int tsteps;
    final int subtsteps;
    final double minx;
    final double miny;
    final double dx;
    final double dy;
    final double dt;
    final MultivariateNormalDistribution D;
    final SpaceTimeRejector rejector;
    int[][][][][] counts;
    int[][][] normalization;
    int[][][] maxCount;

    public NumericalSpaceTimeProbs2D(int i, int i2, int i3, int i4, double d, Rectangle2D rectangle2D, MultivariateNormalDistribution multivariateNormalDistribution, SpaceTimeRejector spaceTimeRejector) {
        this.latticeWidth = i;
        this.latticeHeight = i2;
        this.tsteps = i3;
        this.subtsteps = i4;
        this.D = multivariateNormalDistribution;
        this.rejector = spaceTimeRejector;
        this.dt = d;
        this.minx = rectangle2D.getMinX();
        this.miny = rectangle2D.getMinY();
        this.dx = (rectangle2D.getMaxX() - this.minx) / i;
        this.dy = (rectangle2D.getMaxY() - this.miny) / i2;
        this.counts = new int[i][i2][i][i2][i3];
        this.normalization = new int[i][i2][i3];
        this.maxCount = new int[i][i2][i3];
    }

    public void populate(Point2D point2D, int i, boolean z) {
        populate(x(point2D.getX()), y(point2D.getY()), i, z);
    }

    public int populateAbsorbing(Point2D point2D, int i) {
        return populateAbsorbing(x(point2D.getX()), y(point2D.getY()), i);
    }

    public int populateAbsorbing(int i, int i2, int i3) {
        double d = this.dt / this.subtsteps;
        double[] dArr = new double[2];
        double[] dArr2 = new double[2];
        int[] iArr = new int[this.tsteps];
        int[] iArr2 = new int[this.tsteps];
        int i4 = 0;
        for (int i5 = 0; i5 < i3; i5++) {
            double d2 = 0.0d;
            dArr2[0] = ((i + Math.random()) * this.dx) + this.minx;
            dArr2[1] = ((i2 + Math.random()) * this.dy) + this.miny;
            while (this.rejector.reject(0.0d, dArr2)) {
                dArr2[0] = ((i + Math.random()) * this.dx) + this.minx;
                dArr2[1] = ((i2 + Math.random()) * this.dy) + this.miny;
            }
            boolean z = false;
            for (int i6 = 0; i6 < this.tsteps && !z; i6++) {
                for (int i7 = 0; i7 < this.subtsteps && !z; i7++) {
                    this.D.nextScaledMultivariateNormal(dArr2, d, dArr);
                    d2 += d;
                    z = this.rejector.reject(d2, dArr);
                    if (!z) {
                        dArr2[0] = dArr[0];
                        dArr2[1] = dArr[1];
                    }
                }
                if (!z) {
                    iArr[i6] = x(dArr[0]);
                    iArr2[i6] = y(dArr[1]);
                    increment(i, i2, iArr[i6], iArr2[i6], i6);
                }
            }
            if (!z) {
                i4++;
            }
            if (i5 % 10000 == 0) {
                System.out.print(".");
                System.out.flush();
            }
        }
        System.out.println();
        return i4;
    }

    public void populate(int i, int i2, int i3, boolean z) {
        double d = this.dt / this.subtsteps;
        double[] dArr = new double[2];
        double[] dArr2 = new double[2];
        int[] iArr = new int[this.tsteps];
        int[] iArr2 = new int[this.tsteps];
        for (int i4 = 0; i4 < i3; i4++) {
            double d2 = 0.0d;
            dArr2[0] = ((i + Math.random()) * this.dx) + this.minx;
            dArr2[1] = ((i2 + Math.random()) * this.dy) + this.miny;
            while (this.rejector.reject(0.0d, dArr2)) {
                dArr2[0] = ((i + Math.random()) * this.dx) + this.minx;
                dArr2[1] = ((i2 + Math.random()) * this.dy) + this.miny;
            }
            for (int i5 = 0; i5 < this.tsteps; i5++) {
                for (int i6 = 0; i6 < this.subtsteps; i6++) {
                    do {
                        this.D.nextScaledMultivariateNormal(dArr2, d, dArr);
                        d2 += d;
                    } while (this.rejector.reject(d2, dArr));
                    dArr2[0] = dArr[0];
                    dArr2[1] = dArr[1];
                }
                iArr[i5] = x(dArr[0]);
                iArr2[i5] = y(dArr[1]);
                increment(i, i2, iArr[i5], iArr2[i5], i5);
            }
            if (z) {
                for (int i7 = 0; i7 < this.tsteps; i7++) {
                    for (int i8 = i7 + 1; i8 < this.tsteps; i8++) {
                        increment(iArr[i7], iArr2[i7], iArr[i8], iArr2[i8], (i8 - i7) - 1);
                    }
                }
            }
            if (i4 % 1000 == 0) {
                System.out.print(".");
                System.out.flush();
            }
        }
    }

    private void increment(int i, int i2, int i3, int i4, int i5) {
        int[] iArr = this.counts[i][i2][i3][i4];
        iArr[i5] = iArr[i5] + 1;
        int[] iArr2 = this.normalization[i][i2];
        iArr2[i5] = iArr2[i5] + 1;
        if (this.counts[i][i2][i3][i4][i5] > this.maxCount[i][i2][i5]) {
            this.maxCount[i][i2][i5] = this.counts[i][i2][i3][i4][i5];
        }
    }

    public void populate(int i) {
        System.out.println("Populating numerical transition probabilities");
        for (int i2 = 0; i2 < this.latticeWidth; i2++) {
            for (int i3 = 0; i3 < this.latticeHeight; i3++) {
                populate(i2, i3, i, true);
            }
            System.out.print(".");
            System.out.flush();
        }
        System.out.println((this.latticeWidth * this.latticeHeight * i) + " new paths computed.");
    }

    public final int x(double d) {
        return (int) ((d - this.minx) / this.dx);
    }

    public final int y(double d) {
        return (int) ((d - this.miny) / this.dy);
    }

    public final int t(double d) {
        return (int) (d / this.dt);
    }

    public double getProb(Point2D point2D, Point2D point2D2, double d) {
        int x = x(point2D.getX());
        int x2 = x(point2D.getY());
        int x3 = x(point2D2.getX());
        int x4 = x(point2D2.getY());
        if (d > this.tsteps * this.dt) {
            System.err.println("Time = " + d + ", max time estimated is " + (this.tsteps * this.dt));
            return this.counts[x][x2][x3][x4][this.tsteps - 1] / this.normalization[x][x2][this.tsteps - 1];
        }
        int t = t(d);
        double d2 = (((t * this.dt) + this.dt) - d) / this.dt;
        return (d2 * p(x, x2, x3, x4, t)) + ((1.0d - d2) * p(x, x2, x3, x4, t + 1));
    }

    public double p(int i, int i2, int i3, int i4, int i5) {
        return this.counts[i][i2][i3][i4][i5] / this.normalization[i][i2][i5];
    }

    public double r(int i, int i2, int i3, int i4, int i5) {
        return this.counts[i][i2][i3][i4][i5] / this.maxCount[i][i2][i5];
    }

    public void writeToFile(String str) throws FileNotFoundException {
        PrintWriter printWriter = new PrintWriter(str);
        printWriter.write("xsteps=" + this.latticeWidth + "\n");
        printWriter.write("ysteps=" + this.latticeHeight + "\n");
        printWriter.write("tsteps=" + this.tsteps + "\n");
        printWriter.write("dx=" + this.dx + "\n");
        printWriter.write("dy=" + this.dy + "\n");
        printWriter.write("dt=" + this.dt + "\n");
        printWriter.write("minx=" + this.minx + "\n");
        printWriter.write("miny=" + this.miny + "\n");
        printWriter.write("D=" + matrixString());
        for (int i = 0; i < this.latticeWidth; i++) {
            for (int i2 = 0; i2 < this.latticeHeight; i2++) {
                for (int i3 = 0; i3 < this.latticeWidth; i3++) {
                    for (int i4 = 0; i4 < this.latticeHeight; i4++) {
                        for (int i5 = 0; i5 < this.tsteps; i5++) {
                            printWriter.write(i + "\t" + i2 + "\t" + i3 + "\t" + i4 + "\t" + i5 + "\t" + this.counts[i][i2][i3][i4][i5] + "\n");
                        }
                    }
                }
            }
        }
        printWriter.close();
    }

    private String matrixString() {
        double[][] scaleMatrix = this.D.getScaleMatrix();
        StringBuilder sb = new StringBuilder();
        sb.append(GraphMLUtils.START_ATTRIBUTE);
        for (int i = 0; i < scaleMatrix.length; i++) {
            sb.append(GraphMLUtils.START_ATTRIBUTE);
            sb.append(scaleMatrix[i][0]);
            for (int i2 = 1; i2 < scaleMatrix[i].length; i2++) {
                sb.append("," + scaleMatrix[i][0]);
            }
            sb.append(GraphMLUtils.END_ATTRIBUTE);
        }
        sb.append(GraphMLUtils.END_ATTRIBUTE);
        return sb.toString();
    }

    /* JADX WARN: Type inference failed for: r3v3, types: [double[], double[][]] */
    public static void main(String[] strArr) throws FileNotFoundException {
        Rectangle2D.Double r0 = new Rectangle2D.Double(0.0d, 0.0d, 1.0d, 1.0d);
        NumericalSpaceTimeProbs2D numericalSpaceTimeProbs2D = new NumericalSpaceTimeProbs2D(50, 50, 50, 1, 0.02d, r0, new MultivariateNormalDistribution(new double[]{0.0d}, (double[][]) new double[]{new double[]{1.0d, 0.0d}, new double[]{0.0d, 1.0d}}), SpaceTimeRejector.Utils.createSimpleBounds2D(r0));
        long currentTimeMillis = System.currentTimeMillis();
        numericalSpaceTimeProbs2D.populate(0, 0, 1000, true);
        System.out.println("Time taken = " + ((System.currentTimeMillis() - currentTimeMillis) / 1000) + " seconds");
        for (int i = 0; i < 10; i++) {
            Point2D.Double r02 = new Point2D.Double(Math.random(), Math.random());
            Point2D.Double r03 = new Point2D.Double(Math.random(), Math.random());
            double random = Math.random();
            System.out.println("Pr(" + r03.getX() + ", " + r03.getY() + " | " + r02.getX() + ", " + r02.getY() + ", t=" + random + ") = " + numericalSpaceTimeProbs2D.getProb(r02, r03, random));
        }
    }
}
