package dr.math.distributions;

import cern.jet.random.Gamma;
import cern.jet.random.engine.RandomEngine;
import dr.evomodel.tree.UniformNodeHeightPrior;
import dr.inference.model.GradientProvider;
import dr.inference.model.HessianProvider;
import dr.math.GammaFunction;
import dr.math.MathUtils;
import dr.math.UnivariateFunction;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import org.apache.commons.math.MathException;
import org.apache.commons.math.distribution.GammaDistributionImpl;

/* loaded from: input_file:dr/math/distributions/GammaDistribution.class */
public class GammaDistribution implements Distribution, GradientProvider, HessianProvider {
    protected double shape;
    protected double scale;
    private static final boolean TRY_COLT = false;
    private static RandomEngine randomEngine;
    private static Gamma coltGamma;
    private final UnivariateFunction pdfFunction = new UnivariateFunction() { // from class: dr.math.distributions.GammaDistribution.1
        @Override // dr.math.UnivariateFunction
        public final double evaluate(double d) {
            return GammaDistribution.this.pdf(d);
        }

        @Override // dr.math.UnivariateFunction
        public final double getLowerBound() {
            return 0.0d;
        }

        @Override // dr.math.UnivariateFunction
        public final double getUpperBound() {
            return Double.POSITIVE_INFINITY;
        }
    };
    protected int samples = 0;

    public GammaDistribution(double d, double d2) {
        this.shape = d;
        this.scale = d2;
    }

    public double getShape() {
        return this.shape;
    }

    public void setShape(double d) {
        this.shape = d;
    }

    public double getScale() {
        return this.scale;
    }

    public void setScale(double d) {
        this.scale = d;
    }

    @Override // dr.math.distributions.Distribution
    public double pdf(double d) {
        return pdf(d, this.shape, this.scale);
    }

    @Override // dr.math.distributions.Distribution
    public double logPdf(double d) {
        return logPdf(d, this.shape, this.scale);
    }

    @Override // dr.math.distributions.Distribution
    public double cdf(double d) {
        return cdf(d, this.shape, this.scale);
    }

    @Override // dr.math.distributions.Distribution
    public double quantile(double d) {
        return quantile(d, this.shape, this.scale);
    }

    @Override // dr.math.distributions.Distribution
    public double mean() {
        return mean(this.shape, this.scale);
    }

    @Override // dr.math.distributions.Distribution
    public double variance() {
        return variance(this.shape, this.scale);
    }

    public double nextGamma() {
        return nextGamma(this.shape, this.scale);
    }

    @Override // dr.math.distributions.Distribution
    public final UnivariateFunction getProbabilityDensityFunction() {
        return this.pdfFunction;
    }

    public static double pdf(double d, double d2, double d3) {
        if (d < 0.0d) {
            return 0.0d;
        }
        if (d == 0.0d) {
            if (d2 == 1.0d) {
                return 1.0d / d3;
            }
            return 0.0d;
        }
        if (d2 == 0.0d) {
            return 1.0d / d;
        }
        if (d2 == -0.5d) {
            return Math.sqrt(d);
        }
        double d4 = d / d3;
        return d2 == 1.0d ? Math.exp(-d4) / d3 : Math.exp((((d2 - 1.0d) * Math.log(d4)) - d4) - GammaFunction.lnGamma(d2)) / d3;
    }

    public static double logPdf(double d, double d2, double d3) {
        if (d < 0.0d) {
            return Double.NEGATIVE_INFINITY;
        }
        if (d != 0.0d) {
            return d2 == 1.0d ? ((-d) / d3) - Math.log(d3) : d2 == 0.0d ? -Math.log(d) : d2 == -0.5d ? 0.5d * Math.log(d) : ((((d2 - 1.0d) * (Math.log(d) - Math.log(d3))) - (d / d3)) - GammaFunction.lnGamma(d2)) - Math.log(d3);
        }
        if (d2 == 1.0d) {
            return Math.log(1.0d / d3);
        }
        return Double.NEGATIVE_INFINITY;
    }

