package dr.evomodel.coalescent.hmc;

import dr.evolution.coalescent.IntervalType;
import dr.evolution.coalescent.TreeIntervals;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evomodel.arg.ARGModel;
import dr.evomodel.coalescent.GMRFMultilocusSkyrideLikelihood;
import dr.evomodel.tree.TreeModel;
import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.hmc.HessianWrtParameterProvider;
import dr.inference.model.Likelihood;
import dr.inference.model.Parameter;
import dr.xml.Reportable;
import java.util.List;

/* loaded from: input_file:dr/evomodel/coalescent/hmc/GMRFGradient.class */
public class GMRFGradient implements GradientWrtParameterProvider, HessianWrtParameterProvider, Reportable {
    private final GMRFMultilocusSkyrideLikelihood skygridLikelihood;
    private final WrtParameter wrtParameter;
    private final Parameter parameter;
    private static final Double tolerance = Double.valueOf(1.0E-4d);

    /* loaded from: input_file:dr/evomodel/coalescent/hmc/GMRFGradient$WrtParameter.class */
    public enum WrtParameter {
        LOG_POPULATION_SIZES("logPopulationSizes") { // from class: dr.evomodel.coalescent.hmc.GMRFGradient.WrtParameter.1
            @Override // dr.evomodel.coalescent.hmc.GMRFGradient.WrtParameter
            Parameter getParameter(GMRFMultilocusSkyrideLikelihood gMRFMultilocusSkyrideLikelihood) {
                return gMRFMultilocusSkyrideLikelihood.getPopSizeParameter();
            }

            @Override // dr.evomodel.coalescent.hmc.GMRFGradient.WrtParameter
            double[] getGradientLogDensity(GMRFMultilocusSkyrideLikelihood gMRFMultilocusSkyrideLikelihood) {
                return gMRFMultilocusSkyrideLikelihood.getGradientWrtLogPopulationSize();
            }

            @Override // dr.evomodel.coalescent.hmc.GMRFGradient.WrtParameter
            double[] getDiagonalHessianLogDensity(GMRFMultilocusSkyrideLikelihood gMRFMultilocusSkyrideLikelihood) {
                return gMRFMultilocusSkyrideLikelihood.getDiagonalHessianWrtLogPopulationSize();
            }

            @Override // dr.evomodel.coalescent.hmc.GMRFGradient.WrtParameter
            double getParameterLowerBound() {
                return Double.NEGATIVE_INFINITY;
            }

            @Override // dr.evomodel.coalescent.hmc.GMRFGradient.WrtParameter
            public void getWarning(GMRFMultilocusSkyrideLikelihood gMRFMultilocusSkyrideLikelihood) {
            }
        },
        PRECISION("precision") { // from class: dr.evomodel.coalescent.hmc.GMRFGradient.WrtParameter.2
            @Override // dr.evomodel.coalescent.hmc.GMRFGradient.WrtParameter
            Parameter getParameter(GMRFMultilocusSkyrideLikelihood gMRFMultilocusSkyrideLikelihood) {
                return gMRFMultilocusSkyrideLikelihood.getPrecisionParameter();
            }

            @Override // dr.evomodel.coalescent.hmc.GMRFGradient.WrtParameter
            double[] getGradientLogDensity(GMRFMultilocusSkyrideLikelihood gMRFMultilocusSkyrideLikelihood) {
                return gMRFMultilocusSkyrideLikelihood.getGradientWrtPrecision();
            }

            @Override // dr.evomodel.coalescent.hmc.GMRFGradient.WrtParameter
            double[] getDiagonalHessianLogDensity(GMRFMultilocusSkyrideLikelihood gMRFMultilocusSkyrideLikelihood) {
                return gMRFMultilocusSkyrideLikelihood.getDiagonalHessianWrtPrecision();
            }

            @Override // dr.evomodel.coalescent.hmc.GMRFGradient.WrtParameter
            double getParameterLowerBound() {
                return 0.0d;
            }

            @Override // dr.evomodel.coalescent.hmc.GMRFGradient.WrtParameter
            public void getWarning(GMRFMultilocusSkyrideLikelihood gMRFMultilocusSkyrideLikelihood) {
            }
        },
        REGRESSION_COEFFICIENTS("regressionCoefficients") { // from class: dr.evomodel.coalescent.hmc.GMRFGradient.WrtParameter.3
            @Override // dr.evomodel.coalescent.hmc.GMRFGradient.WrtParameter
            Parameter getParameter(GMRFMultilocusSkyrideLikelihood gMRFMultilocusSkyrideLikelihood) {
                List<Parameter> betaListParameter = gMRFMultilocusSkyrideLikelihood.getBetaListParameter();
                if (betaListParameter.size() > 1) {
                    throw new RuntimeException("This is not the correct way of handling multidimensional parameters");
                }
                return betaListParameter.get(0);
            }

            @Override // dr.evomodel.coalescent.hmc.GMRFGradient.WrtParameter
            double[] getGradientLogDensity(GMRFMultilocusSkyrideLikelihood gMRFMultilocusSkyrideLikelihood) {
                return gMRFMultilocusSkyrideLikelihood.getGradientWrtRegressionCoefficients();
            }

            @Override // dr.evomodel.coalescent.hmc.GMRFGradient.WrtParameter
            double[] getDiagonalHessianLogDensity(GMRFMultilocusSkyrideLikelihood gMRFMultilocusSkyrideLikelihood) {
                return gMRFMultilocusSkyrideLikelihood.getDiagonalHessianWrtRegressionCoefficients();
            }

            @Override // dr.evomodel.coalescent.hmc.GMRFGradient.WrtParameter
            double getParameterLowerBound() {
                return Double.NEGATIVE_INFINITY;
            }

            @Override // dr.evomodel.coalescent.hmc.GMRFGradient.WrtParameter
            public void getWarning(GMRFMultilocusSkyrideLikelihood gMRFMultilocusSkyrideLikelihood) {
            }
        },
        NODE_HEIGHT(ARGModel.NODE_HEIGHT) { // from class: dr.evomodel.coalescent.hmc.GMRFGradient.WrtParameter.4
            Parameter parameter;

            @Override // dr.evomodel.coalescent.hmc.GMRFGradient.WrtParameter
            Parameter getParameter(GMRFMultilocusSkyrideLikelihood gMRFMultilocusSkyrideLikelihood) {
                if (this.parameter == null) {
                    this.parameter = ((TreeModel) gMRFMultilocusSkyrideLikelihood.getTree(0)).createNodeHeightsParameter(true, true, false);
                }
                return this.parameter;
            }

            @Override // dr.evomodel.coalescent.hmc.GMRFGradient.WrtParameter
            double[] getGradientLogDensity(GMRFMultilocusSkyrideLikelihood gMRFMultilocusSkyrideLikelihood) {
                return getGradientWrtNodeHeights(gMRFMultilocusSkyrideLikelihood);
            }

            @Override // dr.evomodel.coalescent.hmc.GMRFGradient.WrtParameter
            double[] getDiagonalHessianLogDensity(GMRFMultilocusSkyrideLikelihood gMRFMultilocusSkyrideLikelihood) {
                return new double[gMRFMultilocusSkyrideLikelihood.getTree(0).getInternalNodeCount()];
            }

            @Override // dr.evomodel.coalescent.hmc.GMRFGradient.WrtParameter
            double getParameterLowerBound() {
                return 0.0d;
            }

            @Override // dr.evomodel.coalescent.hmc.GMRFGradient.WrtParameter
            public void getWarning(GMRFMultilocusSkyrideLikelihood gMRFMultilocusSkyrideLikelihood) {
                if (gMRFMultilocusSkyrideLikelihood.nLoci() > 1) {
                    throw new RuntimeException("Not yet implemented for multiple loci.");
                }
            }

            private double[] getGradientWrtNodeHeights(GMRFMultilocusSkyrideLikelihood gMRFMultilocusSkyrideLikelihood) {
                gMRFMultilocusSkyrideLikelihood.getLogLikelihood();
                Tree tree = gMRFMultilocusSkyrideLikelihood.getTree(0);
                double[] dArr = new double[tree.getInternalNodeCount()];
                double[] parameterValues = gMRFMultilocusSkyrideLikelihood.getPopSizeParameter().getParameterValues();
                double populationFactor = 1.0d / gMRFMultilocusSkyrideLikelihood.getPopulationFactor(0);
                TreeIntervals treeIntervals = gMRFMultilocusSkyrideLikelihood.getTreeIntervals(0);
                int[] gridIndexForInternalNodes = getGridIndexForInternalNodes(gMRFMultilocusSkyrideLikelihood, 0);
                for (int i = 0; i < treeIntervals.getIntervalCount(); i++) {
                    if (treeIntervals.getIntervalType(i) == IntervalType.COALESCENT) {
                        int nodeHeightParameterIndex = getNodeHeightParameterIndex(treeIntervals.getCoalescentNode(i), tree);
                        dArr[nodeHeightParameterIndex] = dArr[nodeHeightParameterIndex] + ((-Math.exp(-parameterValues[gridIndexForInternalNodes[nodeHeightParameterIndex]])) * treeIntervals.getLineageCount(i) * (r0 - 1));
                        if (!tree.isRoot(treeIntervals.getCoalescentNode(i))) {
                            dArr[nodeHeightParameterIndex] = dArr[nodeHeightParameterIndex] - (((-Math.exp(-parameterValues[gridIndexForInternalNodes[nodeHeightParameterIndex] + 1])) * treeIntervals.getLineageCount(i + 1)) * (r0 - 1));
                        }
                    }
                }
                double d = 0.5d * populationFactor;
                for (int i2 = 0; i2 < dArr.length; i2++) {
                    int i3 = i2;
                    dArr[i3] = dArr[i3] * d;
                }
                return dArr;
            }

            private int getNodeHeightParameterIndex(NodeRef nodeRef, Tree tree) {
                return nodeRef.getNumber() - tree.getExternalNodeCount();
            }

            private int[] getGridIndexForInternalNodes(GMRFMultilocusSkyrideLikelihood gMRFMultilocusSkyrideLikelihood, int i) {
                Tree tree = gMRFMultilocusSkyrideLikelihood.getTree(i);
                TreeIntervals treeIntervals = gMRFMultilocusSkyrideLikelihood.getTreeIntervals(i);
                int[] iArr = new int[tree.getInternalNodeCount()];
                int i2 = 0;
                double[] gridPoints = gMRFMultilocusSkyrideLikelihood.getGridPoints();
                for (int i3 = 0; i3 < treeIntervals.getIntervalCount(); i3++) {
                    if (treeIntervals.getIntervalType(i3) == IntervalType.COALESCENT) {
                        while (gridPoints[i2] < treeIntervals.getInterval(i3)) {
                            i2++;
                        }
                        iArr[getNodeHeightParameterIndex(treeIntervals.getCoalescentNode(i3), tree)] = i2;
                    }
                }
                return iArr;
            }
        };

