package dr.inference.distribution;

import dr.inference.model.AbstractModel;
import dr.inference.model.GradientProvider;
import dr.inference.model.HessianProvider;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.math.UnivariateFunction;
import dr.math.distributions.GammaDistribution;
import org.apache.commons.math.MathException;
import org.apache.commons.math.distribution.GammaDistributionImpl;
import org.w3c.dom.Document;
import org.w3c.dom.Element;

/* loaded from: input_file:dr/inference/distribution/GammaDistributionModel.class */
public class GammaDistributionModel extends AbstractModel implements ParametricDistributionModel, GradientProvider, HessianProvider {
    public static final String GAMMA_DISTRIBUTION_MODEL = "gammaDistributionModel";
    public static final String ONE_P_GAMMA_DISTRIBUTION_MODEL = "onePGammaDistributionModel";
    private final UnivariateFunction pdfFunction;
    private final GammaParameterizationType parameterization;
    private final Variable<Double> shape;
    private final Variable<Double> scale;
    private final Variable<Double> rate;
    private final Variable<Double> mean;
    private final double offset;

    /* loaded from: input_file:dr/inference/distribution/GammaDistributionModel$GammaParameterizationType.class */
    public enum GammaParameterizationType {
        ShapeScale,
        ShapeRate,
        ShapeMean,
        OneParameter
    }

    public GammaDistributionModel(Variable<Double> variable, Variable<Double> variable2) {
        this(GammaParameterizationType.ShapeScale, variable, variable2, 0.0d);
    }

    public GammaDistributionModel(Variable<Double> variable) {
        this(GammaParameterizationType.OneParameter, variable, null, 0.0d);
    }

    public GammaDistributionModel(GammaParameterizationType gammaParameterizationType, Variable<Double> variable, Variable<Double> variable2, double d) {
        super(GAMMA_DISTRIBUTION_MODEL);
        this.pdfFunction = new UnivariateFunction() { // from class: dr.inference.distribution.GammaDistributionModel.1
            @Override // dr.math.UnivariateFunction
            public final double evaluate(double d2) {
                return GammaDistributionModel.this.pdf(d2);
            }

            @Override // dr.math.UnivariateFunction
            public final double getLowerBound() {
                return GammaDistributionModel.this.offset;
            }

            @Override // dr.math.UnivariateFunction
            public final double getUpperBound() {
                return Double.POSITIVE_INFINITY;
            }
        };
        this.offset = d;
        this.parameterization = gammaParameterizationType;
        this.shape = variable;
        addVariable(variable);
        variable.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, 0.0d, 1));
        switch (gammaParameterizationType) {
            case ShapeScale:
                this.scale = variable2;
                addVariable(this.scale);
                this.scale.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, 0.0d, 1));
                this.rate = null;
                this.mean = null;
                return;
            case ShapeRate:
                this.rate = variable2;
                addVariable(this.rate);
                this.rate.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, 0.0d, 1));
                this.scale = null;
                this.mean = null;
                return;
            case ShapeMean:
                this.mean = variable2;
                addVariable(this.mean);
                this.mean.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, 0.0d, 1));
                this.scale = null;
                this.rate = null;
                return;
            case OneParameter:
                this.scale = null;
                this.rate = null;
                this.mean = null;
                return;
            default:
                throw new IllegalArgumentException("Unknown parameterization type");
        }
    }

    @Override // dr.math.distributions.Distribution
    public double pdf(double d) {
        if (d < this.offset) {
            return 0.0d;
        }
        return GammaDistribution.pdf(d - this.offset, getShape(), getScale());
    }

    @Override // dr.math.distributions.Distribution
    public double logPdf(double d) {
        if (d < this.offset) {
            return Double.NEGATIVE_INFINITY;
        }
        return GammaDistribution.logPdf(d - this.offset, getShape(), getScale());
    }

    @Override // dr.math.distributions.Distribution
    public double cdf(double d) {
        if (d < this.offset) {
            return 0.0d;
        }
        return GammaDistribution.cdf(d - this.offset, getShape(), getScale());
    }

    @Override // dr.math.distributions.Distribution
    public double quantile(double d) {
        try {
            return new GammaDistributionImpl(getShape(), getScale()).inverseCumulativeProbability(d) + this.offset;
        } catch (MathException e) {
            return Double.NaN;
        }
    }

    @Override // dr.math.distributions.Distribution
    public double mean() {
        return GammaDistribution.mean(getShape(), getScale()) + this.offset;
    }

    @Override // dr.math.distributions.Distribution
    public double variance() {
        return GammaDistribution.variance(getShape(), getScale());
    }

    @Override // dr.math.distributions.Distribution
    public final UnivariateFunction getProbabilityDensityFunction() {
        return this.pdfFunction;
    }

    @Override // dr.inference.model.AbstractModel
    public void handleModelChangedEvent(Model model, Object obj, int i) {
    }

    @Override // dr.inference.model.AbstractModel
    protected final void handleVariableChangedEvent(Variable variable, int i, Variable.ChangeType changeType) {
    }

    @Override // dr.inference.model.AbstractModel
    protected void storeState() {
    }

    @Override // dr.inference.model.AbstractModel
    protected void restoreState() {
    }

    @Override // dr.inference.model.AbstractModel
    protected void acceptState() {
    }

    @Override // dr.inference.model.GradientProvider
    public int getDimension() {
        return 1;
    }

    @Override // dr.inference.model.GradientProvider
    public double[] getGradientLogDensity(Object obj) {
        double[] doubleArray = GradientProvider.toDoubleArray(obj);
        double[] dArr = new double[doubleArray.length];
        double shape = getShape();
        double scale = getScale();
        for (int i = 0; i < doubleArray.length; i++) {
            dArr[i] = GammaDistribution.gradLogPdf(doubleArray[i] - this.offset, shape, scale);
        }
        return dArr;
    }

    @Override // dr.inference.model.HessianProvider
    public double[] getDiagonalHessianLogDensity(Object obj) {
        double[] doubleArray = GradientProvider.toDoubleArray(obj);
        double[] dArr = new double[doubleArray.length];
        double shape = getShape();
        double scale = getScale();
        for (int i = 0; i < doubleArray.length; i++) {
            dArr[i] = GammaDistribution.hessianLogPdf(doubleArray[i] - this.offset, shape, scale);
        }
        return dArr;
    }

    @Override // dr.inference.model.HessianProvider
    public double[][] getHessianLogDensity(Object obj) {
        return HessianProvider.expandDiagonals(getDiagonalHessianLogDensity(obj));
    }

    @Override // dr.inference.model.AbstractModel
    public Element createElement(Document document) {
        throw new RuntimeException("Not implemented!");
    }

    public double getShape() {
        return this.shape.getValue(0).doubleValue();
    }

    public double getScale() {
        switch (this.parameterization) {
            case ShapeScale:
                return this.scale.getValue(0).doubleValue();
            case ShapeRate:
                return 1.0d / this.rate.getValue(0).doubleValue();
            case ShapeMean:
                return this.mean.getValue(0).doubleValue() / getShape();
            case OneParameter:
                return 1.0d / getShape();
            default:
                throw new IllegalArgumentException("Unknown parameterization type");
        }
    }

    @Override // dr.inference.distribution.DensityModel
    public double logPdf(double[] dArr) {
        return logPdf(dArr[0]);
    }

    @Override // dr.inference.distribution.DensityModel
    public Variable<Double> getLocationVariable() {
        throw new UnsupportedOperationException("Not implemented");
    }
}
