package dr.evomodel.treedatalikelihood.continuous;

import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evomodel.branchratemodel.BranchRateModel;
import dr.evomodel.branchratemodel.StrictClockBranchRates;
import dr.evomodel.continuous.MultivariateDiffusionModel;
import dr.evomodel.treedatalikelihood.continuous.cdi.ContinuousDiffusionIntegrator;
import dr.inference.model.Model;
import dr.math.KroneckerOperation;
import java.util.Iterator;
import java.util.List;
import org.ejml.data.DenseMatrix64F;

/* loaded from: input_file:dr/evomodel/treedatalikelihood/continuous/AbstractDriftDiffusionModelDelegate.class */
public abstract class AbstractDriftDiffusionModelDelegate extends AbstractDiffusionModelDelegate {
    private final List<BranchRateModel> branchRateModels;

    /* JADX INFO: Access modifiers changed from: package-private */
    public AbstractDriftDiffusionModelDelegate(Tree tree, MultivariateDiffusionModel multivariateDiffusionModel, List<BranchRateModel> list, int i) {
        super(tree, multivariateDiffusionModel, i);
        this.branchRateModels = list;
        if (list != null) {
            Iterator<BranchRateModel> it = list.iterator();
            while (it.hasNext()) {
                addModel(it.next());
            }
            if (list.size() != this.dim) {
                throw new IllegalArgumentException("Invalid dimensions");
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // dr.evomodel.treedatalikelihood.continuous.AbstractDiffusionModelDelegate, dr.inference.model.AbstractModel
    public void handleModelChangedEvent(Model model, Object obj, int i) {
        if (this.branchRateModels.contains(model)) {
            fireModelChanged(model);
        } else {
            super.handleModelChangedEvent(model, obj, i);
        }
    }

    @Override // dr.evomodel.treedatalikelihood.continuous.AbstractDiffusionModelDelegate, dr.evomodel.treedatalikelihood.continuous.DiffusionProcessDelegate
    public boolean hasDrift() {
        return true;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // dr.evomodel.treedatalikelihood.continuous.AbstractDiffusionModelDelegate
    public double[] getDriftRates(int[] iArr, int i) {
        double[] dArr = new double[i * this.dim];
        if (this.branchRateModels != null) {
            int i2 = 0;
            for (int i3 = 0; i3 < i; i3++) {
                NodeRef node = this.tree.getNode(iArr[i3]);
                for (int i4 = 0; i4 < this.dim; i4++) {
                    dArr[i2] = this.branchRateModels.get(i4).getBranchRate(this.tree, node);
                    i2++;
                }
            }
        }
        return dArr;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public double[] getDriftRate(NodeRef nodeRef) {
        double[] dArr = new double[this.dim];
        if (this.branchRateModels != null) {
            for (int i = 0; i < this.dim; i++) {
                dArr[i] = this.branchRateModels.get(i).getBranchRate(this.tree, nodeRef);
            }
        }
        return dArr;
    }

    public boolean isConstantDrift() {
        if (this.branchRateModels == null) {
            return false;
        }
        for (int i = 0; i < this.dim; i++) {
            if (!(this.branchRateModels.get(i) instanceof StrictClockBranchRates)) {
                return false;
            }
        }
        return true;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public DenseMatrix64F getGradientDisplacementWrtDrift(NodeRef nodeRef, ContinuousDiffusionIntegrator continuousDiffusionIntegrator, ContinuousDataLikelihoodDelegate continuousDataLikelihoodDelegate, DenseMatrix64F denseMatrix64F) {
        return scaleGradient(nodeRef, continuousDiffusionIntegrator, continuousDataLikelihoodDelegate, denseMatrix64F);
    }

    @Override // dr.evomodel.treedatalikelihood.continuous.DiffusionProcessDelegate
    public double[] getAccumulativeDrift(NodeRef nodeRef, double[] dArr, ContinuousDiffusionIntegrator continuousDiffusionIntegrator, int i) {
        double[] dArr2 = new double[i];
        System.arraycopy(dArr, 0, dArr2, 0, dArr.length);
        double[] dArr3 = new double[i];
        double[] dArr4 = null;
        if (hasActualization()) {
            dArr4 = hasDiagonalActualization() ? new double[i] : new double[i * i];
        }
        recursivelyAccumulateDrift(nodeRef, dArr2, continuousDiffusionIntegrator, dArr3, dArr4, i);
        return dArr2;
    }

    private void recursivelyAccumulateDrift(NodeRef nodeRef, double[] dArr, ContinuousDiffusionIntegrator continuousDiffusionIntegrator, double[] dArr2, double[] dArr3, int i) {
        if (this.tree.isRoot(nodeRef)) {
            return;
        }
        recursivelyAccumulateDrift(this.tree.getParent(nodeRef), dArr, continuousDiffusionIntegrator, dArr2, dArr3, i);
        continuousDiffusionIntegrator.getBranchDisplacement(getMatrixBufferOffsetIndex(nodeRef.getNumber()), dArr2);
        if (hasActualization()) {
            continuousDiffusionIntegrator.getBranchActualization(getMatrixBufferOffsetIndex(nodeRef.getNumber()), dArr3);
        }
        double[] dArr4 = new double[i];
        continuousDiffusionIntegrator.getBranchExpectation(dArr3, dArr, dArr2, dArr4);
        System.arraycopy(dArr4, 0, dArr, 0, i);
    }

    @Override // dr.evomodel.treedatalikelihood.continuous.DiffusionProcessDelegate
    public double[][] getJointVariance(double d, double[][] dArr, double[][] dArr2, double[][] dArr3) {
        return KroneckerOperation.product(dArr, dArr3);
    }
}
