package dr.inference.hmc;

import dr.inference.model.CompoundLikelihood;
import dr.inference.model.DerivativeOrder;
import dr.inference.model.Likelihood;
import dr.inference.model.Parameter;
import dr.inferencexml.hmc.GradientWrapperParser;
import dr.xml.Reportable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

/* loaded from: input_file:dr/inference/hmc/SumDerivative.class */
public class SumDerivative implements GradientWrtParameterProvider, HessianWrtParameterProvider, DerivativeWrtParameterProvider, Reportable {
    private final int dimension;
    private final Likelihood likelihood;
    private final Parameter parameter;
    private final List<GradientWrtParameterProvider> derivativeList;
    private final List<DerivativeWrtParameterProvider> newDerivativeList;
    private final DerivativeOrder highestOrder;
    private static final boolean DEBUG = false;
    private static final boolean DEBUG_KILL = false;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:dr/inference/hmc/SumDerivative$DerivativeType.class */
    public enum DerivativeType {
        GRADIENT(GradientWrapperParser.NAME) { // from class: dr.inference.hmc.SumDerivative.DerivativeType.1
            @Override // dr.inference.hmc.SumDerivative.DerivativeType
            public double[] getDerivativeLogDensity(GradientWrtParameterProvider gradientWrtParameterProvider) {
                return gradientWrtParameterProvider.getGradientLogDensity();
            }
        },
        DIAGONAL_HESSIAN("diagonalHessian") { // from class: dr.inference.hmc.SumDerivative.DerivativeType.2
            @Override // dr.inference.hmc.SumDerivative.DerivativeType
            public double[] getDerivativeLogDensity(GradientWrtParameterProvider gradientWrtParameterProvider) {
                return ((HessianWrtParameterProvider) gradientWrtParameterProvider).getDiagonalHessianLogDensity();
            }
        };

        private String type;

        DerivativeType(String str) {
            this.type = str;
        }

        public abstract double[] getDerivativeLogDensity(GradientWrtParameterProvider gradientWrtParameterProvider);
    }

    public SumDerivative(List<GradientWrtParameterProvider> list) {
        this.derivativeList = list;
        GradientWrtParameterProvider gradientWrtParameterProvider = list.get(0);
        this.dimension = gradientWrtParameterProvider.getDimension();
        this.parameter = gradientWrtParameterProvider.getParameter();
        if (list.size() == 1) {
            this.likelihood = gradientWrtParameterProvider.getLikelihood();
        } else {
            ArrayList arrayList = new ArrayList();
            for (GradientWrtParameterProvider gradientWrtParameterProvider2 : list) {
                if (gradientWrtParameterProvider2.getDimension() != this.dimension) {
                    throw new RuntimeException("Unequal parameter dimensions");
                }
                if (!Arrays.equals(gradientWrtParameterProvider2.getParameter().getParameterValues(), this.parameter.getParameterValues())) {
                    throw new RuntimeException("Unequal parameter values");
                }
                for (Likelihood likelihood : gradientWrtParameterProvider2.getLikelihood().getLikelihoodSet()) {
                    if (!arrayList.contains(likelihood)) {
                        arrayList.add(likelihood);
                    }
                }
            }
            this.likelihood = new CompoundLikelihood(arrayList);
        }
        this.newDerivativeList = new ArrayList();
        for (GradientWrtParameterProvider gradientWrtParameterProvider3 : list) {
            if (gradientWrtParameterProvider3 instanceof DerivativeWrtParameterProvider) {
                this.newDerivativeList.add((DerivativeWrtParameterProvider) gradientWrtParameterProvider3);
            }
        }
        this.highestOrder = DerivativeWrtParameterProvider.getHighestOrder(this.newDerivativeList);
    }

    @Override // dr.inference.hmc.GradientWrtParameterProvider
    public Likelihood getLikelihood() {
        return this.likelihood;
    }

    @Override // dr.inference.hmc.GradientWrtParameterProvider
    public Parameter getParameter() {
        return this.parameter;
    }