    public static double gradLogPdf2(double d, double d2, double d3) {
        if (d < 0.0d) {
            return Double.POSITIVE_INFINITY;
        }
        if (d2 == 1.0d) {
            return (-1.0d) / d3;
        }
        if (d == 0.0d) {
            return Double.POSITIVE_INFINITY;
        }
        return d2 == 0.0d ? (-1.0d) / d : d2 == -0.5d ? 0.5d / d : ((d2 - 1.0d) / d) - (1.0d / d3);
    }

    public static double cdf(double d, double d2, double d3) {
        if (d < 0.0d || d2 <= 0.0d) {
            return 0.0d;
        }
        return GammaFunction.incompleteGammaP(d2, d / d3);
    }

    public static double quantile(double d, double d2, double d3) {
        return 0.5d * d3 * pointChi2(d, 2.0d * d2);
    }

    public static double mean(double d, double d2) {
        return d2 * d;
    }

    public static double variance(double d, double d2) {
        return d2 * d2 * d;
    }

    public static double nextGamma(double d, double d2) {
        return nextGamma(d, d2, false);
    }

    public static double nextGamma(double d, double d2, boolean z) {
        double d3;
        double exp;
        double d4 = 0.0d;
        if (d < 1.0E-5d) {
            if (d < 0.0d) {
                System.out.println("Negative shape parameter");
                throw new IllegalArgumentException("Negative shape parameter");
            }
            double log = Math.log(50.0d) - Math.log(1.0E-20d);
            do {
                exp = Math.exp(Math.log(1.0E-20d) + (log * MathUtils.nextDouble()));
            } while (Math.exp(-exp) < MathUtils.nextDouble());
            return exp;
        }
        if (z && Math.floor(d) == d && d > 4.0d) {
            for (int i = 0; i < d; i++) {
                d4 += -Math.log(MathUtils.nextDouble());
            }
            return d4 * d2;
        }
        if (d == 1.0d) {
            return (-Math.log(MathUtils.nextDouble())) * d2;
        }
        if (d == 2.0d) {
            return (-Math.log(MathUtils.nextDouble() * MathUtils.nextDouble())) * d2;
        }
        if (d == 3.0d) {
            return (-Math.log(MathUtils.nextDouble() * MathUtils.nextDouble() * MathUtils.nextDouble())) * d2;
        }
        if (d == 4.0d) {
            return (-Math.log(MathUtils.nextDouble() * MathUtils.nextDouble() * MathUtils.nextDouble() * MathUtils.nextDouble())) * d2;
        }
        do {
            try {
                d3 = quantile(MathUtils.nextDouble(), d, d2);
            } catch (IllegalArgumentException e) {
                d3 = 0.0d;
            }
        } while (d3 == 0.0d);
        return d3;
    }

    public static double nextExpGamma(double d, double d2, double d3) {
        return nextExpGamma(d, d2, d3, false);
    }

