package dr.evomodel.coalescent;

import dr.evomodel.coalescent.OldAbstractCoalescentLikelihood;
import dr.evomodel.tree.TreeModel;
import dr.evomodel.treedatalikelihood.discrete.NodeHeightProxyParameter;
import dr.evomodel.treedatalikelihood.discrete.NodeHeightTransform;
import dr.evomodelxml.operators.RandomWalkIntegerNodeHeightWeightedOperatorParser;
import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.hmc.HessianWrtParameterProvider;
import dr.inference.model.Likelihood;
import dr.inference.model.Parameter;
import dr.math.MultivariateFunction;
import dr.math.NumericalDerivative;
import dr.math.matrixAlgebra.Vector;
import dr.xml.Reportable;

/* loaded from: input_file:dr/evomodel/coalescent/GMRFSkyrideGradient.class */
public class GMRFSkyrideGradient implements GradientWrtParameterProvider, HessianWrtParameterProvider, Reportable {
    private final GMRFSkyrideLikelihood skyrideLikelihood;
    private final WrtParameter wrtParameter;
    private final Parameter parameter;
    private final OldAbstractCoalescentLikelihood.IntervalNodeMapping intervalNodeMapping;
    private final NodeHeightTransform nodeHeightTransform;
    private MultivariateFunction numeric1 = new MultivariateFunction() { // from class: dr.evomodel.coalescent.GMRFSkyrideGradient.1
        @Override // dr.math.MultivariateFunction
        public double evaluate(double[] dArr) {
            if (GMRFSkyrideGradient.this.nodeHeightTransform != null) {
                GMRFSkyrideGradient.this.wrtParameter.update(GMRFSkyrideGradient.this.nodeHeightTransform, dArr);
            } else {
                for (int i = 0; i < GMRFSkyrideGradient.this.parameter.getDimension(); i++) {
                    GMRFSkyrideGradient.this.parameter.setParameterValueQuietly(i, dArr[i]);
                }
                GMRFSkyrideGradient.this.parameter.fireParameterChangedEvent();
            }
            GMRFSkyrideGradient.this.skyrideLikelihood.makeDirty();
            return GMRFSkyrideGradient.this.skyrideLikelihood.getLogLikelihood();
        }

        @Override // dr.math.MultivariateFunction
        public int getNumArguments() {
            return GMRFSkyrideGradient.this.getParameter().getDimension();
        }

        @Override // dr.math.MultivariateFunction
        public double getLowerBound(int i) {
            return 0.0d;
        }

        @Override // dr.math.MultivariateFunction
        public double getUpperBound(int i) {
            return Double.POSITIVE_INFINITY;
        }
    };

    /* loaded from: input_file:dr/evomodel/coalescent/GMRFSkyrideGradient$WrtParameter.class */
    public enum WrtParameter {
        COALESCENT_INTERVAL { // from class: dr.evomodel.coalescent.GMRFSkyrideGradient.WrtParameter.1
            @Override // dr.evomodel.coalescent.GMRFSkyrideGradient.WrtParameter
            double[] getGradientLogDensity(GMRFSkyrideLikelihood gMRFSkyrideLikelihood, OldAbstractCoalescentLikelihood.IntervalNodeMapping intervalNodeMapping) {
                double[] gradientLogDensityWrtUnsortedNodeHeight = super.getGradientLogDensityWrtUnsortedNodeHeight(gMRFSkyrideLikelihood);
                double[] dArr = new double[gradientLogDensityWrtUnsortedNodeHeight.length];
                double d = 0.0d;
                for (int length = gradientLogDensityWrtUnsortedNodeHeight.length - 1; length > -1; length--) {
                    d += gradientLogDensityWrtUnsortedNodeHeight[length];
                    dArr[length] = d;
                }
                return dArr;
            }

            @Override // dr.evomodel.coalescent.GMRFSkyrideGradient.WrtParameter
            void update(NodeHeightTransform nodeHeightTransform, double[] dArr) {
                nodeHeightTransform.inverse(dArr, 0, dArr.length);
            }
        },
        NODE_HEIGHTS { // from class: dr.evomodel.coalescent.GMRFSkyrideGradient.WrtParameter.2
            @Override // dr.evomodel.coalescent.GMRFSkyrideGradient.WrtParameter
            double[] getGradientLogDensity(GMRFSkyrideLikelihood gMRFSkyrideLikelihood, OldAbstractCoalescentLikelihood.IntervalNodeMapping intervalNodeMapping) {
                return intervalNodeMapping.sortByNodeNumbers(getGradientLogDensityWrtUnsortedNodeHeight(gMRFSkyrideLikelihood));
            }

            @Override // dr.evomodel.coalescent.GMRFSkyrideGradient.WrtParameter
            void update(NodeHeightTransform nodeHeightTransform, double[] dArr) {
                nodeHeightTransform.transform(dArr, 0, dArr.length);
            }
        };

