package dr.evomodel.treedatalikelihood.continuous;

import dr.evolution.tree.BranchRates;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeUtils;
import dr.evomodel.treedatalikelihood.continuous.cdi.ContinuousDiffusionIntegrator;
import dr.math.KroneckerOperation;
import dr.math.matrixAlgebra.IllegalDimension;
import dr.math.matrixAlgebra.Matrix;
import java.util.HashSet;

/* loaded from: input_file:dr/evomodel/treedatalikelihood/continuous/MultivariateTraitDebugUtilities.class */
public class MultivariateTraitDebugUtilities {

    /* loaded from: input_file:dr/evomodel/treedatalikelihood/continuous/MultivariateTraitDebugUtilities$Accumulator.class */
    private enum Accumulator {
        OFF_DIAGONAL { // from class: dr.evomodel.treedatalikelihood.continuous.MultivariateTraitDebugUtilities.Accumulator.1
            @Override // dr.evomodel.treedatalikelihood.continuous.MultivariateTraitDebugUtilities.Accumulator
            BranchCumulant accumulate(BranchCumulant branchCumulant, double d, BranchCumulant branchCumulant2, double d2) {
                return new BranchCumulant(branchCumulant.nTaxa + branchCumulant2.nTaxa, branchCumulant.sharedLength + branchCumulant2.sharedLength + ((branchCumulant.nTaxa - 1) * branchCumulant.nTaxa * d) + ((branchCumulant2.nTaxa - 1) * branchCumulant2.nTaxa * d2));
            }
        },
        DIAGONAL { // from class: dr.evomodel.treedatalikelihood.continuous.MultivariateTraitDebugUtilities.Accumulator.2
            @Override // dr.evomodel.treedatalikelihood.continuous.MultivariateTraitDebugUtilities.Accumulator
            BranchCumulant accumulate(BranchCumulant branchCumulant, double d, BranchCumulant branchCumulant2, double d2) {
                return new BranchCumulant(branchCumulant.nTaxa + branchCumulant2.nTaxa, branchCumulant.sharedLength + branchCumulant2.sharedLength + (branchCumulant.nTaxa * d) + (branchCumulant2.nTaxa * d2));
            }
        };

        /* JADX INFO: Access modifiers changed from: private */
        /* loaded from: input_file:dr/evomodel/treedatalikelihood/continuous/MultivariateTraitDebugUtilities$Accumulator$BranchCumulant.class */
        public class BranchCumulant {
            final int nTaxa;
            final double sharedLength;

            BranchCumulant(int i, double d) {
                this.nTaxa = i;
                this.sharedLength = d;
            }
        }

        abstract BranchCumulant accumulate(BranchCumulant branchCumulant, double d, BranchCumulant branchCumulant2, double d2);

        /* JADX INFO: Access modifiers changed from: private */
        public BranchCumulant postOrderAccumulation(Tree tree, NodeRef nodeRef, BranchRates branchRates) {
            if (tree.isExternal(nodeRef)) {
                return new BranchCumulant(1, 0.0d);
            }
            NodeRef child = tree.getChild(nodeRef, 0);
            NodeRef child2 = tree.getChild(nodeRef, 1);
            BranchCumulant postOrderAccumulation = postOrderAccumulation(tree, child, branchRates);
            BranchCumulant postOrderAccumulation2 = postOrderAccumulation(tree, child2, branchRates);
            double branchLength = tree.getBranchLength(child);
            double branchLength2 = tree.getBranchLength(child2);
            if (branchRates != null) {
                branchLength *= branchRates.getBranchRate(tree, child);
                branchLength2 *= branchRates.getBranchRate(tree, child2);
            }
            return accumulate(postOrderAccumulation, branchLength, postOrderAccumulation2, branchLength2);
        }
    }

    public static double getLengthToRoot(Tree tree, BranchRates branchRates, NodeRef nodeRef) {
        double d = 0.0d;
        if (!tree.isRoot(nodeRef)) {
            NodeRef parent = tree.getParent(nodeRef);
            double d2 = 1.0d;
            if (branchRates != null) {
                d2 = branchRates.getBranchRate(tree, nodeRef);
            }
            d = 0.0d + (d2 * tree.getBranchLength(nodeRef)) + getLengthToRoot(tree, branchRates, parent);
        }
        return d;
    }

    private static NodeRef findMRCA(Tree tree, int i, int i2) {
        HashSet hashSet = new HashSet();
        hashSet.add(tree.getTaxonId(i));
        hashSet.add(tree.getTaxonId(i2));
        return TreeUtils.getCommonAncestorNode(tree, hashSet);
    }

    public static void insertPrecision(Tree tree, NodeRef nodeRef, NodeRef nodeRef2, double[][] dArr, double d) {
        double[] dArr2 = dArr[nodeRef.getNumber()];
        int number = nodeRef2.getNumber();
        double[] dArr3 = dArr[nodeRef2.getNumber()];
        int number2 = nodeRef.getNumber();
        double branchLength = (-1.0d) / (tree.getBranchLength(nodeRef2) * d);
        dArr3[number2] = branchLength;
        dArr2[number] = branchLength;
        recurseGraph(tree, nodeRef2, dArr, d);
    }

