package dr.evomodel.epidemiology.casetocase.periodpriors;

import dr.inference.loggers.LogColumn;
import dr.inference.model.Parameter;
import dr.math.distributions.NormalDistribution;
import dr.xml.AbstractXMLObjectParser;
import dr.xml.AttributeRule;
import dr.xml.XMLObject;
import dr.xml.XMLObjectParser;
import dr.xml.XMLParseException;
import dr.xml.XMLSyntaxRule;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;

/* loaded from: input_file:dr/evomodel/epidemiology/casetocase/periodpriors/KnownVarianceNormalPeriodPriorDistribution.class */
public class KnownVarianceNormalPeriodPriorDistribution extends AbstractPeriodPriorDistribution {
    public static final String NORMAL = "knownVarianceNormalPeriodPriorDistribution";
    public static final String LOG = "log";
    public static final String ID = "id";
    public static final String MU_0 = "mu0";
    public static final String SIGMA = "sigma";
    public static final String SIGMA_0 = "sigma0";
    private NormalDistribution hyperprior;
    private Parameter posteriorMean;
    private Parameter posteriorVariance;
    private double sigma;
    private ArrayList<Double> dataValues;
    private double[] currentParameters;
    public static XMLObjectParser PARSER = new AbstractXMLObjectParser() { // from class: dr.evomodel.epidemiology.casetocase.periodpriors.KnownVarianceNormalPeriodPriorDistribution.3
        private final XMLSyntaxRule[] rules = {AttributeRule.newBooleanRule("log", true), AttributeRule.newStringRule("id", false), AttributeRule.newDoubleRule(KnownVarianceNormalPeriodPriorDistribution.MU_0, false), AttributeRule.newDoubleRule("sigma", false), AttributeRule.newDoubleRule(KnownVarianceNormalPeriodPriorDistribution.SIGMA_0, false)};

        @Override // dr.xml.XMLObjectParser
        public String getParserName() {
            return KnownVarianceNormalPeriodPriorDistribution.NORMAL;
        }

        @Override // dr.xml.AbstractXMLObjectParser
        public Object parseXMLObject(XMLObject xMLObject) throws XMLParseException {
            String str = (String) xMLObject.getAttribute("id");
            double doubleAttribute = xMLObject.getDoubleAttribute(KnownVarianceNormalPeriodPriorDistribution.MU_0);
            return new KnownVarianceNormalPeriodPriorDistribution(str, xMLObject.hasAttribute("log") ? xMLObject.getBooleanAttribute("log") : false, xMLObject.getDoubleAttribute("sigma"), doubleAttribute, xMLObject.getDoubleAttribute(KnownVarianceNormalPeriodPriorDistribution.SIGMA_0));
        }

        @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
        public XMLSyntaxRule[] getSyntaxRules() {
            return this.rules;
        }

        @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
        public String getParserDescription() {
            return "Calculates the probability of a set of doubles being drawn from the prior posterior distributionof a normal distribution of unknown mean and known standard deviation sigma";
        }

        @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
        public Class getReturnType() {
            return KnownVarianceNormalPeriodPriorDistribution.class;
        }
    };

    public KnownVarianceNormalPeriodPriorDistribution(String str, boolean z, double d, NormalDistribution normalDistribution) {
        super(str, z);
        this.hyperprior = normalDistribution;
        this.posteriorVariance = new Parameter.Default(1);
        this.posteriorMean = new Parameter.Default(1);
        addVariable(this.posteriorVariance);
        addVariable(this.posteriorMean);
        this.sigma = d;
    }

    public KnownVarianceNormalPeriodPriorDistribution(String str, boolean z, double d, double d2, double d3) {
        this(str, z, d, new NormalDistribution(d2, d3));
    }

    @Override // dr.evomodel.epidemiology.casetocase.periodpriors.AbstractPeriodPriorDistribution
    public void reset() {
        this.dataValues = new ArrayList<>();
        this.currentParameters[0] = this.hyperprior.getMean();
        this.currentParameters[1] = this.hyperprior.getSD();
        this.logL = 0.0d;
    }

    @Override // dr.evomodel.epidemiology.casetocase.periodpriors.AbstractPeriodPriorDistribution
    public double calculateLogPosteriorProbability(double d, double d2) {
        double calculateLogPosteriorPredictiveProbability = calculateLogPosteriorPredictiveProbability(d);
        if (d2 != Double.NEGATIVE_INFINITY) {
            calculateLogPosteriorPredictiveProbability -= calculateLogPosteriorPredictiveCDF(d2, true);
        }
        this.logL += calculateLogPosteriorPredictiveProbability;
        update(d);
        return calculateLogPosteriorPredictiveProbability;
    }