        abstract double[] getGradientLogDensity(GMRFSkyrideLikelihood gMRFSkyrideLikelihood, OldAbstractCoalescentLikelihood.IntervalNodeMapping intervalNodeMapping);

        abstract void update(NodeHeightTransform nodeHeightTransform, double[] dArr);

        double[] getGradientLogDensityWrtUnsortedNodeHeight(GMRFSkyrideLikelihood gMRFSkyrideLikelihood) {
            double[] dArr = new double[gMRFSkyrideLikelihood.getCoalescentIntervalDimension()];
            double[] parameterValues = gMRFSkyrideLikelihood.getPopSizeParameter().getParameterValues();
            int i = 0;
            for (int i2 = 0; i2 < gMRFSkyrideLikelihood.getIntervalCount(); i2++) {
                if (gMRFSkyrideLikelihood.getIntervalType(i2) == OldAbstractCoalescentLikelihood.CoalescentEventType.COALESCENT) {
                    double lineageCount = (-Math.exp(-parameterValues[i])) * gMRFSkyrideLikelihood.getLineageCount(i2) * (gMRFSkyrideLikelihood.getLineageCount(i2) - 1);
                    if (i < gMRFSkyrideLikelihood.getCoalescentIntervalDimension() - 1 && i2 < gMRFSkyrideLikelihood.getIntervalCount() - 1) {
                        lineageCount -= ((-Math.exp(-parameterValues[i + 1])) * gMRFSkyrideLikelihood.getLineageCount(i2 + 1)) * (gMRFSkyrideLikelihood.getLineageCount(i2 + 1) - 1);
                    }
                    dArr[i] = lineageCount / 2.0d;
                    i++;
                }
            }
            return dArr;
        }
    }

    public GMRFSkyrideGradient(GMRFSkyrideLikelihood gMRFSkyrideLikelihood, WrtParameter wrtParameter, TreeModel treeModel, NodeHeightTransform nodeHeightTransform) {
        this.skyrideLikelihood = gMRFSkyrideLikelihood;
        this.intervalNodeMapping = this.skyrideLikelihood.getIntervalNodeMapping();
        this.wrtParameter = wrtParameter;
        this.nodeHeightTransform = nodeHeightTransform;
        if (nodeHeightTransform == null) {
            this.parameter = new NodeHeightProxyParameter(RandomWalkIntegerNodeHeightWeightedOperatorParser.INTERNAL_NODE_HEIGHTS, treeModel, true);
        } else {
            this.parameter = nodeHeightTransform.getParameter();
        }
    }

    @Override // dr.inference.hmc.GradientWrtParameterProvider
    public Likelihood getLikelihood() {
        return this.skyrideLikelihood;
    }

    @Override // dr.inference.hmc.GradientWrtParameterProvider
    public Parameter getParameter() {
        return this.parameter;
    }

    @Override // dr.inference.hmc.GradientWrtParameterProvider
    public int getDimension() {
        return getParameter().getDimension();
    }

    @Override // dr.inference.hmc.GradientWrtParameterProvider
    public double[] getGradientLogDensity() {
        return this.wrtParameter.getGradientLogDensity(this.skyrideLikelihood, this.intervalNodeMapping);
    }

    @Override // dr.xml.Reportable
    public String getReport() {
        double[] parameterValues = getParameter().getParameterValues();
        double[] gradient = NumericalDerivative.gradient(this.numeric1, getParameter().getParameterValues());
        for (int i = 0; i < parameterValues.length; i++) {
            getParameter().setParameterValue(i, parameterValues[i]);
        }
        double[] diagonalHessian = NumericalDerivative.diagonalHessian(this.numeric1, getParameter().getParameterValues());
        for (int i2 = 0; i2 < parameterValues.length; i2++) {
            getParameter().setParameterValue(i2, parameterValues[i2]);
        }
        StringBuilder sb = new StringBuilder();
        sb.append("analytic: ").append(new Vector(getGradientLogDensity()));
        sb.append("\n");
        sb.append("numeric: ").append(new Vector(gradient));
        sb.append("\n");
        sb.append("analytic diagonal Hessian: ").append(new Vector(getDiagonalHessianLogDensity()));
        sb.append("\n");
        sb.append("numeric diagonal Hessian: ").append(new Vector(diagonalHessian));
        sb.append("\n");
        return sb.toString();
    }

    @Override // dr.inference.hmc.HessianWrtParameterProvider
    public double[] getDiagonalHessianLogDensity() {
        return new double[getDimension()];
    }

    @Override // dr.inference.hmc.HessianWrtParameterProvider
    public double[][] getHessianLogDensity() {
        throw new RuntimeException("Not yet implemented!");
    }
}
