package dr.evomodel.substmodel;

import dr.evomodel.substmodel.DifferentialMassProvider;
import dr.math.matrixAlgebra.ReadableMatrix;
import dr.math.matrixAlgebra.WrappedMatrix;

/* loaded from: input_file:dr/evomodel/substmodel/DifferentiableSubstitutionModelUtil.class */
public class DifferentiableSubstitutionModelUtil {
    public static double[] getDifferentialMassMatrix(double d, int i, WrappedMatrix wrappedMatrix, EigenDecomposition eigenDecomposition) {
        double[] eigenValues = eigenDecomposition.getEigenValues();
        WrappedMatrix.Raw raw = new WrappedMatrix.Raw(eigenDecomposition.getEigenVectors(), 0, i, i);
        WrappedMatrix.Raw raw2 = new WrappedMatrix.Raw(eigenDecomposition.getInverseEigenVectors(), 0, i, i);
        getTripleMatrixMultiplication(i, raw2, wrappedMatrix, raw);
        for (int i2 = 0; i2 < i; i2++) {
            for (int i3 = 0; i3 < i; i3++) {
                if (i2 == i3 || eigenValues[i2] == eigenValues[i3]) {
                    wrappedMatrix.set(i2, i3, wrappedMatrix.get(i2, i3) * d);
                } else {
                    wrappedMatrix.set(i2, i3, (wrappedMatrix.get(i2, i3) * (1.0d - Math.exp((eigenValues[i3] - eigenValues[i2]) * d))) / (eigenValues[i2] - eigenValues[i3]));
                }
            }
        }
        getTripleMatrixMultiplication(i, raw, wrappedMatrix, raw2);
        double[] dArr = new double[i * i];
        int i4 = i * i;
        for (int i5 = 0; i5 < i4; i5++) {
            dArr[i5] = wrappedMatrix.get(i5);
        }
        return dArr;
    }

    public static void getTripleMatrixMultiplication(int i, ReadableMatrix readableMatrix, WrappedMatrix wrappedMatrix, ReadableMatrix readableMatrix2) {
        double[][] dArr = new double[i][i];
        for (int i2 = 0; i2 < i; i2++) {
            for (int i3 = 0; i3 < i; i3++) {
                for (int i4 = 0; i4 < i; i4++) {
                    double[] dArr2 = dArr[i2];
                    int i5 = i3;
                    dArr2[i5] = dArr2[i5] + (wrappedMatrix.get(i2, i4) * readableMatrix2.get(i4, i3));
                }
            }
        }
        for (int i6 = 0; i6 < i; i6++) {
            for (int i7 = 0; i7 < i; i7++) {
                double d = 0.0d;
                for (int i8 = 0; i8 < i; i8++) {
                    d += readableMatrix.get(i6, i8) * dArr[i8][i7];
                }
                wrappedMatrix.set(i6, i7, d);
            }
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static WrappedMatrix getInfinitesimalDifferentialMatrix(DifferentialMassProvider.DifferentialWrapper.WrtParameter wrtParameter, BaseSubstitutionModel baseSubstitutionModel) {
        if (!(baseSubstitutionModel instanceof DifferentiableSubstitutionModel)) {
            throw new RuntimeException("Not supported!");
        }
        double d = baseSubstitutionModel.setupMatrix();
        int stateCount = baseSubstitutionModel.getDataType().getStateCount();
        int rateCount = baseSubstitutionModel.getRateCount(stateCount);
        double[] dArr = new double[stateCount * stateCount];
        baseSubstitutionModel.getInfinitesimalMatrix(dArr);
        double[] dArr2 = new double[rateCount];
        ((DifferentiableSubstitutionModel) baseSubstitutionModel).setupDifferentialRates(wrtParameter, dArr2, d);
        double[][] dArr3 = new double[stateCount][stateCount];
        baseSubstitutionModel.setupQMatrix(dArr2, baseSubstitutionModel.getFrequencyModel().getFrequencies(), dArr3);
        baseSubstitutionModel.makeValid(dArr3, stateCount);
        double weightedNormalizationGradient = ((DifferentiableSubstitutionModel) baseSubstitutionModel).getWeightedNormalizationGradient(wrtParameter, dArr3, baseSubstitutionModel.getFrequencyModel().getFrequencies());
        for (int i = 0; i < stateCount; i++) {
            for (int i2 = 0; i2 < stateCount; i2++) {
                double[] dArr4 = dArr3[i];
                int i3 = i2;
                dArr4[i3] = dArr4[i3] - (dArr[(i * stateCount) + i2] * weightedNormalizationGradient);
            }
        }
        return new WrappedMatrix.ArrayOfArray(dArr3);
    }
}
