package dr.math.distributions;

import dr.inference.model.GradientProvider;
import dr.math.ErrorFunction;
import dr.math.MathUtils;
import dr.math.UnivariateFunction;

/* loaded from: input_file:dr/math/distributions/NormalDistribution.class */
public class NormalDistribution implements Distribution, RandomGenerator, GradientProvider {
    private final UnivariateFunction pdfFunction = new UnivariateFunction() { // from class: dr.math.distributions.NormalDistribution.1
        @Override // dr.math.UnivariateFunction
        public final double evaluate(double d2) {
            return NormalDistribution.this.pdf(d2);
        }

        @Override // dr.math.UnivariateFunction
        public final double getLowerBound() {
            return Double.NEGATIVE_INFINITY;
        }

        @Override // dr.math.UnivariateFunction
        public final double getUpperBound() {
            return Double.POSITIVE_INFINITY;
        }
    };
    protected double m;
    protected double sd;
    private static final double[] a = {2.2352520354606837d, 161.02823106855587d, 1067.6894854603709d, 18154.98125334356d, 0.06568233791820745d};
    private static final double[] b = {47.202581904688245d, 976.0985517377767d, 10260.932208618979d, 45507.78933502673d};
    private static final double[] c = {0.39894151208813466d, 8.883149794388377d, 93.50665613217785d, 597.2702763948002d, 2494.5375852903726d, 6848.190450536283d, 11602.65143764735d, 9842.714838383978d, 1.0765576773720192E-8d};
    private static final double[] d = {22.266688044328117d, 235.387901782625d, 1519.3775994075547d, 6485.558298266761d, 18615.571640885097d, 34900.95272114598d, 38912.00328609327d, 19685.429676859992d};
    private static final double[] p_ = {0.215898534057957d, 0.12740116116024736d, 0.022235277870649807d, 0.0014216191932278934d, 2.9112874951168793E-5d, 0.023073441764940174d};
    private static final double[] q = {1.284260096144911d, 0.4682382124808651d, 0.06598813786892856d, 0.0037823963320275824d, 7.297515550839662E-5d};
    private static final int CUTOFF = 16;
    private static final double M_SQRT_32 = 5.656854249492381d;
    private static final double M_1_SQRT_2PI = 0.3989422804014327d;
    private static final double DBL_EPSILON = 2.220446049250313E-16d;

    public NormalDistribution(double d2, double d3) {
        this.m = d2;
        this.sd = d3;
    }

    public double getMean() {
        return this.m;
    }

    public void setMean(double d2) {
        this.m = d2;
    }

    public double getSD() {
        return this.sd;
    }

    public void setSD(double d2) {
        this.sd = d2;
    }

    @Override // dr.math.distributions.Distribution
    public double pdf(double d2) {
        return pdf(d2, this.m, this.sd);
    }

    @Override // dr.math.distributions.Distribution
    public double logPdf(double d2) {
        return logPdf(d2, this.m, this.sd);
    }

    @Override // dr.math.distributions.Distribution
    public double cdf(double d2) {
        return cdf(d2, this.m, this.sd);
    }

    @Override // dr.math.distributions.Distribution
    public double quantile(double d2) {
        return quantile(d2, this.m, this.sd);
    }

    @Override // dr.math.distributions.Distribution
    public double mean() {
        return mean(this.m, this.sd);
    }

    @Override // dr.math.distributions.Distribution
    public double variance() {
        return variance(this.m, this.sd);
    }

    @Override // dr.math.distributions.Distribution
    public final UnivariateFunction getProbabilityDensityFunction() {
        return this.pdfFunction;
    }

    public static double pdf(double d2, double d3, double d4) {
        return (1.0d / (Math.sqrt(6.283185307179586d) * d4)) * Math.exp(((-(d2 - d3)) * (d2 - d3)) / ((2.0d * d4) * d4));
    }

    public static double logPdf(double d2, double d3, double d4) {
        return Math.log(1.0d / (Math.sqrt(6.283185307179586d) * d4)) + (((-(d2 - d3)) * (d2 - d3)) / ((2.0d * d4) * d4));
    }

    public static double gradLogPdf(double d2, double d3, double d4) {
        return (d3 - d2) / (d4 * d4);
    }