    @Override // dr.inference.hmc.DerivativeWrtParameterProvider
    public int getDimension(DerivativeOrder derivativeOrder) {
        return derivativeOrder.getDerivativeDimension(this.dimension);
    }

    @Override // dr.inference.hmc.GradientWrtParameterProvider
    public int getDimension() {
        return this.dimension;
    }

    @Override // dr.inference.hmc.DerivativeWrtParameterProvider
    public double[] getDerivativeLogDensity(DerivativeOrder derivativeOrder) {
        if (!$assertionsDisabled && this.highestOrder.getValue() < derivativeOrder.getValue()) {
            throw new AssertionError();
        }
        int size = this.newDerivativeList.size();
        double[] derivativeLogDensity = this.newDerivativeList.get(0).getDerivativeLogDensity(derivativeOrder);
        for (int i = 1; i < size; i++) {
            double[] derivativeLogDensity2 = this.newDerivativeList.get(i).getDerivativeLogDensity(derivativeOrder);
            for (int i2 = 0; i2 < derivativeLogDensity2.length; i2++) {
                int i3 = i2;
                derivativeLogDensity[i3] = derivativeLogDensity[i3] + derivativeLogDensity2[i2];
            }
        }
        return derivativeLogDensity;
    }

    @Override // dr.inference.hmc.DerivativeWrtParameterProvider
    public DerivativeOrder getHighestOrder() {
        return this.highestOrder;
    }

    @Override // dr.inference.hmc.HessianWrtParameterProvider
    public double[] getDiagonalHessianLogDensity() {
        return getDerivativeLogDensity(DerivativeType.DIAGONAL_HESSIAN);
    }

    @Override // dr.inference.hmc.HessianWrtParameterProvider
    public double[][] getHessianLogDensity() {
        if (!$assertionsDisabled && !(this.derivativeList.get(0) instanceof HessianWrtParameterProvider)) {
            throw new AssertionError();
        }
        int size = this.derivativeList.size();
        double[][] hessianLogDensity = ((HessianWrtParameterProvider) this.derivativeList.get(0)).getHessianLogDensity();
        for (int i = 1; i < size; i++) {
            if (!$assertionsDisabled && !(this.derivativeList.get(i) instanceof HessianWrtParameterProvider)) {
                throw new AssertionError();
            }
            double[][] hessianLogDensity2 = ((HessianWrtParameterProvider) this.derivativeList.get(i)).getHessianLogDensity();
            for (int i2 = 0; i2 < hessianLogDensity2[0].length; i2++) {
                for (int i3 = 0; i3 < hessianLogDensity2[0].length; i3++) {
                    double[] dArr = hessianLogDensity[i2];
                    int i4 = i3;
                    dArr[i4] = dArr[i4] + hessianLogDensity2[i2][i3];
                }
            }
        }
        return hessianLogDensity;
    }

    private double[] getDerivativeLogDensity(DerivativeType derivativeType) {
        int size = this.derivativeList.size();
        double[] derivativeLogDensity = derivativeType.getDerivativeLogDensity(this.derivativeList.get(0));
        for (int i = 1; i < size; i++) {
            double[] derivativeLogDensity2 = derivativeType.getDerivativeLogDensity(this.derivativeList.get(i));
            for (int i2 = 0; i2 < derivativeLogDensity2.length; i2++) {
                int i3 = i2;
                derivativeLogDensity[i3] = derivativeLogDensity[i3] + derivativeLogDensity2[i2];
            }
        }
        return derivativeLogDensity;
    }

    @Override // dr.inference.hmc.GradientWrtParameterProvider
    public double[] getGradientLogDensity() {
        return getDerivativeLogDensity(DerivativeType.GRADIENT);
    }

    @Override // dr.xml.Reportable
    public String getReport() {
        return "jointGradient." + this.parameter.getParameterName() + "\n" + GradientWrtParameterProvider.getReportAndCheckForError(this, Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, GradientWrtParameterProvider.TOLERANCE);
    }

    static {
        $assertionsDisabled = !SumDerivative.class.desiredAssertionStatus();
    }
}
