package dr.evomodel.treedatalikelihood.continuous;

import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeTrait;
import dr.evomodel.treedatalikelihood.TreeDataLikelihood;
import dr.evomodel.treedatalikelihood.preorder.NewTipFullConditionalDistributionDelegate;
import dr.evomodel.treedatalikelihood.preorder.TipFullConditionalDistributionDelegate;
import dr.evomodel.treedatalikelihood.preorder.TipGradientViaFullConditionalDelegate;
import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.model.Likelihood;
import dr.inference.model.Parameter;
import dr.math.matrixAlgebra.Vector;
import dr.xml.Reportable;

/* loaded from: input_file:dr/evomodel/treedatalikelihood/continuous/TreeTipGradient.class */
public class TreeTipGradient implements GradientWrtParameterProvider, Reportable {
    private final TreeDataLikelihood treeDataLikelihood;
    private final TreeTrait treeTraitProvider;
    private final Tree tree;
    private final Parameter traitParameter;
    private final int nTaxa;
    private final int nTraits;
    private final int dimTrait;
    private final Parameter maskParameter;
    static final /* synthetic */ boolean $assertionsDisabled;

    public TreeTipGradient(String str, TreeDataLikelihood treeDataLikelihood, ContinuousDataLikelihoodDelegate continuousDataLikelihoodDelegate, Parameter parameter) {
        if (!$assertionsDisabled && treeDataLikelihood == null) {
            throw new AssertionError();
        }
        this.treeDataLikelihood = treeDataLikelihood;
        this.tree = treeDataLikelihood.getTree();
        this.maskParameter = parameter;
        String name = TipGradientViaFullConditionalDelegate.getName(str);
        if (treeDataLikelihood.getTreeTrait(name) == null) {
            continuousDataLikelihoodDelegate.addFullConditionalGradientTrait(str);
        }
        if (treeDataLikelihood.getTreeTrait(TipFullConditionalDistributionDelegate.getName(str)) == null) {
            continuousDataLikelihoodDelegate.addFullConditionalDensityTrait(str);
        }
        if (treeDataLikelihood.getTreeTrait(NewTipFullConditionalDistributionDelegate.getName(str)) == null) {
            continuousDataLikelihoodDelegate.addNewFullConditionalDensityTrait(str);
        }
        this.treeTraitProvider = treeDataLikelihood.getTreeTrait(name);
        if (!$assertionsDisabled && this.treeTraitProvider == null) {
            throw new AssertionError();
        }
        this.nTaxa = treeDataLikelihood.getTree().getExternalNodeCount();
        this.nTraits = treeDataLikelihood.getDataLikelihoodDelegate().getTraitCount();
        this.dimTrait = treeDataLikelihood.getDataLikelihoodDelegate().getTraitDim();
        if (this.nTraits != 1) {
            throw new RuntimeException("Not yet implemented for >1 traits");
        }
        this.traitParameter = continuousDataLikelihoodDelegate.getDataModel().getParameter();
        if (parameter != null && parameter.getDimension() != this.traitParameter.getDimension()) {
            throw new RuntimeException("Trait and mask parameters must be the same size");
        }
    }

    @Override // dr.inference.hmc.GradientWrtParameterProvider
    public Likelihood getLikelihood() {
        return this.treeDataLikelihood;
    }

    @Override // dr.inference.hmc.GradientWrtParameterProvider
    public Parameter getParameter() {
        return this.traitParameter;
    }

    @Override // dr.inference.hmc.GradientWrtParameterProvider
    public int getDimension() {
        return getParameter().getDimension();
    }

    @Override // dr.inference.hmc.GradientWrtParameterProvider
    public double[] getGradientLogDensity() {
        double[] dArr = new double[this.nTaxa * this.dimTrait * this.nTraits];
        int i = 0;
        for (int i2 = 0; i2 < this.nTaxa; i2++) {
            double[] dArr2 = (double[]) this.treeTraitProvider.getTrait(this.tree, this.tree.getExternalNode(i2));
            System.arraycopy(dArr2, 0, dArr, i, dArr2.length);
            i += dArr2.length;
        }
        if (this.maskParameter != null) {
            for (int i3 = 0; i3 < this.maskParameter.getDimension(); i3++) {
                if (this.maskParameter.getParameterValue(i3) == 0.0d) {
                    dArr[i3] = 0.0d;
                }
            }
        }
        return dArr;
    }

    @Override // dr.xml.Reportable
    public String getReport() {
        return new Vector(getGradientLogDensity()).toString();
    }

    static {
        $assertionsDisabled = !TreeTipGradient.class.desiredAssertionStatus();
    }
}