    public static void recurseGraph(Tree tree, NodeRef nodeRef, double[][] dArr, double d) {
        if (tree.isExternal(nodeRef)) {
            return;
        }
        insertPrecision(tree, nodeRef, tree.getChild(nodeRef, 0), dArr, d);
        insertPrecision(tree, nodeRef, tree.getChild(nodeRef, 1), dArr, d);
    }

    public static double[][] getGraphVariance(Tree tree, BranchRates branchRates, double d, double d2) {
        int nodeCount = tree.getNodeCount();
        double[][] dArr = new double[nodeCount][nodeCount];
        for (int i = 0; i < nodeCount; i++) {
            dArr[i][i] = getLengthToRoot(tree, branchRates, tree.getNode(i)) * d;
            for (int i2 = i + 1; i2 < nodeCount; i2++) {
                dArr[i][i2] = getLengthToRoot(tree, branchRates, TreeUtils.getCommonAncestorSafely(tree, tree.getNode(i), tree.getNode(i2))) * d;
            }
        }
        makeSymmetric(dArr);
        addPrior(dArr, d2);
        return dArr;
    }

    public static double[][] getTreeVariance(Tree tree, BranchRates branchRates, double d, double d2) {
        int externalNodeCount = tree.getExternalNodeCount();
        double[][] dArr = new double[externalNodeCount][externalNodeCount];
        for (int i = 0; i < externalNodeCount; i++) {
            dArr[i][i] = getLengthToRoot(tree, branchRates, tree.getExternalNode(i)) * d;
            for (int i2 = i + 1; i2 < externalNodeCount; i2++) {
                dArr[i][i2] = getLengthToRoot(tree, branchRates, findMRCA(tree, i, i2)) * d;
            }
        }
        makeSymmetric(dArr);
        addPrior(dArr, d2);
        return dArr;
    }

    private static void makeSymmetric(double[][] dArr) {
        for (int i = 0; i < dArr.length; i++) {
            for (int i2 = i + 1; i2 < dArr[i].length; i2++) {
                dArr[i2][i] = dArr[i][i2];
            }
        }
    }

    private static void addPrior(double[][] dArr, double d) {
        if (Double.isInfinite(d)) {
            return;
        }
        for (int i = 0; i < dArr.length; i++) {
            for (int i2 = 0; i2 < dArr[i].length; i2++) {
                double[] dArr2 = dArr[i];
                int i3 = i2;
                dArr2[i3] = dArr2[i3] + (1.0d / d);
            }
        }
    }

    public static double[][] getTreeDrift(Tree tree, double[] dArr, ContinuousDiffusionIntegrator continuousDiffusionIntegrator, DiffusionProcessDelegate diffusionProcessDelegate) {
        int dimTrait = continuousDiffusionIntegrator.getDimTrait();
        double[][] dArr2 = new double[tree.getExternalNodeCount()][dimTrait];
        for (int i = 0; i < tree.getExternalNodeCount(); i++) {
            dArr2[i] = diffusionProcessDelegate.getAccumulativeDrift(tree.getExternalNode(i), dArr, continuousDiffusionIntegrator, dimTrait);
        }
        return dArr2;
    }

    public static double[][] getGraphDrift(Tree tree, ContinuousDiffusionIntegrator continuousDiffusionIntegrator, DiffusionProcessDelegate diffusionProcessDelegate) {
        int dimTrait = continuousDiffusionIntegrator.getDimTrait();
        double[][] dArr = new double[tree.getNodeCount()][dimTrait];
        double[] dArr2 = new double[dimTrait];
        for (int i = 0; i < tree.getNodeCount(); i++) {
            dArr[i] = diffusionProcessDelegate.getAccumulativeDrift(tree.getNode(i), dArr2, continuousDiffusionIntegrator, dimTrait);
        }
        return dArr;
    }

    public static Matrix getJointVarianceFactor(double d, double[][] dArr, double[][] dArr2, double[][] dArr3, double[][] dArr4, DiffusionProcessDelegate diffusionProcessDelegate, Matrix matrix) {
        if (!diffusionProcessDelegate.hasActualization()) {
            return new Matrix(diffusionProcessDelegate.getJointVariance(d, dArr, dArr, dArr3));
        }
        Matrix matrix2 = new Matrix(diffusionProcessDelegate.getJointVariance(d, dArr, dArr2, dArr4));
        Matrix matrix3 = new Matrix(KroneckerOperation.product(KroneckerOperation.makeIdentityMatrixArray(dArr2[0].length), matrix.toComponents()));
        Matrix matrix4 = null;
        try {
            matrix4 = matrix3.product(matrix2.product(matrix3.transpose()));
        } catch (IllegalDimension e) {
            e.printStackTrace();
        }
        return matrix4;
    }

    public static double getVarianceOffDiagonalSum(Tree tree, BranchRates branchRates, double d) {
        return Accumulator.OFF_DIAGONAL.postOrderAccumulation(tree, tree.getRoot(), branchRates).sharedLength * d;
    }

    public static double getVarianceDiagonalSum(Tree tree, BranchRates branchRates, double d) {
        return Accumulator.DIAGONAL.postOrderAccumulation(tree, tree.getRoot(), branchRates).sharedLength * d;
    }
}
