package dr.inference.operators.hmc;

import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.hmc.PrecisionColumnProvider;
import dr.inference.hmc.PrecisionMatrixVectorProductProvider;
import dr.inference.model.Parameter;
import dr.inference.operators.hmc.AbstractParticleOperator;
import dr.math.MathUtils;
import dr.math.matrixAlgebra.WrappedVector;
import dr.xml.Reportable;
import java.util.Arrays;

/* loaded from: input_file:dr/inference/operators/hmc/IrreversibleZigZagOperator.class */
public class IrreversibleZigZagOperator extends AbstractZigZagOperator implements Reportable {
    static final boolean CPP_NEXT_BOUNCE = false;
    private static final boolean NEW_WAY = true;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:dr/inference/operators/hmc/IrreversibleZigZagOperator$PiecewiseLinearEndpoints.class */
    public class PiecewiseLinearEndpoints {
        final double c0;
        final double c1;
        final double f0;
        final double f1;
        final double slope0;
        final double slope1;

        private PiecewiseLinearEndpoints(double d, double d2, double d3, double d4, double d5, double d6) {
            this.c0 = d;
            this.c1 = d2;
            this.f0 = d3;
            this.f1 = d4;
            this.slope0 = d5;
            this.slope1 = d6;
        }
    }

    public IrreversibleZigZagOperator(GradientWrtParameterProvider gradientWrtParameterProvider, PrecisionMatrixVectorProductProvider precisionMatrixVectorProductProvider, PrecisionColumnProvider precisionColumnProvider, double d, AbstractParticleOperator.Options options, Parameter parameter, int i) {
        super(gradientWrtParameterProvider, precisionMatrixVectorProductProvider, precisionColumnProvider, d, options, parameter, i);
    }

    @Override // dr.inference.operators.hmc.AbstractZigZagOperator
    WrappedVector drawInitialMomentum() {
        return new WrappedVector.Raw(null, 0, 0);
    }

    @Override // dr.inference.operators.hmc.AbstractZigZagOperator
    WrappedVector drawInitialVelocity(WrappedVector wrappedVector) {
        WrappedVector wrappedVector2 = this.preconditioning.mass;
        double[] dArr = new double[wrappedVector2.getDim()];
        int dim = wrappedVector2.getDim();
        for (int i = 0; i < dim; i++) {
            dArr[i] = MathUtils.nextDouble() > 0.5d ? 1.0d : -1.0d;
        }
        if (this.mask != null) {
            applyMask(dArr);
        }
        return new WrappedVector.Raw(dArr);
    }

    @Override // dr.inference.operators.hmc.AbstractZigZagOperator, dr.inference.operators.hmc.AbstractParticleOperator
    double integrateTrajectory(WrappedVector wrappedVector) {
        WrappedVector drawInitialMomentum = drawInitialMomentum();
        WrappedVector drawInitialVelocity = drawInitialVelocity(drawInitialMomentum);
        WrappedVector initialGradient = getInitialGradient();
        WrappedVector precisionProduct = getPrecisionProduct(drawInitialVelocity);
        AbstractParticleOperator.BounceState bounceState = new AbstractParticleOperator.BounceState(drawTotalTravelTime());
        int i = 0;
        this.timer.startTimer("integrateTrajectory");
        while (bounceState.isTimeRemaining()) {
            this.timer.startTimer("getNext");
            MinimumTravelInformation nextBounce = getNextBounce(wrappedVector, drawInitialVelocity, precisionProduct, initialGradient, drawInitialMomentum);
            this.timer.stopTimer("getNext");
            bounceState = doBounce(bounceState, nextBounce, wrappedVector, drawInitialVelocity, precisionProduct, initialGradient, drawInitialMomentum);
            i++;
        }
        this.timer.stopTimer("integrateTrajectory");
        return 0.0d;
    }

    private MinimumTravelInformation testNative(WrappedVector wrappedVector, WrappedVector wrappedVector2, WrappedVector wrappedVector3, WrappedVector wrappedVector4) {
        this.timer.startTimer("getNextC++");
        MinimumTravelInformation nextEventIrreversible = this.nativeZigZag.getNextEventIrreversible(wrappedVector.getBuffer(), wrappedVector2.getBuffer(), wrappedVector3.getBuffer(), wrappedVector4.getBuffer());
        this.timer.stopTimer("getNextC++");
        return nextEventIrreversible;
    }

    private MinimumTravelInformation getNextBounce(WrappedVector wrappedVector, WrappedVector wrappedVector2, WrappedVector wrappedVector3, WrappedVector wrappedVector4, WrappedVector wrappedVector5) {
        return getNextBounce(0, wrappedVector.getDim(), wrappedVector.getBuffer(), wrappedVector2.getBuffer(), wrappedVector3.getBuffer(), wrappedVector4.getBuffer(), wrappedVector5.getBuffer());
    }

