package dr.inference.distribution.shrinkage;

import dr.inference.model.Parameter;
import dr.math.distributions.NormalDistribution;

/* loaded from: input_file:dr/inference/distribution/shrinkage/JointBayesianBridgeDistributionModel.class */
public class JointBayesianBridgeDistributionModel extends BayesianBridgeDistributionModel {
    private final Parameter localScale;
    private final Parameter slabWidth;

    public JointBayesianBridgeDistributionModel(Parameter parameter, Parameter parameter2, Parameter parameter3, Parameter parameter4, int i) {
        super(parameter, parameter3, i);
        this.localScale = parameter2;
        this.slabWidth = parameter4;
        if (i != parameter2.getDimension()) {
            throw new IllegalArgumentException("Invalid dimensions");
        }
        addVariable(parameter2);
    }

    @Override // dr.inference.distribution.shrinkage.BayesianBridgeDistributionModel
    public Parameter getLocalScale() {
        return this.localScale;
    }

    @Override // dr.inference.distribution.shrinkage.BayesianBridgeDistributionModel
    public Parameter getSlabWidth() {
        return this.slabWidth;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    @Override // dr.inference.distribution.shrinkage.BayesianBridgeDistributionModel
    public double[] gradientLogPdf(double[] dArr) {
        double[] dArr2 = new double[this.dim];
        for (int i = 0; i < this.dim; i++) {
            dArr2[i] = NormalDistribution.gradLogPdf(dArr[i], 0.0d, getStandardDeviation(i));
        }
        return dArr2;
    }

    @Override // dr.math.distributions.MultivariateDistribution, dr.inference.distribution.DensityModel
    public double logPdf(double[] dArr) {
        double d = 0.0d;
        for (int i = 0; i < this.dim; i++) {
            d += NormalDistribution.logPdf(dArr[i], 0.0d, getStandardDeviation(i));
        }
        return d;
    }

    private double getStandardDeviation(int i) {
        double parameterValue = this.globalScale.getParameterValue(0) * this.localScale.getParameterValue(i);
        if (this.slabWidth != null) {
            double parameterValue2 = parameterValue / this.slabWidth.getParameterValue(0);
            parameterValue /= Math.sqrt(1.0d + (parameterValue2 * parameterValue2));
        }
        return parameterValue;
    }
}
