package dr.evomodel.treedatalikelihood.preorder;

import beagle.Beagle;
import dr.evolution.alignment.PatternList;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeTrait;
import dr.evolution.tree.TreeTraitProvider;
import dr.evomodel.siteratemodel.SiteRateModel;
import dr.evomodel.treedatalikelihood.BeagleDataLikelihoodDelegate;
import dr.evomodel.treedatalikelihood.EvolutionaryProcessDelegate;
import dr.evomodel.treedatalikelihood.ProcessOnTreeDelegate;
import dr.evomodel.treedatalikelihood.preorder.ProcessSimulationDelegate;
import dr.inference.model.Model;
import dr.math.matrixAlgebra.WrappedVector;
import java.util.List;

/* loaded from: input_file:dr/evomodel/treedatalikelihood/preorder/AbstractDiscreteTraitDelegate.class */
public abstract class AbstractDiscreteTraitDelegate extends ProcessSimulationDelegate.AbstractDelegate {
    private static final String GRADIENT_TRAIT_NAME = "Gradient";
    private static final String HESSIAN_TRAIT_NAME = "Hessian";
    private static final boolean DEBUG_TRANSPOSE = false;
    private static final boolean USE_CACHE = true;
    protected final BeagleDataLikelihoodDelegate likelihoodDelegate;

    /* renamed from: beagle, reason: collision with root package name */
    protected final Beagle f7beagle;
    protected EvolutionaryProcessDelegate evolutionaryProcessDelegate;
    protected final SiteRateModel siteRateModel;
    protected final PatternList patternList;
    protected final int patternCount;
    protected final int stateCount;
    protected final int categoryCount;
    private int preOrderPartialOffset;
    protected final double[] gradient;
    private boolean substitutionProcessKnown;
    private static final boolean COUNT_TOTAL_OPERATIONS = true;
    private long simulateCount;
    private long getTraitCount;
    private long updatePrePartialCount;
    static final /* synthetic */ boolean $assertionsDisabled;

    public AbstractDiscreteTraitDelegate(String str, Tree tree, BeagleDataLikelihoodDelegate beagleDataLikelihoodDelegate) {
        super(str, tree);
        this.simulateCount = 0L;
        this.getTraitCount = 0L;
        this.updatePrePartialCount = 0L;
        this.likelihoodDelegate = beagleDataLikelihoodDelegate;
        this.f7beagle = beagleDataLikelihoodDelegate.getBeagleInstance();
        if (!$assertionsDisabled && !this.likelihoodDelegate.isUsePreOrder()) {
            throw new AssertionError();
        }
        this.evolutionaryProcessDelegate = beagleDataLikelihoodDelegate.getEvolutionaryProcessDelegate();
        this.siteRateModel = beagleDataLikelihoodDelegate.getSiteRateModel();
        this.patternCount = beagleDataLikelihoodDelegate.getPatternList().getPatternCount();
        this.stateCount = beagleDataLikelihoodDelegate.getPatternList().getDataType().getStateCount();
        this.categoryCount = this.siteRateModel.getCategoryCount();
        this.preOrderPartialOffset = beagleDataLikelihoodDelegate.getPartialBufferCount();
        this.patternList = beagleDataLikelihoodDelegate.getPatternList();
        this.gradient = new double[tree.getNodeCount() - 1];
        beagleDataLikelihoodDelegate.addModelListener(this);
        beagleDataLikelihoodDelegate.addModelRestoreListener(this);
        this.substitutionProcessKnown = false;
    }

    private void printMatrix(double[] dArr) {
        for (int i = 0; i < this.siteRateModel.getCategoryCount(); i++) {
            System.err.println("\nRate = " + i);
            for (int i2 = 0; i2 < this.stateCount; i2++) {
                double[] dArr2 = new double[this.stateCount];
                System.arraycopy(dArr, (i * this.stateCount * this.stateCount) + (i2 * this.stateCount), dArr2, 0, this.stateCount);
                System.err.println(new WrappedVector.Raw(dArr2));
            }
        }
    }

    private void debugMatrixTranspose(int[] iArr) {
        double[] dArr = new double[this.stateCount * this.stateCount * this.siteRateModel.getCategoryCount()];
        int i = iArr[4];
        this.f7beagle.getTransitionMatrix(i, dArr);
        printMatrix(dArr);
        this.f7beagle.transposeTransitionMatrices(new int[]{i}, new int[]{1}, 1);
        this.f7beagle.getTransitionMatrix(1, dArr);
        printMatrix(dArr);
    }

