package dr.evomodel.treedatalikelihood.discrete;

import com.github.lbfgs4j.LbfgsMinimizer;
import com.github.lbfgs4j.liblbfgs.Function;
import com.github.lbfgs4j.liblbfgs.Lbfgs;
import com.github.lbfgs4j.liblbfgs.LbfgsConstant;
import dr.evomodelxml.tree.UniformNodeHeightPriorParser;
import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.model.Likelihood;
import dr.inference.model.Parameter;
import dr.math.MultivariateFunction;
import dr.math.NumericalDerivative;
import dr.math.matrixAlgebra.ReadableVector;
import dr.math.matrixAlgebra.Vector;
import dr.math.matrixAlgebra.WrappedVector;
import dr.util.Timer;
import dr.util.Transform;
import dr.xml.Reportable;

/* loaded from: input_file:dr/evomodel/treedatalikelihood/discrete/MaximizerWrtParameter.class */
public class MaximizerWrtParameter implements Reportable {
    private final GradientWrtParameterProvider gradient;
    private final GradientType gradientType;
    private final Parameter parameter;
    private final Likelihood likelihood;
    private final Transform transform;
    private final Function function;
    private final Settings settings;
    private double time = 0.0d;
    private long count = 0;
    private double minimumValue = Double.NaN;
    private double[] minimumPoint = null;

    /* loaded from: input_file:dr/evomodel/treedatalikelihood/discrete/MaximizerWrtParameter$GradientType.class */
    private enum GradientType {
        ANALYTIC(UniformNodeHeightPriorParser.ANALYTIC),
        NUMERICAL("numerical");

        private String type;

        GradientType(String str) {
            this.type = str;
        }

        @Override // java.lang.Enum
        public String toString() {
            return this.type;
        }
    }

    /* loaded from: input_file:dr/evomodel/treedatalikelihood/discrete/MaximizerWrtParameter$Settings.class */
    public static class Settings {
        int numberIterations;
        boolean startAtCurrentState;
        boolean printToScreen;

        public Settings(int i, boolean z, boolean z2) {
            this.numberIterations = i;
            this.startAtCurrentState = z;
            this.printToScreen = z2;
        }
    }

    public MaximizerWrtParameter(Likelihood likelihood, Parameter parameter, GradientWrtParameterProvider gradientWrtParameterProvider, Transform transform, Settings settings) {
        this.likelihood = likelihood;
        this.parameter = parameter;
        this.transform = transform;
        if (gradientWrtParameterProvider == null) {
            this.gradient = constructGradient();
            this.gradientType = GradientType.NUMERICAL;
        } else {
            this.gradient = gradientWrtParameterProvider;
            this.gradientType = GradientType.ANALYTIC;
        }
        this.function = constructFunction();
        this.settings = settings;
    }

    public Likelihood getLikelihood() {
        return this.likelihood;
    }

    public void maximize() {
        LbfgsConstant.LBFGS_Param defaultParams = Lbfgs.defaultParams();
        if (this.settings.numberIterations > 0) {
            defaultParams.max_iterations = this.settings.numberIterations;
        }
        LbfgsMinimizer lbfgsMinimizer = new LbfgsMinimizer(defaultParams, this.settings.printToScreen);
        double[] dArr = null;
        if (this.settings.startAtCurrentState) {
            dArr = this.parameter.getParameterValues();
            if (this.transform != null) {
                dArr = this.transform.transform(dArr, 0, dArr.length);
            }
        }
        Timer timer = new Timer();
        timer.start();
        this.minimumPoint = lbfgsMinimizer.minimize(this.function, dArr);
        timer.stop();
        this.time += timer.toSeconds();
        this.minimumValue = this.function.valueAt(this.minimumPoint);
        if (this.transform != null) {
            ReadableVector.Utils.setParameter(new WrappedVector.Raw(this.transform.inverse(this.minimumPoint, 0, this.minimumPoint.length)), this.parameter);
        } else {
            ReadableVector.Utils.setParameter(new WrappedVector.Raw(this.minimumPoint), this.parameter);
        }
    }

