package dr.evomodel.treedatalikelihood.discrete;

import dr.evolution.tree.MutableTreeModel;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeTrait;
import dr.evomodel.branchmodel.BranchSpecificSubstitutionParameterBranchModel;
import dr.evomodel.branchratemodel.ArbitraryBranchRates;
import dr.evomodel.branchratemodel.BranchRateModel;
import dr.evomodel.substmodel.DifferentiableSubstitutionModel;
import dr.evomodel.substmodel.DifferentialMassProvider;
import dr.evomodel.tree.TreeParameterModel;
import dr.evomodel.treedatalikelihood.BeagleDataLikelihoodDelegate;
import dr.evomodel.treedatalikelihood.ProcessSimulation;
import dr.evomodel.treedatalikelihood.TreeDataLikelihood;
import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.hmc.HessianWrtParameterProvider;
import dr.inference.loggers.LogColumn;
import dr.inference.loggers.Loggable;
import dr.inference.model.CompoundParameter;
import dr.inference.model.Likelihood;
import dr.inference.model.Parameter;
import dr.math.MachineAccuracy;
import dr.math.MultivariateFunction;
import dr.math.NumericalDerivative;
import dr.math.matrixAlgebra.Vector;
import dr.xml.Reportable;
import java.util.ArrayList;

/* loaded from: input_file:dr/evomodel/treedatalikelihood/discrete/BranchSubstitutionParameterGradient.class */
public class BranchSubstitutionParameterGradient implements GradientWrtParameterProvider, HessianWrtParameterProvider, Reportable, Loggable {
    protected final TreeDataLikelihood treeDataLikelihood;
    protected final TreeTrait treeTraitProvider;
    protected final Tree tree;
    protected final boolean useHessian;
    protected final CompoundParameter branchParameter;
    private final BranchRateModel branchRateModel;
    protected final TreeParameterModel parameterIndexHelper;
    private static final boolean DEBUG = true;
    protected static final boolean COUNT_TOTAL_OPERATIONS = true;
    protected long getGradientLogDensityCount = 0;
    protected MultivariateFunction numeric;
    static final /* synthetic */ boolean $assertionsDisabled;

    public BranchSubstitutionParameterGradient(String str, TreeDataLikelihood treeDataLikelihood, BeagleDataLikelihoodDelegate beagleDataLikelihoodDelegate, CompoundParameter compoundParameter, BranchRateModel branchRateModel, boolean z) {
        this.treeDataLikelihood = treeDataLikelihood;
        this.tree = treeDataLikelihood.getTree();
        this.branchParameter = compoundParameter;
        this.branchRateModel = branchRateModel;
        this.useHessian = z;
        this.parameterIndexHelper = new TreeParameterModel((MutableTreeModel) this.tree, new Parameter.Default(this.tree.getNodeCount() - 1), false);
        String name = BranchSubstitutionParameterDelegate.getName(str);
        if (treeDataLikelihood.getTreeTrait(name) == null) {
            BranchSpecificSubstitutionParameterBranchModel branchSpecificSubstitutionParameterBranchModel = (BranchSpecificSubstitutionParameterBranchModel) beagleDataLikelihoodDelegate.getBranchModel();
            ArrayList arrayList = new ArrayList();
            for (int i = 0; i < this.parameterIndexHelper.getParameterSize(); i++) {
                NodeRef node = this.tree.getNode(this.parameterIndexHelper.getNodeNumberFromParameterIndex(i));
                DifferentiableSubstitutionModel differentiableSubstitutionModel = (DifferentiableSubstitutionModel) branchSpecificSubstitutionParameterBranchModel.getSubstitutionModel(node);
                arrayList.add(new DifferentialMassProvider.DifferentialWrapper(differentiableSubstitutionModel, differentiableSubstitutionModel.factory(compoundParameter.getParameter(node.getNumber()))));
            }
            treeDataLikelihood.addTraits(new ProcessSimulation(treeDataLikelihood, new BranchSubstitutionParameterDelegate(str, treeDataLikelihood.getTree(), beagleDataLikelihoodDelegate, treeDataLikelihood.getBranchRateModel(), new BranchDifferentialMassProvider(this.parameterIndexHelper, arrayList))).getTreeTraits());
        }
        this.treeTraitProvider = treeDataLikelihood.getTreeTrait(name);
        if (!$assertionsDisabled && this.treeTraitProvider == null) {
            throw new AssertionError();
        }
        if (treeDataLikelihood.getDataLikelihoodDelegate().getTraitCount() != 1) {
            throw new RuntimeException("Not yet implemented for >1 traits");
        }
    }

    @Override // dr.inference.hmc.HessianWrtParameterProvider
    public double[] getDiagonalHessianLogDensity() {
        throw new RuntimeException("Not yet implemented");
    }

    @Override // dr.inference.hmc.HessianWrtParameterProvider
    public double[][] getHessianLogDensity() {
        throw new RuntimeException("Not yet implemented");
    }