    public static double nextExpGamma(double d, double d2, double d3, boolean z) {
        double nextGamma;
        double exp;
        double nextGamma2;
        int i = 0;
        if (!z) {
            if (d < 0.0d) {
                return 1.0d / nextExpGamma(-d, d3, d2);
            }
            if (d == 0.0d) {
                double sqrt = Math.sqrt(d2 / d3);
                double d4 = 1.0d / d3;
                if (d4 < sqrt) {
                    d4 = sqrt;
                }
                double exp2 = (1.0d / d4) * Math.exp((-1.0d) / (d3 * d4));
                do {
                    nextGamma2 = nextGamma(1.0d, d2) + sqrt;
                    i++;
                    if (MathUtils.nextDouble() <= ((1.0d / nextGamma2) * Math.exp((-1.0d) / (nextGamma2 * d3))) / exp2) {
                        break;
                    }
                } while (i < 10000);
                if (i == 10000) {
                    System.out.println("Severe Warning: nextExpGamma (shape=0) failed to generate a sample - returning bogus value!");
                }
                if (MathUtils.nextDouble() > 0.5d) {
                    nextGamma2 = d2 / (d3 * nextGamma2);
                }
                return nextGamma2;
            }
            if (d <= 0.0d) {
                System.out.println("nextExpGamma: Illegal argument (shape parameter is must be positive)");
                throw new IllegalArgumentException("");
            }
            double sqrt2 = ((d * d2) + Math.sqrt(((4.0d * d2) / d3) + (((d * d) * d2) * d2))) / 2.0d;
            double d5 = 1.0d / ((1.0d / d2) - (1.0d / ((d3 * sqrt2) * sqrt2)));
            do {
                nextGamma = nextGamma(d, d5);
                exp = Math.exp(((-((nextGamma / sqrt2) - 1.0d)) * ((nextGamma / sqrt2) - 1.0d)) / (d3 * nextGamma));
                i++;
                if (MathUtils.nextDouble() <= exp) {
                    break;
                }
            } while (i < 10000);
            if (exp > 1.0d) {
                System.out.println("PROBLEM!!  This should be impossible!!  Contact the authors.");
            }
            if (d5 < 0.0d) {
                System.out.println("PROBLEM!! This should be impossible too!!  Contact the authors.");
            }
            if (i == 10000) {
                System.out.println("Severe Warning: nextExpGamma failed to generate a sample - returning bogus value!");
            }
            return nextGamma;
        }
        do {
            nextGamma = nextGamma(d, d2);
        } while (MathUtils.nextDouble() > Math.exp((-1.0d) / (d3 * nextGamma)));
        return nextGamma;
    }

    public static double gradLogPdf(double d, double d2, double d3) {
        if (d < 0.0d) {
            return 0.0d;
        }
        return d2 == -0.5d ? 0.5d / d : d2 == 0.0d ? (-1.0d) / d : d2 == 1.0d ? (-1.0d) / d3 : ((d2 - 1.0d) / d) - (1.0d / d3);
    }

    public static double hessianLogPdf(double d, double d2, double d3) {
        if (d < 0.0d) {
            return 0.0d;
        }
        if (d2 == -0.5d) {
            return (-0.5d) / (d * d);
        }
        if (d2 == 0.0d) {
            return 1.0d / (d * d);
        }
        if (d2 == 1.0d) {
            return 0.0d;
        }
        return (1.0d - d2) / (d * d);
    }

    /*  JADX ERROR: JadxRuntimeException in pass: BlockProcessor
        jadx.core.utils.exceptions.JadxRuntimeException: CFG modification limit reached, blocks count: 132
        	at jadx.core.dex.visitors.blocks.BlockProcessor.processBlocksTree(BlockProcessor.java:64)
        	at jadx.core.dex.visitors.blocks.BlockProcessor.visit(BlockProcessor.java:44)
        */
    private static double pointChi2(double r31, double r33) {
        /*
            Method dump skipped, instructions count: 730
            To view this dump add '--comments-level debug' option
        */
        throw new UnsupportedOperationException("Method not decompiled: dr.math.distributions.GammaDistribution.pointChi2(double, double):double");
    }