    @Override // dr.xml.Reportable
    public String getReport() {
        StringBuilder sb = new StringBuilder();
        if (this.function == null) {
            sb.append("Not yet executed.");
        } else {
            if (this.transform != null) {
                sb.append("Gradient is taken with respect to the transformed paramter values.\n");
                sb.append("Untransformed X: ").append(new Vector(this.transform.inverse(this.minimumPoint, 0, this.minimumPoint.length))).append("\n");
            }
            sb.append("X: ").append(new Vector(this.minimumPoint)).append("\n");
            sb.append("Gradient: ").append(new Vector(this.function.gradientAt(this.minimumPoint))).append("\n");
            sb.append("Gradient type: ").append(this.gradientType).append("\n");
            sb.append("Fx: ").append(this.minimumValue).append("\n");
            sb.append("Time: ").append(this.time).append("s\n");
            sb.append("Count: ").append(this.count).append("\n");
        }
        return sb.toString();
    }

    /* JADX INFO: Access modifiers changed from: private */
    public double evaluateLogLikelihood() {
        this.count++;
        return this.likelihood.getLogLikelihood();
    }

    private Function constructFunction() {
        return new Function() { // from class: dr.evomodel.treedatalikelihood.discrete.MaximizerWrtParameter.1
            @Override // com.github.lbfgs4j.liblbfgs.Function
            public int getDimension() {
                return MaximizerWrtParameter.this.gradient.getDimension();
            }

            @Override // com.github.lbfgs4j.liblbfgs.Function
            public double valueAt(double[] dArr) {
                if (MaximizerWrtParameter.this.transform != null) {
                    dArr = MaximizerWrtParameter.this.transform.inverse(dArr, 0, dArr.length);
                }
                ReadableVector.Utils.setParameter(new WrappedVector.Raw(dArr), MaximizerWrtParameter.this.parameter);
                return -MaximizerWrtParameter.this.evaluateLogLikelihood();
            }

            @Override // com.github.lbfgs4j.liblbfgs.Function
            public double[] gradientAt(double[] dArr) {
                if (MaximizerWrtParameter.this.transform != null) {
                    dArr = MaximizerWrtParameter.this.transform.inverse(dArr, 0, dArr.length);
                }
                ReadableVector.Utils.setParameter(new WrappedVector.Raw(dArr), MaximizerWrtParameter.this.parameter);
                double[] gradientLogDensity = MaximizerWrtParameter.this.gradient.getGradientLogDensity();
                if (MaximizerWrtParameter.this.transform != null) {
                    gradientLogDensity = MaximizerWrtParameter.this.transform.updateGradientUnWeightedLogDensity(gradientLogDensity, dArr, 0, dArr.length);
                }
                for (int i = 0; i < gradientLogDensity.length; i++) {
                    gradientLogDensity[i] = -gradientLogDensity[i];
                }
                return gradientLogDensity;
            }
        };
    }

    private GradientWrtParameterProvider constructGradient() {
        final MultivariateFunction multivariateFunction = new MultivariateFunction() { // from class: dr.evomodel.treedatalikelihood.discrete.MaximizerWrtParameter.2
            @Override // dr.math.MultivariateFunction
            public double evaluate(double[] dArr) {
                ReadableVector.Utils.setParameter(new WrappedVector.Raw(dArr), MaximizerWrtParameter.this.parameter);
                return MaximizerWrtParameter.this.evaluateLogLikelihood();
            }

            @Override // dr.math.MultivariateFunction
            public int getNumArguments() {
                return MaximizerWrtParameter.this.parameter.getDimension();
            }

            @Override // dr.math.MultivariateFunction
            public double getLowerBound(int i) {
                return Double.NEGATIVE_INFINITY;
            }

            @Override // dr.math.MultivariateFunction
            public double getUpperBound(int i) {
                return Double.POSITIVE_INFINITY;
            }
        };
        return new GradientWrtParameterProvider() { // from class: dr.evomodel.treedatalikelihood.discrete.MaximizerWrtParameter.3
            @Override // dr.inference.hmc.GradientWrtParameterProvider
            public Likelihood getLikelihood() {
                return MaximizerWrtParameter.this.likelihood;
            }

            @Override // dr.inference.hmc.GradientWrtParameterProvider
            public Parameter getParameter() {
                return MaximizerWrtParameter.this.parameter;
            }

            @Override // dr.inference.hmc.GradientWrtParameterProvider
            public int getDimension() {
                return MaximizerWrtParameter.this.parameter.getDimension();
            }

            @Override // dr.inference.hmc.GradientWrtParameterProvider
            public double[] getGradientLogDensity() {
                return NumericalDerivative.gradient(multivariateFunction, MaximizerWrtParameter.this.parameter.getParameterValues());
            }
        };
    }
}