    public static double hessianLogPdf(double d2, double d3, double d4) {
        return (-1.0d) / (d4 * d4);
    }

    public static double cdf(double d2, double d3, double d4) {
        return cdf(d2, d3, d4, false);
    }

    public static double quantile(double d2, double d3, double d4) {
        return d3 + (Math.sqrt(2.0d) * d4 * ErrorFunction.inverseErf((2.0d * d2) - 1.0d));
    }

    public static double mean(double d2, double d3) {
        return d2;
    }

    public static double variance(double d2, double d3) {
        return d3 * d3;
    }

    public static double cdf(double d2, double d3, double d4, boolean z) {
        if (Double.isNaN(d2) || Double.isNaN(d3) || Double.isNaN(d4)) {
            return Double.NaN;
        }
        if (Double.isInfinite(d2) && d3 == d2) {
            return Double.NaN;
        }
        if (d4 > 0.0d) {
            double d5 = (d2 - d3) / d4;
            return Double.isInfinite(d5) ? d2 < d3 ? 0.0d : 1.0d : standardCDF(d5, z);
        }
        if (d4 < 0.0d) {
            return Double.NaN;
        }
        return d2 < d3 ? 0.0d : 1.0d;
    }

    public static double standardCDF(double d2, boolean z) {
        double d3;
        double d4;
        if (Double.isNaN(d2)) {
            return Double.NaN;
        }
        double d5 = d2;
        boolean z2 = 0 == 0;
        double abs = Math.abs(d2);
        if (abs <= 0.67448975d) {
            if (abs > 1.1102230246251565E-16d) {
                double d6 = d2 * d2;
                d4 = a[4] * d6;
                d3 = d6;
                for (int i = 0; i < 3; i++) {
                    d4 = (d4 + a[i]) * d6;
                    d3 = (d3 + b[i]) * d6;
                }
            } else {
                d3 = 0.0d;
                d4 = 0.0d;
            }
            double d7 = (d2 * (d4 + a[3])) / (d3 + b[3]);
            if (z2) {
                d5 = 0.5d + d7;
            }
            r30 = 0 != 0 ? 0.5d - d7 : Double.NaN;
            if (z) {
                if (z2) {
                    d5 = Math.log(d5);
                }
                if (0 != 0) {
                    Math.log(r30);
                }
            }
        } else if (abs <= M_SQRT_32) {
            double d8 = c[8] * abs;
            double d9 = abs;
            for (int i2 = 0; i2 < 7; i2++) {
                d8 = (d8 + c[i2]) * abs;
                d9 = (d9 + d[i2]) * abs;
            }
            double d10 = (d8 + c[7]) / (d9 + d[7]);
            double d11 = (((int) (abs * 16.0d)) * 1.0d) / 16.0d;
            double d12 = (abs - d11) * (abs + d11);
            if (z) {
                d5 = ((-d11) * d11 * 0.5d) + ((-d12) * 0.5d) + Math.log(d10);
                if ((z2 && d2 > 0.0d) || (0 != 0 && d2 <= 0.0d)) {
                    r30 = Math.log(1.0d - ((Math.exp(((-d11) * d11) * 0.5d) * Math.exp((-d12) * 0.5d)) * d10));
                }
            } else {
                d5 = Math.exp((-d11) * d11 * 0.5d) * Math.exp((-d12) * 0.5d) * d10;
                r30 = 1.0d - d5;
            }
            if (d2 > 0.0d) {
                if (z2) {
                    d5 = r30;
                }
            }
        } else if (z || ((z2 && -37.5193d < d2 && d2 < 8.2924d) || (0 != 0 && -8.2924d < d2 && d2 < 37.5193d))) {
            double d13 = 1.0d / (d2 * d2);
            double d14 = p_[5] * d13;
            double d15 = d13;
            for (int i3 = 0; i3 < 4; i3++) {
                d14 = (d14 + p_[i3]) * d13;
                d15 = (d15 + q[i3]) * d13;
            }
            double d16 = (M_1_SQRT_2PI - ((d13 * (d14 + p_[4])) / (d15 + q[4]))) / abs;
            double d17 = (((int) (d2 * 16.0d)) * 1.0d) / 16.0d;
            double d18 = (d2 - d17) * (d2 + d17);
            if (z) {
                d5 = ((-d17) * d17 * 0.5d) + ((-d18) * 0.5d) + Math.log(d16);
                if ((z2 && d2 > 0.0d) || (0 != 0 && d2 <= 0.0d)) {
                    r30 = Math.log(1.0d - ((Math.exp(((-d17) * d17) * 0.5d) * Math.exp((-d18) * 0.5d)) * d16));
                }
            } else {
                d5 = Math.exp((-d17) * d17 * 0.5d) * Math.exp((-d18) * 0.5d) * d16;
                r30 = 1.0d - d5;
            }
            if (d2 > 0.0d) {
                if (z2) {
                    d5 = r30;
                }
            }
        } else {
            d5 = d2 > 0.0d ? 1.0d : 0.0d;
        }
        return d5;
    }