    @Override // dr.inference.hmc.GradientWrtParameterProvider
    public Likelihood getLikelihood() {
        return this.treeDataLikelihood;
    }

    @Override // dr.inference.hmc.GradientWrtParameterProvider
    public Parameter getParameter() {
        return this.branchParameter;
    }

    @Override // dr.inference.hmc.GradientWrtParameterProvider
    public int getDimension() {
        return this.branchParameter.getDimension();
    }

    @Override // dr.inference.hmc.GradientWrtParameterProvider
    public double[] getGradientLogDensity() {
        double[] dArr = new double[this.tree.getNodeCount() - 1];
        double[] dArr2 = (double[]) this.treeTraitProvider.getTrait(this.tree, null);
        for (int i = 0; i < dArr.length; i++) {
            NodeRef node = this.tree.getNode(this.parameterIndexHelper.getNodeNumberFromParameterIndex(i));
            dArr[this.parameterIndexHelper.getParameterIndexFromNodeNumber(node.getNumber())] = dArr2[i] * getChainGradient(this.tree, node);
        }
        this.getGradientLogDensityCount++;
        return dArr;
    }

    protected double getChainGradient(Tree tree, NodeRef nodeRef) {
        double parameterValue = this.branchParameter.getParameterValue(nodeRef.getNumber());
        if (this.branchRateModel instanceof ArbitraryBranchRates) {
            return ((ArbitraryBranchRates) this.branchRateModel).getTransform().differential(parameterValue, tree, nodeRef);
        }
        return 1.0d;
    }

    @Override // dr.inference.loggers.Loggable
    public LogColumn[] getColumns() {
        return new LogColumn[]{new LogColumn.Default("gradient report", new Object() { // from class: dr.evomodel.treedatalikelihood.discrete.BranchSubstitutionParameterGradient.1
            public String toString() {
                return "\n" + BranchSubstitutionParameterGradient.this.getReport();
            }
        })};
    }

    private MultivariateFunction numericWrap(final Parameter parameter) {
        return new MultivariateFunction() { // from class: dr.evomodel.treedatalikelihood.discrete.BranchSubstitutionParameterGradient.2
            @Override // dr.math.MultivariateFunction
            public double evaluate(double[] dArr) {
                if (!(BranchSubstitutionParameterGradient.this.branchRateModel instanceof ArbitraryBranchRates)) {
                    throw new RuntimeException("Not yet tested with ProxyParameter.");
                }
                ArbitraryBranchRates arbitraryBranchRates = (ArbitraryBranchRates) BranchSubstitutionParameterGradient.this.branchRateModel;
                Tree tree = BranchSubstitutionParameterGradient.this.treeDataLikelihood.getTree();
                for (int i = 0; i < dArr.length; i++) {
                    if (!tree.isRoot(tree.getNode(i))) {
                        arbitraryBranchRates.setBranchRate(tree, tree.getNode(i), dArr[i]);
                    }
                }
                return BranchSubstitutionParameterGradient.this.treeDataLikelihood.getLogLikelihood();
            }

            @Override // dr.math.MultivariateFunction
            public int getNumArguments() {
                return parameter.getDimension();
            }

            @Override // dr.math.MultivariateFunction
            public double getLowerBound(int i) {
                return 0.0d;
            }

            @Override // dr.math.MultivariateFunction
            public double getUpperBound(int i) {
                return Double.POSITIVE_INFINITY;
            }
        };
    }

    protected boolean valuesAreSufficientlyLarge(double[] dArr) {
        for (double d : dArr) {
            if (Math.abs(d) < MachineAccuracy.SQRT_EPSILON * 1.2d) {
                return false;
            }
        }
        return true;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public String getReport(Parameter parameter) {
        double[] parameterValues = parameter.getParameterValues();
        boolean valuesAreSufficientlyLarge = valuesAreSufficientlyLarge(parameter.getParameterValues());
        this.numeric = numericWrap(parameter);
        double[] gradient = valuesAreSufficientlyLarge ? NumericalDerivative.gradient(this.numeric, parameter.getParameterValues()) : null;
        if (this.useHessian && valuesAreSufficientlyLarge) {
            NumericalDerivative.diagonalHessian(this.numeric, parameter.getParameterValues());
        }
        for (int i = 0; i < parameterValues.length; i++) {
            parameter.setParameterValue(i, parameterValues[i]);
        }
        StringBuilder sb = new StringBuilder();
        sb.append("Gradient Peeling: ").append(new Vector(getGradientLogDensity()));
        sb.append("\n");
        if (gradient == null || !valuesAreSufficientlyLarge) {
            sb.append("Gradient mumeric: too close to 0");
        } else {
            sb.append("Gradient numeric: ").append(new Vector(gradient));
        }
        sb.append("\n");
        return sb.toString();
    }

    @Override // dr.xml.Reportable
    public String getReport() {
        return getReport(this.branchParameter);
    }

    static {
        $assertionsDisabled = !BranchSubstitutionParameterGradient.class.desiredAssertionStatus();
    }
}