    @Override // dr.evomodel.epidemiology.casetocase.periodpriors.AbstractPeriodPriorDistribution
    public double calculateLogPosteriorCDF(double d, boolean z) {
        return calculateLogPosteriorPredictiveCDF(d, z);
    }

    public double calculateLogPosteriorPredictiveProbability(double d) {
        return NormalDistribution.logPdf(d, this.currentParameters[0], Math.sqrt(Math.pow(this.currentParameters[1], 2.0d) + Math.pow(this.sigma, 2.0d)));
    }

    public double calculateLogPosteriorPredictiveCDF(double d, boolean z) {
        double sqrt = (d - this.currentParameters[0]) / Math.sqrt(Math.pow(this.currentParameters[1], 2.0d) + Math.pow(this.sigma, 2.0d));
        return z ? NormalDistribution.standardCDF(-sqrt, true) : NormalDistribution.standardCDF(sqrt, true);
    }

    private void update(double d) {
        this.dataValues.add(Double.valueOf(d));
        double mean = this.hyperprior.getMean();
        double sd = this.hyperprior.getSD();
        double size = this.dataValues.size();
        double d2 = 0.0d;
        Iterator<Double> it = this.dataValues.iterator();
        while (it.hasNext()) {
            d2 += it.next().doubleValue();
        }
        double d3 = d2 / size;
        double sqrt = Math.sqrt(1.0d / ((size / Math.pow(this.sigma, 2.0d)) + (1.0d / Math.pow(sd, 2.0d))));
        this.currentParameters = new double[]{Math.pow(sqrt, 2.0d) * ((mean / Math.pow(sd, 2.0d)) + ((size * d3) / Math.pow(this.sigma, 2.0d))), sqrt};
    }

    @Override // dr.evomodel.epidemiology.casetocase.periodpriors.AbstractPeriodPriorDistribution
    public double calculateLogLikelihood(double[] dArr) {
        int length = dArr.length;
        double mean = this.hyperprior.getMean();
        double sd = this.hyperprior.getSD();
        double pow = Math.pow(this.sigma, 2.0d);
        double pow2 = Math.pow(sd, 2.0d);
        double d = 0.0d;
        double d2 = 0.0d;
        for (double d3 : dArr) {
            Double valueOf = Double.valueOf(d3);
            d += valueOf.doubleValue();
            d2 += Math.pow(valueOf.doubleValue(), 2.0d);
        }
        double d4 = d / length;
        this.posteriorMean.setParameterValue(0, ((mean / pow2) + (d / pow)) / ((1.0d / pow2) + (length / pow)));
        this.posteriorVariance.setParameterValue(0, 1.0d / ((1.0d / pow2) + (length / pow)));
        this.logL = ((((Math.log(this.sigma) - (length * Math.log(Math.sqrt(6.283185307179586d) * this.sigma))) - Math.log(Math.sqrt((length * pow2) + pow))) + ((-d2) / (2.0d * pow))) - (Math.pow(mean, 2.0d) / (2.0d * pow2))) + (((Math.pow(((sd * length) * d4) / this.sigma, 2.0d) + Math.pow((this.sigma * mean) / sd, 2.0d)) + (((2 * length) * d4) * mean)) / (2.0d * ((length * pow2) + pow)));
        return this.logL;
    }

    @Override // dr.evomodel.epidemiology.casetocase.periodpriors.AbstractPeriodPriorDistribution, dr.inference.loggers.Loggable
    public LogColumn[] getColumns() {
        ArrayList arrayList = new ArrayList(Arrays.asList(super.getColumns()));
        arrayList.add(new LogColumn.Abstract(getModelName() + "_posteriorMean") { // from class: dr.evomodel.epidemiology.casetocase.periodpriors.KnownVarianceNormalPeriodPriorDistribution.1
            @Override // dr.inference.loggers.LogColumn.Abstract
            protected String getFormattedValue() {
                return String.valueOf(KnownVarianceNormalPeriodPriorDistribution.this.posteriorMean.getParameterValue(0));
            }
        });
        arrayList.add(new LogColumn.Abstract(getModelName() + "_posteriorVariance") { // from class: dr.evomodel.epidemiology.casetocase.periodpriors.KnownVarianceNormalPeriodPriorDistribution.2
            @Override // dr.inference.loggers.LogColumn.Abstract
            protected String getFormattedValue() {
                return String.valueOf(KnownVarianceNormalPeriodPriorDistribution.this.posteriorVariance.getParameterValue(0));
            }
        });
        return (LogColumn[]) arrayList.toArray(new LogColumn[arrayList.size()]);
    }
}