    @Override // dr.evomodel.treedatalikelihood.preorder.ProcessSimulationDelegate.AbstractDelegate, dr.evomodel.treedatalikelihood.preorder.ProcessSimulationDelegate
    public void simulate(int[] iArr, int i, int i2) {
        simulateRoot(i2);
        this.f7beagle.updatePrePartials(iArr, i, -1);
        getNodeDerivatives(this.tree, this.gradient, null);
        this.simulateCount++;
        this.updatePrePartialCount += i;
    }

    @Override // dr.evomodel.treedatalikelihood.preorder.ProcessSimulationDelegate.AbstractDelegate
    public void setupStatistics() {
        throw new RuntimeException("Not used (?) with BEAGLE");
    }

    @Override // dr.evomodel.treedatalikelihood.preorder.ProcessSimulationDelegate.AbstractDelegate
    protected void simulateRoot(int i) {
        double[] rootStateFrequencies = this.evolutionaryProcessDelegate.getRootStateFrequencies();
        double[] dArr = new double[this.stateCount * this.patternCount * this.categoryCount];
        for (int i2 = 0; i2 < this.patternCount * this.categoryCount; i2++) {
            System.arraycopy(rootStateFrequencies, 0, dArr, i2 * this.stateCount, this.stateCount);
        }
        this.f7beagle.setPartials(getPreOrderPartialIndex(i), dArr);
    }

    @Override // dr.evomodel.treedatalikelihood.preorder.ProcessSimulationDelegate.AbstractDelegate
    protected void simulateNode(int i, int i2, int i3, int i4, int i5) {
        throw new RuntimeException("Not used with BEAGLE");
    }

    protected String getGradientTraitName() {
        return GRADIENT_TRAIT_NAME;
    }

    protected String getHessianTraitName() {
        return HESSIAN_TRAIT_NAME;
    }

