package dr.inference.trace;

import dr.evomodel.continuous.BivariateTraitBranchAttributeProvider;
import dr.math.LogTricks;
import dr.math.MathUtils;
import dr.util.TaskListener;
import java.util.ArrayList;
import java.util.List;

/* loaded from: input_file:dr/inference/trace/MarginalLikelihoodAnalysis.class */
public class MarginalLikelihoodAnalysis {
    private final String traceName;
    private final List<Double> sample;
    private final long burnin;
    private final String analysisType;
    private final int bootstrapLength;
    private double logMarginalLikelihood;
    private double bootstrappedSE;
    private boolean marginalLikelihoodCalculated = false;
    private TaskListener listener = null;

    public String getTraceName() {
        return this.traceName;
    }

    public long getBurnin() {
        return this.burnin;
    }

    public MarginalLikelihoodAnalysis(List<Double> list, String str, long j, String str2, int i) {
        this.sample = list;
        this.traceName = str;
        this.burnin = j;
        this.analysisType = str2;
        this.bootstrapLength = i;
    }

    public double calculateLogMarginalLikelihood(List<Double> list) {
        return this.analysisType.equals("aicm") ? logMarginalLikelihoodAICM(list) : this.analysisType.equals("smoothed") ? logMarginalLikelihoodSmoothed(list) : this.analysisType.equals("arithmetic") ? logMarginalLikelihoodArithmetic(list) : logMarginalLikelihoodHarmonic(list);
    }

    public double logMarginalLikelihoodArithmetic(List<Double> list) {
        int size = list.size();
        double d = -1.7976931348623157E308d;
        for (int i = 0; i < size; i++) {
            if (Double.isNaN(list.get(i).doubleValue()) || Double.isInfinite(list.get(i).doubleValue())) {
                size--;
            } else {
                d = LogTricks.logSum(d, list.get(i).doubleValue());
            }
        }
        return d - StrictMath.log(size);
    }

    public double logMarginalLikelihoodHarmonic(List<Double> list) {
        double d = 0.0d;
        int size = list.size();
        for (int i = 0; i < size; i++) {
            d += list.get(i).doubleValue();
        }
        double d2 = -1.7976931348623157E308d;
        for (int i2 = 0; i2 < size; i2++) {
            d2 = LogTricks.logSum(d2, d - list.get(i2).doubleValue());
        }
        return (d - d2) + StrictMath.log(size);
    }

    public double logMarginalLikelihoodAICM(List<Double> list) {
        double d = 0.0d;
        int size = list.size();
        for (int i = 0; i < size; i++) {
            d += list.get(i).doubleValue();
        }
        double d2 = d / size;
        double d3 = 0.0d;
        for (int i2 = 0; i2 < size; i2++) {
            d3 += (list.get(i2).doubleValue() - d2) * (list.get(i2).doubleValue() - d2);
        }
        return (2.0d * (d3 / (size - 1.0d))) - (2.0d * d2);
    }

    public void calculate() {
        this.logMarginalLikelihood = calculateLogMarginalLikelihood(this.sample);
        if (this.bootstrapLength > 1) {
            int size = this.sample.size();
            ArrayList arrayList = new ArrayList();
            Double[] dArr = new Double[this.bootstrapLength];
            double d = 0.0d;
            double d2 = 0.0d;
            double d3 = 1.0d / this.bootstrapLength;
            for (int i = 0; i < this.bootstrapLength; i++) {
                fireProgress(d2);
                d2 += d3;
                int[] sampleIndicesWithReplacement = MathUtils.sampleIndicesWithReplacement(size);
                for (int i2 = 0; i2 < size; i2++) {
                    arrayList.add(this.sample.get(sampleIndicesWithReplacement[i2]));
                }
                dArr[i] = Double.valueOf(calculateLogMarginalLikelihood(arrayList));
                d += dArr[i].doubleValue();
                arrayList.clear();
            }
            double d4 = d / this.bootstrapLength;
            double d5 = 0.0d;
            for (int i3 = 0; i3 < this.bootstrapLength; i3++) {
                d5 += (dArr[i3].doubleValue() - d4) * (dArr[i3].doubleValue() - d4);
            }
            this.bootstrappedSE = Math.sqrt(d5 / (this.bootstrapLength - 1.0d));
        }
        fireProgress(1.0d);
        this.marginalLikelihoodCalculated = true;
    }