        private final String name;

        WrtParameter(String str) {
            this.name = str;
        }

        abstract Parameter getParameter(GMRFMultilocusSkyrideLikelihood gMRFMultilocusSkyrideLikelihood);

        abstract double[] getGradientLogDensity(GMRFMultilocusSkyrideLikelihood gMRFMultilocusSkyrideLikelihood);

        abstract double[] getDiagonalHessianLogDensity(GMRFMultilocusSkyrideLikelihood gMRFMultilocusSkyrideLikelihood);

        abstract double getParameterLowerBound();

        public abstract void getWarning(GMRFMultilocusSkyrideLikelihood gMRFMultilocusSkyrideLikelihood);

        public static WrtParameter factory(String str) {
            for (WrtParameter wrtParameter : values()) {
                if (str.equalsIgnoreCase(wrtParameter.name)) {
                    return wrtParameter;
                }
            }
            return null;
        }
    }

    public GMRFGradient(GMRFMultilocusSkyrideLikelihood gMRFMultilocusSkyrideLikelihood, WrtParameter wrtParameter) {
        this.skygridLikelihood = gMRFMultilocusSkyrideLikelihood;
        this.wrtParameter = wrtParameter;
        this.parameter = wrtParameter.getParameter(gMRFMultilocusSkyrideLikelihood);
    }