    public static void main(String[] strArr) {
        testQuantile(1.0E-10d, 0.878328435043444d, 0.0013696236839573005d);
        testQuantile(0.5d, 0.878328435043444d, 0.0013696236839573005d);
        testQuantile(0.9999999999d, 0.878328435043444d, 0.0013696236839573005d);
        testQuantileCM(1.0E-10d, 0.878328435043444d, 0.0013696236839573005d);
        testQuantileCM(0.5d, 0.878328435043444d, 0.0013696236839573005d);
        testQuantileCM(0.9999999999d, 0.878328435043444d, 0.0013696236839573005d);
        double d = 0.0125d;
        while (true) {
            double d2 = d;
            if (d2 >= 1.0d) {
                break;
            }
            System.out.print(d2 + ": ");
            try {
                System.out.println(new GammaDistributionImpl(0.878328435043444d, 0.0013696236839573005d).inverseCumulativeProbability(d2));
            } catch (MathException e) {
                System.out.println(e.getMessage());
            }
            d = d2 + 0.025d;
        }
        GammaDistribution gammaDistribution = new GammaDistribution(0.01d, 100.0d);
        double[] dArr = new double[UniformNodeHeightPrior.DEFAULT_MC_SAMPLE];
        double d3 = 0.0d;
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = gammaDistribution.nextGamma();
            d3 += dArr[i];
        }
        double length = d3 / dArr.length;
        System.out.println("Mean = " + length);
        double d4 = 0.0d;
        for (double d5 : dArr) {
            d4 += Math.pow(d5 - length, 2.0d);
        }
        System.out.println("Variance = " + (d4 / dArr.length));
    }

    private static void testQuantile(double d, double d2, double d3) {
        long currentTimeMillis = System.currentTimeMillis();
        for (int i = 0; i < 1000; i++) {
            quantile(d, d2, d3);
        }
        System.out.println("Quantile, " + d + ", for shape=" + d2 + ", scale=" + d3 + " : " + quantile(d, d2, d3) + ", time=" + (System.currentTimeMillis() - currentTimeMillis) + "ms");
    }

    private static void testQuantileCM(double d, double d2, double d3) {
        long currentTimeMillis = System.currentTimeMillis();
        double d4 = 0.0d;
        for (int i = 0; i < 1000; i++) {
            try {
                d4 = new GammaDistributionImpl(d2, d3).inverseCumulativeProbability(d);
            } catch (MathException e) {
                e.printStackTrace();
            }
        }
        d4 = new GammaDistributionImpl(d2, d3).inverseCumulativeProbability(d);
        System.out.println("commons.maths inverseCDF, " + d + ", for shape=" + d2 + ", scale=" + d3 + " : " + d4 + ", time=" + (System.currentTimeMillis() - currentTimeMillis) + "ms");
    }

    private static double KolmogorovSmirnov(List<Double> list, List<Double> list2) {
        int i = 0;
        int i2 = 0;
        for (int i3 = 0; i3 < list.size(); i3++) {
            while (i < list2.size() && list2.get(i).doubleValue() < list.get(i3).doubleValue()) {
                i++;
            }
            i2 = Math.max(i2, i - i3);
        }
        return i2 / Math.sqrt(2.0d * list.size());
    }

    private static void testExpGamma2(double d, double d2, double d3, int i, double d4) {
        double d5 = 0.0d;
        double d6 = 0.0d;
        double d7 = 0.0d;
        ArrayList arrayList = new ArrayList(0);
        for (int i2 = 0; i2 < i; i2++) {
            double nextExpGamma = nextExpGamma(d, d2, d3, false);
            d5 += 1.0d;
            d6 += nextExpGamma;
            d7 += nextExpGamma * nextExpGamma;
            arrayList.add(Double.valueOf(nextExpGamma));
        }
        Collections.sort(arrayList);
        double d8 = d6 / d5;
        double d9 = (d7 - ((d6 * d6) / d5)) / d5;
        System.out.println("Equal-mean test: (shape=" + d + " scale=" + d2 + " bias=" + d3 + " mean=" + d8 + " expected=" + d4 + " var=" + d9 + " median=" + arrayList.get(i / 2) + "): z=" + ((d4 - d8) / Math.sqrt(d9 / i)));
    }

    private static void testExpGamma(double d, double d2, double d3, int i) {
        ArrayList arrayList = new ArrayList(0);
        ArrayList arrayList2 = new ArrayList(0);
        long currentTimeMillis = System.currentTimeMillis();
        for (int i2 = 0; i2 < i; i2++) {
            arrayList.add(Double.valueOf(nextExpGamma(d, d2, d3, true)));
        }
        long currentTimeMillis2 = System.currentTimeMillis();
        for (int i3 = 0; i3 < i; i3++) {
            arrayList2.add(Double.valueOf(nextExpGamma(d, d2, d3, false)));
        }
        Collections.sort(arrayList);
        Collections.sort(arrayList2);
        System.out.println("KS test for shape=" + d + ", bias=" + d3 + " : " + KolmogorovSmirnov(arrayList, arrayList2) + " and " + KolmogorovSmirnov(arrayList2, arrayList) + " slow=" + (currentTimeMillis2 - currentTimeMillis) + "ms, fast=" + (System.currentTimeMillis() - currentTimeMillis2) + "ms");
    }

    private static void test(double d, double d2, int i) {
        ArrayList arrayList = new ArrayList(0);
        ArrayList arrayList2 = new ArrayList(0);
        for (int i2 = 0; i2 < i; i2++) {
            arrayList.add(Double.valueOf(nextGamma(d, d2, true)));
            arrayList2.add(Double.valueOf(nextGamma(d, d2, false)));
        }
        Collections.sort(arrayList);
        Collections.sort(arrayList2);
        System.out.println("KS test for shape=" + d + " : " + KolmogorovSmirnov(arrayList, arrayList2) + " and " + KolmogorovSmirnov(arrayList2, arrayList));
    }

    private static void testAddition(double d, double d2, int i, int i2) {
        ArrayList arrayList = new ArrayList(0);
        ArrayList arrayList2 = new ArrayList(0);
        ArrayList arrayList3 = new ArrayList(0);
        for (int i3 = 0; i3 < i2; i3++) {
            double d3 = 0.0d;
            for (int i4 = 0; i4 < i; i4++) {
                d3 += nextGamma(d, d2, true);
            }
            arrayList.add(Double.valueOf(d3));
            double d4 = 0.0d;
            for (int i5 = 0; i5 < i; i5++) {
                d4 += nextGamma(d, d2, false);
            }
            arrayList2.add(Double.valueOf(d4));
            arrayList3.add(Double.valueOf(nextGamma(d * i, d2, true)));
        }
        Collections.sort(arrayList);
        Collections.sort(arrayList2);
        Collections.sort(arrayList3);
        System.out.println("KS test for shape=" + d + " : slow=" + KolmogorovSmirnov(arrayList, arrayList3) + " & " + KolmogorovSmirnov(arrayList3, arrayList) + "; fast=" + KolmogorovSmirnov(arrayList2, arrayList3) + " & " + KolmogorovSmirnov(arrayList3, arrayList2));
    }

    @Override // dr.inference.model.GradientProvider
    public int getDimension() {
        return 1;
    }

    @Override // dr.inference.model.GradientProvider
    public double[] getGradientLogDensity(Object obj) {
        double[] doubleArray = GradientProvider.toDoubleArray(obj);
        double[] dArr = new double[doubleArray.length];
        for (int i = 0; i < doubleArray.length; i++) {
            dArr[i] = gradLogPdf(doubleArray[i], this.shape, this.scale);
        }
        return dArr;
    }

    @Override // dr.inference.model.HessianProvider
    public double[] getDiagonalHessianLogDensity(Object obj) {
        double[] doubleArray = GradientProvider.toDoubleArray(obj);
        double[] dArr = new double[doubleArray.length];
        for (int i = 0; i < doubleArray.length; i++) {
            dArr[i] = hessianLogPdf(doubleArray[i], this.shape, this.scale);
        }
        return dArr;
    }

    @Override // dr.inference.model.HessianProvider
    public double[][] getHessianLogDensity(Object obj) {
        return HessianProvider.expandDiagonals(getDiagonalHessianLogDensity(obj));
    }
}