    @Override // dr.evomodel.treedatalikelihood.preorder.ProcessSimulationDelegate.AbstractDelegate
    protected void constructTraits(TreeTraitProvider.Helper helper) {
        helper.addTrait(new TreeTrait.DA() { // from class: dr.evomodel.treedatalikelihood.preorder.AbstractDiscreteTraitDelegate.1
            @Override // dr.evolution.tree.TreeTrait
            public String getTraitName() {
                return AbstractDiscreteTraitDelegate.this.getGradientTraitName();
            }

            @Override // dr.evolution.tree.TreeTrait
            public TreeTrait.Intent getIntent() {
                return TreeTrait.Intent.BRANCH;
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // dr.evolution.tree.TreeTrait
            public double[] getTrait(Tree tree, NodeRef nodeRef) {
                return AbstractDiscreteTraitDelegate.this.getGradient(nodeRef);
            }
        });
        helper.addTrait(new TreeTrait.DA() { // from class: dr.evomodel.treedatalikelihood.preorder.AbstractDiscreteTraitDelegate.2
            @Override // dr.evolution.tree.TreeTrait
            public String getTraitName() {
                return AbstractDiscreteTraitDelegate.this.getHessianTraitName();
            }

            @Override // dr.evolution.tree.TreeTrait
            public TreeTrait.Intent getIntent() {
                return TreeTrait.Intent.BRANCH;
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // dr.evolution.tree.TreeTrait
            public double[] getTrait(Tree tree, NodeRef nodeRef) {
                return AbstractDiscreteTraitDelegate.this.getHessian(tree, nodeRef);
            }
        });
    }

    /* JADX INFO: Access modifiers changed from: private */
    public double[] getHessian(Tree tree, NodeRef nodeRef) {
        this.simulationProcess.cacheSimulatedTraits(nodeRef);
        double[] dArr = new double[tree.getNodeCount() - 1];
        getNodeDerivatives(tree, null, dArr);
        return dArr;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public double[] getGradient(NodeRef nodeRef) {
        this.getTraitCount++;
        this.simulationProcess.cacheSimulatedTraits(nodeRef);
        return (double[]) this.gradient.clone();
    }

    protected abstract void cacheDifferentialMassMatrix(Tree tree, boolean z);

    private void getNodeDerivatives(Tree tree, double[] dArr, double[] dArr2) {
        int[] iArr = new int[tree.getNodeCount() - 1];
        int[] iArr2 = new int[tree.getNodeCount() - 1];
        int[] iArr3 = new int[tree.getNodeCount() - 1];
        int[] iArr4 = new int[tree.getNodeCount() - 1];
        if ((this.substitutionProcessKnown && dArr2 == null) ? false : true) {
            cacheDifferentialMassMatrix(tree, dArr2 != null);
            this.substitutionProcessKnown = true;
        }
        int i = 0;
        for (int i2 = 0; i2 < tree.getNodeCount(); i2++) {
            if (!tree.isRoot(tree.getNode(i2))) {
                iArr[i] = getPostOrderPartialIndex(i2);
                iArr2[i] = getPreOrderPartialIndex(i2);
                iArr3[i] = getFirstDerivativeMatrixBufferIndex(i2);
                iArr4[i] = getSecondDerivativeMatrixBufferIndex(i2);
                i++;
            }
        }
        double[] dArr3 = dArr2 != null ? new double[dArr2.length] : null;
        this.f7beagle.calculateEdgeDifferentials(iArr, iArr2, iArr3, new int[]{0}, tree.getNodeCount() - 1, null, dArr, dArr3);
        if (dArr2 != null) {
            this.f7beagle.calculateEdgeDifferentials(iArr, iArr2, iArr4, new int[]{0}, tree.getNodeCount() - 1, null, dArr2, null);
            for (int i3 = 0; i3 < dArr2.length; i3++) {
                int i4 = i3;
                dArr2[i4] = dArr2[i4] - dArr3[i3];
            }
        }
    }

    protected int getFirstDerivativeMatrixBufferIndex(int i) {
        return this.evolutionaryProcessDelegate.getInfinitesimalMatrixBufferIndex(i);
    }

    protected int getSecondDerivativeMatrixBufferIndex(int i) {
        return this.evolutionaryProcessDelegate.getInfinitesimalSquaredMatrixBufferIndex(i);
    }

    @Override // dr.inference.model.ModelListener
    public void modelChangedEvent(Model model, Object obj, int i) {
        this.substitutionProcessKnown = false;
    }

    @Override // dr.inference.model.ModelListener
    public void modelRestored(Model model) {
        this.substitutionProcessKnown = false;
    }

    @Override // dr.evomodel.treedatalikelihood.preorder.ProcessSimulationDelegate
    public int vectorizeNodeOperations(List<ProcessOnTreeDelegate.NodeOperation> list, int[] iArr) {
        int i = 0;
        for (ProcessOnTreeDelegate.NodeOperation nodeOperation : list) {
            int i2 = i;
            int i3 = i + 1;
            iArr[i2] = getPreOrderPartialIndex(nodeOperation.getLeftChild());
            int i4 = i3 + 1;
            iArr[i3] = -1;
            int i5 = i4 + 1;
            iArr[i4] = -1;
            int i6 = i5 + 1;
            iArr[i5] = getPreOrderPartialIndex(nodeOperation.getNodeNumber());
            int i7 = i6 + 1;
            iArr[i6] = this.evolutionaryProcessDelegate.getMatrixIndex(nodeOperation.getLeftChild());
            int i8 = i7 + 1;
            iArr[i7] = getPostOrderPartialIndex(nodeOperation.getRightChild());
            i = i8 + 1;
            iArr[i8] = this.evolutionaryProcessDelegate.getMatrixIndex(nodeOperation.getRightChild());
        }
        return list.size();
    }

    @Override // dr.evomodel.treedatalikelihood.preorder.ProcessSimulationDelegate
    public int getSingleOperationSize() {
        return 7;
    }

    private int getPostOrderPartialIndex(int i) {
        return this.likelihoodDelegate.getPartialBufferIndex(i);
    }

    private int getPreOrderPartialIndex(int i) {
        return this.preOrderPartialOffset + i;
    }

    public String toString() {
        return "\tsimulateCount = " + this.simulateCount + "\n\tgetTraitCount = " + this.getTraitCount + "\n\tupPrePartialCount = " + this.updatePrePartialCount + "\n";
    }

    static {
        $assertionsDisabled = !AbstractDiscreteTraitDelegate.class.desiredAssertionStatus();
    }
}
