package dr.evomodel.branchmodel.lineagespecific;

import dr.app.bss.Utils;
import dr.inference.distribution.ParametricMultivariateDistributionModel;
import dr.inference.model.AbstractModelLikelihood;
import dr.inference.model.CompoundParameter;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import java.util.ArrayList;
import java.util.List;

/* loaded from: input_file:dr/evomodel/branchmodel/lineagespecific/DirichletProcessPrior.class */
public class DirichletProcessPrior extends AbstractModelLikelihood {
    private static boolean VERBOSE = false;
    private Parameter categoriesParameter;
    private CompoundParameter uniquelyRealizedParameters;
    public ParametricMultivariateDistributionModel baseModel;
    private Parameter gamma;
    private int categoryCount;
    private int N;
    private boolean likelihoodKnown;
    private double logLikelihood;
    private final List<Double> cachedLogFactorials;

    public DirichletProcessPrior(Parameter parameter, CompoundParameter compoundParameter, ParametricMultivariateDistributionModel parametricMultivariateDistributionModel, Parameter parameter2) {
        super("");
        this.likelihoodKnown = false;
        this.categoriesParameter = parameter;
        this.baseModel = parametricMultivariateDistributionModel;
        this.uniquelyRealizedParameters = compoundParameter;
        this.gamma = parameter2;
        this.categoryCount = compoundParameter.getDimension();
        this.N = parameter.getDimension();
        this.cachedLogFactorials = new ArrayList();
        this.cachedLogFactorials.add(0, Double.valueOf(0.0d));
        addVariable(this.categoriesParameter);
        addVariable(this.gamma);
        addVariable(this.uniquelyRealizedParameters);
        if (parametricMultivariateDistributionModel != null) {
            addModel(parametricMultivariateDistributionModel);
        }
        this.likelihoodKnown = false;
    }

    private double getLogFactorial(int i) {
        if (this.cachedLogFactorials.size() <= i) {
            for (int size = this.cachedLogFactorials.size() - 1; size <= i; size++) {
                this.cachedLogFactorials.add(Double.valueOf(this.cachedLogFactorials.get(size).doubleValue() + Math.log(size + 1)));
            }
        }
        return this.cachedLogFactorials.get(i).doubleValue();
    }

    private int[] getCounts() {
        int[] iArr = new int[this.categoryCount];
        for (int i = 0; i < this.N; i++) {
            int mapping = getMapping(i);
            iArr[mapping] = iArr[mapping] + 1;
        }
        return iArr;
    }

    public double getGamma() {
        return this.gamma.getParameterValue(0);
    }

    private int getMapping(int i) {
        return (int) this.categoriesParameter.getParameterValue(i);
    }

    public double getLogDensity(Parameter parameter) {
        return this.baseModel.logPdf(parameter.getAttributeValue());
    }

    public double getRealizedValuesLogDensity() {
        double d = 0.0d;
        for (int i = 0; i < this.categoryCount; i++) {
            d += getLogDensity(this.uniquelyRealizedParameters.getParameter(i));
        }
        return d;
    }

    public double getCategoriesLogDensity() {
        int[] counts = getCounts();
        if (VERBOSE) {
            Utils.printArray(counts);
        }
        double log = this.categoryCount * Math.log(getGamma());
        for (int i = 0; i < this.categoryCount; i++) {
            int i2 = counts[i];
            if (i2 > 0) {
                log += getLogFactorial(i2 - 1);
            }
        }
        for (int i3 = 1; i3 <= this.N; i3++) {
            log -= Math.log((getGamma() + i3) - 1.0d);
        }
        return log;
    }

    @Override // dr.inference.model.Likelihood
    public Model getModel() {
        return this;
    }

    @Override // dr.inference.model.Likelihood
    public double getLogLikelihood() {
        fireModelChanged();
        this.likelihoodKnown = false;
        if (!this.likelihoodKnown) {
            this.logLikelihood = calculateLogLikelihood();
            this.likelihoodKnown = true;
        }
        return this.logLikelihood;
    }

    private double calculateLogLikelihood() {
        return getCategoriesLogDensity() + getRealizedValuesLogDensity();
    }

    @Override // dr.inference.model.Likelihood
    public void makeDirty() {
    }

    @Override // dr.inference.model.AbstractModel
    protected void handleModelChangedEvent(Model model, Object obj, int i) {
        this.likelihoodKnown = false;
    }

    public int getCategoryCount() {
        return this.categoryCount;
    }

    public Parameter getUniqueParameters() {
        return this.uniquelyRealizedParameters;
    }

    public Parameter getUniqueParameter(int i) {
        return this.uniquelyRealizedParameters.getParameter(i);
    }

    @Override // dr.inference.model.AbstractModel
    protected void handleVariableChangedEvent(Variable variable, int i, Variable.ChangeType changeType) {
        if (variable == this.categoriesParameter) {
            fireModelChanged();
            return;
        }
        if (variable == this.gamma) {
            fireModelChanged();
        } else {
            if (variable != this.uniquelyRealizedParameters) {
                throw new IllegalArgumentException("Unknown parameter");
            }
            this.likelihoodKnown = false;
            fireModelChanged();
        }
    }

    public void setVerbose() {
        VERBOSE = true;
    }

    @Override // dr.inference.model.AbstractModel
    protected void storeState() {
    }

    @Override // dr.inference.model.AbstractModel
    protected void restoreState() {
        this.likelihoodKnown = false;
    }

    @Override // dr.inference.model.AbstractModel
    protected void acceptState() {
    }

    public static void main(String[] strArr) {
        testDirichletProcess(new double[]{0.0d, 1.0d, 2.0d}, 3, 1.0d, -Math.log(6.0d));
        testDirichletProcess(new double[]{0.0d, 0.0d, 1.0d}, 3, 1.0d, -Math.log(6.0d));
        testDirichletProcess(new double[]{0.0d, 1.0d, 2.0d, 3.0d, 4.0d}, 5, 0.5d, -6.851184927493743d);
    }

    private static void testDirichletProcess(double[] dArr, int i, double d, double d2) {
        Parameter.Default r0 = new Parameter.Default(dArr);
        Parameter.Default r02 = new Parameter.Default(d);
        CompoundParameter compoundParameter = new CompoundParameter("dummy");
        for (int i2 = 0; i2 < i; i2++) {
            compoundParameter.addParameter(new Parameter.Default(1.0d));
        }
        DirichletProcessPrior dirichletProcessPrior = new DirichletProcessPrior(r0, compoundParameter, null, r02);
        dirichletProcessPrior.setVerbose();
        System.out.println("lnL:          " + dirichletProcessPrior.getCategoriesLogDensity());
        System.out.println("expected lnL: " + d2);
    }
}