    private MinimumTravelInformation getNextBounce(int i, int i2, double[] dArr, double[] dArr2, double[] dArr3, double[] dArr4, double[] dArr5) {
        double d;
        AbstractParticleOperator.Type type;
        int i3 = -1;
        AbstractParticleOperator.Type type2 = AbstractParticleOperator.Type.NONE;
        double[] roots = getRoots(dArr3, dArr4);
        double[] dArr6 = (double[]) roots.clone();
        Arrays.sort(dArr6);
        double d2 = Double.POSITIVE_INFINITY;
        double switchTimeByMergedProcesses = getSwitchTimeByMergedProcesses(dArr3, dArr4, dArr2, roots, dArr6);
        int eventDimension = getEventDimension(dArr2, dArr4, dArr3, switchTimeByMergedProcesses);
        for (int i4 = 0; i4 < i2; i4++) {
            double findBoundaryTime = findBoundaryTime(i4, dArr[i4], dArr2[i4]);
            if (findBoundaryTime < d2) {
                d2 = findBoundaryTime;
                i3 = i4;
            }
        }
        if (switchTimeByMergedProcesses < d2) {
            i3 = eventDimension;
            d = switchTimeByMergedProcesses;
            type = AbstractParticleOperator.Type.GRADIENT;
        } else {
            d = d2;
            type = AbstractParticleOperator.Type.BOUNDARY;
        }
        return new MinimumTravelInformation(d, i3, type);
    }

    private double[] getRoots(double[] dArr, double[] dArr2) {
        double[] dArr3 = new double[dArr.length];
        for (int i = 0; i < dArr.length; i++) {
            double d = dArr2[i] / dArr[i];
            dArr3[i] = d >= 0.0d ? d : 0.0d;
        }
        return dArr3;
    }

    private double getSwitchTimeByMergedProcesses(double[] dArr, double[] dArr2, double[] dArr3, double[] dArr4, double[] dArr5) {
        double nextExponential = MathUtils.nextExponential(1.0d);
        double d = -1.0d;
        double d2 = 0.0d;
        if (dArr5[dArr5.length - 1] == 0.0d) {
            d = integrateLinearFunctionToArea(getEndpointInfo(0.0d, 0.0d, dArr3, dArr2, dArr, dArr4), nextExponential, false);
        } else {
            int i = 1;
            double d3 = dArr5[0];
            while (true) {
                if (i >= dArr5.length) {
                    break;
                }
                if (dArr5[i] > 0.0d) {
                    double d4 = dArr5[i];
                    PiecewiseLinearEndpoints endpointInfo = getEndpointInfo(d3, d4, dArr3, dArr2, dArr, dArr4);
                    double trapezoidArea = getTrapezoidArea(endpointInfo);
                    d2 += trapezoidArea;
                    if (d2 > nextExponential) {
                        d = integrateLinearFunctionToArea(endpointInfo, nextExponential - (d2 - trapezoidArea), true);
                        break;
                    }
                    if (i == dArr5.length - 1) {
                        d = integrateLinearFunctionToArea(endpointInfo, nextExponential - d2, false);
                        break;
                    }
                    d3 = d4;
                    i++;
                } else {
                    i++;
                }
            }
        }
        return d;
    }

    private double integrateLinearFunctionToArea(PiecewiseLinearEndpoints piecewiseLinearEndpoints, double d, boolean z) {
        if (z) {
            double d2 = piecewiseLinearEndpoints.slope0;
            double d3 = piecewiseLinearEndpoints.f0 - (d2 * piecewiseLinearEndpoints.c0);
            return onlyPositiveRoot(d2 * 0.5d, d3, -((d2 * 0.5d * piecewiseLinearEndpoints.c0 * piecewiseLinearEndpoints.c0) + (d3 * piecewiseLinearEndpoints.c0) + d));
        }
        double d4 = piecewiseLinearEndpoints.slope1;
        double d5 = piecewiseLinearEndpoints.f1 - (d4 * piecewiseLinearEndpoints.c1);
        return onlyPositiveRoot(d4 * 0.5d, d5, -((d4 * 0.5d * piecewiseLinearEndpoints.c1 * piecewiseLinearEndpoints.c1) + (d5 * piecewiseLinearEndpoints.c1) + d));
    }

    private double onlyPositiveRoot(double d, double d2, double d3) {
        return ((-d2) + Math.sqrt((d2 * d2) - ((4.0d * d) * d3))) / (2.0d * d);
    }

    private PiecewiseLinearEndpoints getEndpointInfo(double d, double d2, double[] dArr, double[] dArr2, double[] dArr3, double[] dArr4) {
        double[] dArr5 = new double[2];
        double[] dArr6 = new double[2];
        for (int i = 0; i < dArr4.length; i++) {
            accumulateCoef(d, dArr, dArr2, dArr3, dArr4, dArr5, i);
            accumulateCoef(d2, dArr, dArr2, dArr3, dArr4, dArr6, i);
        }
        return new PiecewiseLinearEndpoints(d, d2, (dArr5[0] * d) + dArr5[1], (dArr6[0] * d2) + dArr6[1], dArr5[0], dArr6[0]);
    }