    public double logMarginalLikelihoodSmoothed(List<Double> list, double d, double d2) {
        double log = StrictMath.log(d);
        double log2 = StrictMath.log(1.0d - d);
        int size = list.size();
        double log3 = StrictMath.log(size);
        double d3 = log2 - d2;
        double d4 = (log3 + log) - log2;
        double d5 = d4 + d2;
        for (int i = 0; i < size; i++) {
            double d6 = -LogTricks.logSum(log, d3 + list.get(i).doubleValue());
            d5 = LogTricks.logSum(d5, d6 + list.get(i).doubleValue());
            d4 = LogTricks.logSum(d4, d6);
        }
        return d5 - d4;
    }

    public double getLogMarginalLikelihood() {
        if (!this.marginalLikelihoodCalculated) {
            calculate();
        }
        return this.logMarginalLikelihood;
    }

    public double getBootstrappedSE() {
        if (!this.marginalLikelihoodCalculated) {
            calculate();
        }
        return this.bootstrappedSE;
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        if (this.analysisType.equals("smoothed")) {
            sb.append("log marginal likelihood (using smoothed harmonic mean)");
        } else if (this.analysisType.equals("aicm")) {
            sb.append("AICM");
        } else if (this.analysisType.equals("arithmetic")) {
            sb.append("log marginal likelihood (using arithmetic mean)");
        } else {
            sb.append("log marginal likelihood (using harmonic mean)");
        }
        sb.append(" from ").append(this.traceName).append(" = ").append(String.format(BivariateTraitBranchAttributeProvider.FORMAT, Double.valueOf(getLogMarginalLikelihood())));
        if (this.bootstrapLength > 1) {
            sb.append(" +/- ").append(String.format(BivariateTraitBranchAttributeProvider.FORMAT, Double.valueOf(getBootstrappedSE())));
        } else {
            sb.append("           ");
        }
        sb.append(" burnin=").append(this.burnin);
        if (this.bootstrapLength > 1) {
            sb.append(" replicates=").append(this.bootstrapLength);
        }
        return sb.toString();
    }

    public double logMarginalLikelihoodSmoothed(List<Double> list) {
        double logMarginalLikelihoodHarmonic = logMarginalLikelihoodHarmonic(list);
        double d = 1.0d;
        int i = 0;
        while (Math.abs(d) > 0.001d) {
            double logMarginalLikelihoodSmoothed = logMarginalLikelihoodSmoothed(list, 0.01d, logMarginalLikelihoodHarmonic) - logMarginalLikelihoodHarmonic;
            double d2 = logMarginalLikelihoodHarmonic + logMarginalLikelihoodSmoothed;
            double d3 = logMarginalLikelihoodSmoothed * 10.0d;
            double logMarginalLikelihoodSmoothed2 = logMarginalLikelihoodSmoothed(list, 0.01d, logMarginalLikelihoodHarmonic + d3) - (logMarginalLikelihoodHarmonic + d3);
            double d4 = (logMarginalLikelihoodSmoothed2 - logMarginalLikelihoodSmoothed) / d3;
            double d5 = logMarginalLikelihoodHarmonic - (logMarginalLikelihoodSmoothed / d4);
            if (d5 < 2.0d * logMarginalLikelihoodHarmonic || d5 > 0.0d || d5 > 0.5d * logMarginalLikelihoodHarmonic) {
                d5 = logMarginalLikelihoodHarmonic + (10.0d * logMarginalLikelihoodSmoothed);
            }
            double logMarginalLikelihoodSmoothed3 = logMarginalLikelihoodSmoothed(list, 0.01d, d5) - d5;
            if (Math.abs(logMarginalLikelihoodSmoothed3) <= Math.abs(logMarginalLikelihoodSmoothed2) && (logMarginalLikelihoodSmoothed3 > 0.0d || Math.abs(d4) > 0.01d)) {
                d = d5 - logMarginalLikelihoodHarmonic;
                logMarginalLikelihoodHarmonic = d5;
            } else if (Math.abs(logMarginalLikelihoodSmoothed2) <= Math.abs(logMarginalLikelihoodSmoothed)) {
                double d6 = d2 + logMarginalLikelihoodSmoothed2;
                d = d6 - logMarginalLikelihoodHarmonic;
                logMarginalLikelihoodHarmonic = d6;
            } else {
                d = logMarginalLikelihoodSmoothed;
                logMarginalLikelihoodHarmonic += logMarginalLikelihoodSmoothed;
            }
            i++;
            if (i > 400) {
                System.err.println("Probabilities are not converging!!!");
                return -1.7976931348623157E308d;
            }
        }
        return logMarginalLikelihoodHarmonic;
    }

    public void setTaskListener(TaskListener taskListener) {
        this.listener = taskListener;
    }

    private void fireProgress(double d) {
        if (this.listener != null) {
            this.listener.progress(d);
        }
    }
}