    public static double standardTail(double d2, boolean z) {
        double exp;
        if (d2 < 0.0d) {
            z = !z;
            d2 = -d2;
        }
        if (d2 <= 8.0d || (z && d2 <= 37.0d)) {
            double d3 = 0.5d * d2 * d2;
            exp = d2 >= 1.28d ? (0.398942280385d * Math.exp(-d3)) / ((d2 - 3.8052E-8d) + (1.00000615302d / ((d2 + 3.98064794E-4d) + (1.98615381364d / ((d2 - 0.151679116635d) + (5.29330324926d / ((d2 + 4.8385912808d) - (15.1508972451d / ((d2 + 0.742380924027d) + (30.789933034d / (d2 + 3.99019417011d))))))))))) : 0.5d - (d2 * (0.398942280444d - ((0.399903438504d * d3) / ((d3 + 5.75885480458d) - (29.8213557808d / ((d3 + 2.62433121679d) + (48.6959930692d / (d3 + 5.92885724438d))))))));
        } else {
            exp = 0.0d;
        }
        if (!z) {
            exp = 1.0d - exp;
        }
        return exp;
    }

    public static double tailCDF(double d2, double d3, double d4) {
        return standardTail((d2 - d3) / d4, true);
    }

    public static double tailCDF(double d2, double d3, double d4, boolean z) {
        return standardTail((d2 - d3) / d4, z);
    }

    public double tailCDF(double d2) {
        return standardTail((d2 - this.m) / this.sd, true);
    }

    static void testTail(double d2, double d3, double d4) {
        double cdf = 1.0d - cdf(d2, d3, d4);
        double cdf2 = 1.0d - cdf(d2, d3, d4, false);
        double tailCDF = tailCDF(d2, d3, d4);
        System.out.println(">" + d2 + " N(" + d3 + ", " + d4 + ")");
        System.out.println("Original CDF: " + cdf);
        System.out.println("     New CDF: " + cdf2);
        System.out.println("     tailCDF: " + tailCDF);
    }

    public static void main(String[] strArr) {
        testTail(0.1d, 0.0d, 1.0d);
        System.out.println();
        testTail(1.0d, 0.0d, 1.0d);
        System.out.println();
        testTail(5.0d, 0.0d, 1.0d);
        System.out.println();
        testTail(7.0d, 0.0d, 1.0d);
        System.out.println();
        testTail(8.0d, 0.0d, 1.0d);
        System.out.println();
        testTail(8.25d, 0.0d, 1.0d);
        System.out.println();
        testTail(10.0d, 0.0d, 1.0d);
        System.out.println(standardCDF(4.0d, true));
    }

    @Override // dr.math.distributions.RandomGenerator
    public Object nextRandom() {
        return Double.valueOf((MathUtils.nextGaussian() * getSD()) + getMean());
    }

    @Override // dr.math.distributions.RandomGenerator
    public double logPdf(Object obj) {
        ((Double) obj).doubleValue();
        return logPdf(obj);
    }

    @Override // dr.inference.model.GradientProvider
    public int getDimension() {
        return 1;
    }

    @Override // dr.inference.model.GradientProvider
    public double[] getGradientLogDensity(Object obj) {
        double[] doubleArray = GradientProvider.toDoubleArray(obj);
        double[] dArr = new double[doubleArray.length];
        for (int i = 0; i < doubleArray.length; i++) {
            dArr[i] = gradLogPdf(doubleArray[i], getMean(), getSD());
        }
        return dArr;
    }
}