    private void accumulateCoef(double d, double[] dArr, double[] dArr2, double[] dArr3, double[] dArr4, double[] dArr5, int i) {
        if ((dArr4[i] < d || dArr[i] * dArr3[i] > 0.0d) && (dArr4[i] > d || dArr[i] * dArr3[i] < 0.0d)) {
            return;
        }
        dArr5[0] = dArr5[0] + (dArr[i] * dArr3[i]);
        dArr5[1] = dArr5[1] + ((-dArr[i]) * dArr2[i]);
    }

    private double getTrapezoidArea(PiecewiseLinearEndpoints piecewiseLinearEndpoints) {
        return ((piecewiseLinearEndpoints.f0 + piecewiseLinearEndpoints.f1) * (piecewiseLinearEndpoints.c1 - piecewiseLinearEndpoints.c0)) / 2.0d;
    }

    private int getEventDimension(double[] dArr, double[] dArr2, double[] dArr3, double d) {
        double[] dArr4 = new double[dArr.length];
        double d2 = 0.0d;
        for (int i = 0; i < dArr.length; i++) {
            double d3 = ((d * dArr[i]) * dArr3[i]) - (dArr[i] * dArr2[i]);
            dArr4[i] = d3 > 0.0d ? d3 : 0.0d;
            d2 += dArr4[i];
        }
        double nextDouble = MathUtils.nextDouble();
        double d4 = 0.0d;
        int i2 = -1;
        int i3 = 0;
        while (true) {
            if (i3 >= dArr4.length) {
                break;
            }
            d4 += dArr4[i3] / d2;
            if (nextDouble <= d4) {
                i2 = i3;
                break;
            }
            i3++;
        }
        return i2;
    }

    private double getSwitchTime(double d, double d2, double d3) {
        if (d2 > 0.0d) {
            return d < 0.0d ? ((-d) / d2) + Math.sqrt((2.0d * d3) / d2) : ((-d) / d2) + Math.sqrt(((d * d) / (d2 * d2)) + ((2.0d * d3) / d2));
        }
        if (d2 == 0.0d) {
            if (d > 0.0d) {
                return d3 / d;
            }
            return Double.POSITIVE_INFINITY;
        }
        if (d <= 0.0d) {
            return Double.POSITIVE_INFINITY;
        }
        double d4 = (-d) / d2;
        if (d3 <= (d * d4) + (((d2 * d4) * d4) / 2.0d)) {
            return ((-d) / d2) - Math.sqrt(((d * d) / (d2 * d2)) + ((2.0d * d3) / d2));
        }
        return Double.POSITIVE_INFINITY;
    }

    @Override // dr.inference.operators.hmc.AbstractZigZagOperator
    AbstractParticleOperator.BounceState doBounce(AbstractParticleOperator.BounceState bounceState, MinimumTravelInformation minimumTravelInformation, WrappedVector wrappedVector, WrappedVector wrappedVector2, WrappedVector wrappedVector3, WrappedVector wrappedVector4, WrappedVector wrappedVector5) {
        AbstractParticleOperator.BounceState bounceState2;
        this.timer.startTimer("doBounce");
        double d = bounceState.remainingTime;
        double d2 = minimumTravelInformation.time;
        if (d < d2) {
            updatePosition(wrappedVector, wrappedVector2, d);
            bounceState2 = new AbstractParticleOperator.BounceState(AbstractParticleOperator.Type.NONE, -1, 0.0d);
        } else {
            AbstractParticleOperator.Type type = minimumTravelInformation.type;
            int i = minimumTravelInformation.index;
            updateDynamics(wrappedVector.getBuffer(), wrappedVector2.getBuffer(), wrappedVector3.getBuffer(), wrappedVector4.getBuffer(), getPrecisionColumn(i).getBuffer(), d2, i);
            reflectVelocity(wrappedVector2, i);
            bounceState2 = new AbstractParticleOperator.BounceState(type, i, d - d2);
        }
        this.timer.stopTimer("doBounce");
        return bounceState2;
    }

    private void updateDynamics(double[] dArr, double[] dArr2, double[] dArr3, double[] dArr4, double[] dArr5, double d, int i) {
        double d2 = 2.0d * dArr2[i];
        int length = dArr.length;
        for (int i2 = 0; i2 < length; i2++) {
            double d3 = dArr3[i2];
            dArr[i2] = dArr[i2] + (d * dArr2[i2]);
            dArr4[i2] = dArr4[i2] - (d * d3);
            dArr3[i2] = d3 - (d2 * dArr5[i2]);
        }
    }

    @Override // dr.inference.operators.SimpleMCMCOperator, dr.inference.operators.MCMCOperator
    public String getOperatorName() {
        return "Irreversible zig-zag operator";
    }
}