    @Override // dr.inference.hmc.GradientWrtParameterProvider
    public Likelihood getLikelihood() {
        return this.skygridLikelihood;
    }

    @Override // dr.inference.hmc.GradientWrtParameterProvider
    public Parameter getParameter() {
        return this.parameter;
    }

    @Override // dr.inference.hmc.GradientWrtParameterProvider
    public int getDimension() {
        return this.parameter.getDimension();
    }

    @Override // dr.inference.hmc.GradientWrtParameterProvider
    public double[] getGradientLogDensity() {
        return this.wrtParameter.getGradientLogDensity(this.skygridLikelihood);
    }

    @Override // dr.inference.hmc.HessianWrtParameterProvider
    public double[] getDiagonalHessianLogDensity() {
        return this.wrtParameter.getDiagonalHessianLogDensity(this.skygridLikelihood);
    }

    @Override // dr.inference.hmc.HessianWrtParameterProvider
    public double[][] getHessianLogDensity() {
        throw new RuntimeException("Not yet implemented");
    }

    @Override // dr.xml.Reportable
    public String getReport() {
        return ((this.skygridLikelihood + "." + this.wrtParameter.name + "\n") + GradientWrtParameterProvider.getReportAndCheckForError(this, this.wrtParameter.getParameterLowerBound(), Double.POSITIVE_INFINITY, tolerance) + " \n") + HessianWrtParameterProvider.getReportAndCheckForError(this, tolerance);
    }
}
