package dr.evomodel.treedatalikelihood.continuous.cdi;

import dr.math.matrixAlgebra.missingData.MissingOps;
import java.util.Arrays;
import org.ejml.data.DenseMatrix64F;
import org.ejml.ops.CommonOps;

/* loaded from: input_file:dr/evomodel/treedatalikelihood/continuous/cdi/SafeMultivariateDiagonalActualizedWithDriftIntegrator.class */
public class SafeMultivariateDiagonalActualizedWithDriftIntegrator extends SafeMultivariateWithDriftIntegrator {
    private static boolean DEBUG;
    private static final boolean TIMING = false;
    private double[] diagonal1mActualizations;
    double[] stationaryVariances;
    private double[] vectorDiagQdi;
    private double[] vectorDiagQdj;
    static final /* synthetic */ boolean $assertionsDisabled;

    public SafeMultivariateDiagonalActualizedWithDriftIntegrator(PrecisionType precisionType, int i, int i2, int i3, int i4, int i5) {
        super(precisionType, i, i2, i3, i4, i5);
        allocateStorage();
        System.err.println("Trying SafeMultivariateDiagonalActualizedWithDriftIntegrator");
    }

    @Override // dr.evomodel.treedatalikelihood.continuous.cdi.ContinuousDiffusionIntegrator.Basic, dr.evomodel.treedatalikelihood.continuous.cdi.ContinuousDiffusionIntegrator
    public void getBranch1mActualization(int i, double[] dArr) {
        if (i == -1) {
            throw new RuntimeException("Not yet implemented");
        }
        if (!$assertionsDisabled && dArr == null) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && dArr.length < this.dimTrait) {
            throw new AssertionError();
        }
        System.arraycopy(this.diagonal1mActualizations, i * this.dimTrait, dArr, 0, this.dimTrait);
    }

    @Override // dr.evomodel.treedatalikelihood.continuous.cdi.ContinuousDiffusionIntegrator.Basic, dr.evomodel.treedatalikelihood.continuous.cdi.ContinuousDiffusionIntegrator
    public void getBranchActualization(int i, double[] dArr) {
        getBranch1mActualization(i, dArr);
        oneMinus(dArr);
    }

    @Override // dr.evomodel.treedatalikelihood.continuous.cdi.SafeMultivariateWithDriftIntegrator, dr.evomodel.treedatalikelihood.continuous.cdi.ContinuousDiffusionIntegrator.Basic, dr.evomodel.treedatalikelihood.continuous.cdi.ContinuousDiffusionIntegrator
    public void getBranchExpectation(double[] dArr, double[] dArr2, double[] dArr3, double[] dArr4) {
        if (!$assertionsDisabled && dArr4 == null) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && dArr4.length < this.dimTrait) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && dArr == null) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && dArr.length < this.dimTrait) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && dArr2 == null) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && dArr2.length < this.dimTrait) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && dArr3 == null) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && dArr3.length < this.dimTrait) {
            throw new AssertionError();
        }
        for (int i = 0; i < this.dimTrait; i++) {
            dArr4[i] = (dArr[i] * dArr2[i]) + dArr3[i];
        }
    }

    private void allocateStorage() {
        this.diagonal1mActualizations = new double[this.dimTrait * this.bufferCount];
        this.stationaryVariances = new double[this.dimProcess * this.dimProcess * this.diffusionCount];
        this.vectorDiagQdi = new double[this.dimTrait];
        this.vectorDiagQdj = new double[this.dimTrait];
    }

    @Override // dr.evomodel.treedatalikelihood.continuous.cdi.ContinuousDiffusionIntegrator.Basic, dr.evomodel.treedatalikelihood.continuous.cdi.ContinuousDiffusionIntegrator
    public void setDiffusionStationaryVariance(int i, double[] dArr, double[] dArr2) {
        if (!$assertionsDisabled && this.stationaryVariances == null) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && this.dimProcess != dArr.length) {
            throw new AssertionError();
        }
        int i2 = this.dimProcess * this.dimProcess;
        int i3 = i2 * i;
        double[] dArr3 = new double[i2];
        scalingMatrix(dArr, dArr3);
        setStationaryVariance(i3, dArr3, i2, dArr2);
    }

    void setStationaryVariance(int i, double[] dArr, int i2, double[] dArr2) {
        scaleInv(this.inverseDiffusions, i, dArr, this.stationaryVariances, i, i2);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static void scaleInv(double[] dArr, int i, double[] dArr2, double[] dArr3, int i2, int i3) {
        for (int i4 = 0; i4 < i3; i4++) {
            dArr3[i2 + i4] = dArr[i + i4] / dArr2[i4];
        }
    }

    private static void scalingMatrix(double[] dArr, double[] dArr2) {
        int length = dArr.length;
        for (int i = 0; i < length; i++) {
            for (int i2 = 0; i2 < length; i2++) {
                dArr2[(i * length) + i2] = dArr[i] + dArr[i2];
            }
        }
    }

    @Override // dr.evomodel.treedatalikelihood.continuous.cdi.ContinuousDiffusionIntegrator.Basic, dr.evomodel.treedatalikelihood.continuous.cdi.ContinuousDiffusionIntegrator
    public void updateOrnsteinUhlenbeckDiffusionMatrices(int i, int[] iArr, double[] dArr, double[] dArr2, double[] dArr3, double[] dArr4, int i2) {
        if (!$assertionsDisabled && this.diffusions == null) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && iArr.length < i2) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && dArr.length < i2) {
            throw new AssertionError();
        }
        super.updateOrnsteinUhlenbeckDiffusionMatrices(i, iArr, dArr, dArr2, dArr3, dArr4, i2);
        if (DEBUG) {
            System.err.println("Matrices (safe with actualized drift):");
        }
        int i3 = this.dimTrait * this.dimTrait;
        int i4 = this.dimProcess * this.dimProcess * i;
        for (int i5 = 0; i5 < i2; i5++) {
            double d = dArr[i5];
            int i6 = this.dimTrait * iArr[i5];
            computeOUActualization(dArr3, dArr4, d, i6, this.dimTrait * i6);
        }
        for (int i7 = 0; i7 < i2; i7++) {
            double d2 = dArr[i7];
            int i8 = i3 * iArr[i7];
            computeOUVarianceBranch(i4, i8, this.dimTrait * iArr[i7], d2);
            invertVectorSymmPosDef(this.variances, this.precisions, i8, this.dimProcess);
        }
        if (!$assertionsDisabled && dArr2 == null) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && this.displacements == null) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && dArr2.length < i2 * this.dimProcess) {
            throw new AssertionError();
        }
        int i9 = 0;
        for (int i10 = 0; i10 < i2; i10++) {
            computeOUActualizedDisplacement(dArr2, i9, i3 * iArr[i10], this.dimTrait * iArr[i10]);
            i9 += this.dimProcess;
        }
    }

    void computeOUActualization(double[] dArr, double[] dArr2, double d, int i, int i2) {
        computeOUDiagonal1mActualization(dArr, d, this.dimTrait, this.diagonal1mActualizations, i);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static void computeOUDiagonal1mActualization(double[] dArr, double d, int i, double[] dArr2, int i2) {
        for (int i3 = 0; i3 < i; i3++) {
            dArr2[i2 + i3] = -Math.expm1((-dArr[i3]) * d);
        }
    }

    void computeOUVarianceBranch(int i, int i2, int i3, double d) {
        scalingActualizationMatrix(this.diagonal1mActualizations, i3, this.stationaryVariances, i, this.variances, i2, this.dimTrait, d, this.inverseDiffusions, i);
    }

    private static void scalingActualizationMatrix(double[] dArr, int i, double[] dArr2, int i2, double[] dArr3, int i3, int i4, double d, double[] dArr4, int i5) {
        for (int i6 = 0; i6 < i4; i6++) {
            for (int i7 = 0; i7 < i4; i7++) {
                double d2 = dArr2[i2 + (i6 * i4) + i7];
                if (Double.isInfinite(d2) || dArr[i + i6] + dArr[i + i7] == 0.0d) {
                    dArr3[i3 + (i6 * i4) + i7] = d * dArr4[i5 + (i6 * i4) + i7];
                } else {
                    dArr3[i3 + (i6 * i4) + i7] = d2 * (((-dArr[i + i6]) * dArr[i + i7]) + dArr[i + i6] + dArr[i + i7]);
                }
            }
        }
    }

    private static void invertVectorSymmPosDef(double[] dArr, double[] dArr2, int i, int i2) {
        DenseMatrix64F wrap = MissingOps.wrap(dArr, i, i2, i2);
        DenseMatrix64F denseMatrix64F = new DenseMatrix64F(i2, i2);
        MissingOps.symmPosDefInvert(wrap, denseMatrix64F);
        MissingOps.unwrap(denseMatrix64F, dArr2, i);
    }

    void computeOUActualizedDisplacement(double[] dArr, int i, int i2, int i3) {
        for (int i4 = 0; i4 < this.dimTrait; i4++) {
            this.displacements[i3 + i4] = dArr[i + i4] * this.diagonal1mActualizations[i3 + i4];
        }
    }

    @Override // dr.evomodel.treedatalikelihood.continuous.cdi.SafeMultivariateIntegrator
    void actualizePrecision(DenseMatrix64F denseMatrix64F, DenseMatrix64F denseMatrix64F2, int i, int i2, int i3) {
        double[] dArr = this.vectorDiagQdj;
        System.arraycopy(this.diagonal1mActualizations, i3, dArr, 0, this.dimTrait);
        oneMinus(dArr);
        MissingOps.diagMult(dArr, denseMatrix64F, denseMatrix64F2);
        MissingOps.diagMult(denseMatrix64F2, dArr, denseMatrix64F);
    }

    @Override // dr.evomodel.treedatalikelihood.continuous.cdi.SafeMultivariateIntegrator
    void actualizeVariance(DenseMatrix64F denseMatrix64F, int i, int i2, int i3) {
        double[] dArr = this.vectorDiagQdi;
        System.arraycopy(this.diagonal1mActualizations, i3, dArr, 0, this.dimTrait);
        oneMinus(dArr);
        diagonalDoubleProduct(denseMatrix64F, dArr, denseMatrix64F);
    }

    @Override // dr.evomodel.treedatalikelihood.continuous.cdi.SafeMultivariateWithDriftIntegrator, dr.evomodel.treedatalikelihood.continuous.cdi.SafeMultivariateIntegrator
    void scaleAndDriftMean(int i, int i2, int i3) {
        for (int i4 = 0; i4 < this.dimTrait; i4++) {
            this.preOrderPartials[i + i4] = ((1.0d - this.diagonal1mActualizations[i3 + i4]) * this.preOrderPartials[i + i4]) + this.displacements[i3 + i4];
        }
    }

    public double[] getStationaryVariance(int i) {
        if ($assertionsDisabled || this.stationaryVariances != null) {
            return getMatrixProcess(i, this.stationaryVariances);
        }
        throw new AssertionError();
    }

    @Override // dr.evomodel.treedatalikelihood.continuous.cdi.SafeMultivariateIntegrator
    void computePartialPrecision(int i, int i2, int i3, int i4, DenseMatrix64F denseMatrix64F, DenseMatrix64F denseMatrix64F2, DenseMatrix64F denseMatrix64F3) {
        double[] dArr = this.vectorDiagQdi;
        System.arraycopy(this.diagonal1mActualizations, i, dArr, 0, this.dimTrait);
        oneMinus(dArr);
        double[] dArr2 = this.vectorDiagQdj;
        System.arraycopy(this.diagonal1mActualizations, i2, dArr2, 0, this.dimTrait);
        oneMinus(dArr2);
        DenseMatrix64F denseMatrix64F4 = this.matrix0;
        DenseMatrix64F denseMatrix64F5 = this.matrix1;
        diagonalDoubleProduct(denseMatrix64F, dArr, denseMatrix64F4);
        diagonalDoubleProduct(denseMatrix64F2, dArr2, denseMatrix64F5);
        CommonOps.add(denseMatrix64F4, denseMatrix64F5, denseMatrix64F3);
        if (DEBUG) {
            System.err.println("Qdi: " + Arrays.toString(dArr));
            System.err.println("\tQdiPipQdi: " + denseMatrix64F4);
            System.err.println("\tQdj: " + Arrays.toString(dArr2));
            System.err.println("\tQdjPjpQdj: " + denseMatrix64F5);
        }
    }

    @Override // dr.evomodel.treedatalikelihood.continuous.cdi.SafeMultivariateWithDriftIntegrator
    void computeWeightedSum(double[] dArr, double[] dArr2, int i, double[] dArr3) {
        MissingOps.weightedSumActualized(dArr, 0, this.matrixPip, this.vectorDiagQdi, 0, dArr2, 0, this.matrixPjp, this.vectorDiagQdj, 0, i, dArr3);
    }

    private static void diagonalDoubleProduct(DenseMatrix64F denseMatrix64F, double[] dArr, DenseMatrix64F denseMatrix64F2) {
        MissingOps.diagMult(denseMatrix64F, dArr, denseMatrix64F2);
        MissingOps.diagMult(dArr, denseMatrix64F2);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static void oneMinus(double[] dArr) {
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = 1.0d - dArr[i];
        }
    }

    static {
        $assertionsDisabled = !SafeMultivariateDiagonalActualizedWithDriftIntegrator.class.desiredAssertionStatus();
        DEBUG = false;
    }
}
