package dr.inference.operators;

import dr.inference.distribution.DistributionLikelihood;
import dr.inference.distribution.MultivariateDistributionLikelihood;
import dr.inference.distribution.NormalDistributionModel;
import dr.inference.distribution.ParametricDistributionModel;
import dr.inference.model.CompoundLikelihood;
import dr.inference.model.Likelihood;
import dr.inference.model.MaskedParameter;
import dr.inference.model.MatrixParameterInterface;
import dr.inference.model.Parameter;
import dr.inference.model.TransformedParameter;
import dr.inference.model.Variable;
import dr.inferencexml.operators.EllipticalSliceOperatorParser;
import dr.math.MathUtils;
import dr.math.distributions.CompoundGaussianProcess;
import dr.math.distributions.GaussianProcessRandomGenerator;
import dr.math.distributions.MultivariateNormalDistribution;
import dr.math.matrixAlgebra.WrappedVector;
import dr.math.matrixAlgebra.missingData.MissingOps;
import dr.util.Attribute;
import dr.util.Transform;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.ejml.data.DenseMatrix64F;
import org.ejml.factory.DecompositionFactory;
import org.ejml.interfaces.decomposition.QRDecomposition;
import org.ejml.ops.CommonOps;

/* loaded from: input_file:dr/inference/operators/EllipticalSliceOperator.class */
public class EllipticalSliceOperator extends SimpleMetropolizedGibbsOperator implements GibbsOperator {
    private final GaussianProcessRandomGenerator gaussianProcess;
    private static final boolean MINIMAL_EVALUATION = true;
    private double pathParameter;
    private final Parameter variable;
    private int current;
    private boolean drawByRow;
    private boolean signalConstituentParameters;
    private double[] priorMean;
    private boolean center;
    private double bracketAngle;
    private boolean translationInvariant;
    private boolean rotationInvariant;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:dr/inference/operators/EllipticalSliceOperator$Interval.class */
    public class Interval {
        double lower;
        double upper;

        Interval(double d, double d2) {
            this.lower = d;
            this.upper = d2;
        }

        void adjust(double d) {
            if (d > 0.0d) {
                this.upper = d;
            } else {
                if (d >= 0.0d) {
                    throw new RuntimeException("Shrunk to current position; bad.");
                }
                this.lower = d;
            }
        }

        double draw() {
            return (MathUtils.nextDouble() * (this.upper - this.lower)) + this.lower;
        }
    }

    public EllipticalSliceOperator(Parameter parameter, GaussianProcessRandomGenerator gaussianProcessRandomGenerator, boolean z, boolean z2) {
        this(parameter, gaussianProcessRandomGenerator, z, z2, 0.0d, false, false);
    }

    public EllipticalSliceOperator(Parameter parameter, GaussianProcessRandomGenerator gaussianProcessRandomGenerator, boolean z, boolean z2, double d, boolean z3, boolean z4) {
        this.pathParameter = 1.0d;
        this.priorMean = null;
        this.center = true;
        this.variable = parameter;
        this.gaussianProcess = gaussianProcessRandomGenerator;
        this.drawByRow = z;
        this.signalConstituentParameters = z2;
        this.bracketAngle = d;
        this.translationInvariant = z3;
        this.rotationInvariant = z4;
        if (d < 0.0d || d >= 6.283185307179586d) {
            throw new IllegalArgumentException("Invalid bracket angle");
        }
        int dimension = parameter.getDimension();
        int length = ((double[]) gaussianProcessRandomGenerator.nextRandom()).length;
        if (dimension != length) {
            throw new IllegalArgumentException("Dimension of variable (" + dimension + ") does not match dimension of Gaussian process draw (" + length + ")");
        }
    }

    public Variable<Double> getVariable() {
        return this.variable;
    }

    private double getLogGaussianPrior() {
        return this.gaussianProcess.getLikelihood() == null ? this.gaussianProcess.logPdf(this.variable.getParameterValues()) : this.gaussianProcess.getLikelihood().getLogLikelihood();
    }

    private void unwindCompoundLikelihood(Likelihood likelihood, List<Likelihood> list) {
        if (!(likelihood instanceof CompoundLikelihood)) {
            list.add(likelihood);
            return;
        }
        Iterator<Likelihood> it = ((CompoundLikelihood) likelihood).getLikelihoods().iterator();
        while (it.hasNext()) {
            unwindCompoundLikelihood(it.next(), list);
        }
    }

    private List<Likelihood> unwindCompoundLikelihood(Likelihood likelihood) {
        ArrayList arrayList = new ArrayList();
        unwindCompoundLikelihood(likelihood, arrayList);
        return arrayList;
    }

    private boolean containsGaussianProcess(Likelihood likelihood) {
        return this.gaussianProcess instanceof CompoundGaussianProcess ? ((CompoundGaussianProcess) this.gaussianProcess).contains(likelihood) : this.gaussianProcess == likelihood;
    }

    private double evaluateDensity(Likelihood likelihood, double d) {
        return evaluate(likelihood, d) - (getLogGaussianPrior() * d);
    }

    @Override // dr.inference.operators.SimpleMetropolizedGibbsOperator
    public double doOperation(Likelihood likelihood) {
        List<Likelihood> unwindCompoundLikelihood = unwindCompoundLikelihood(likelihood);
        ArrayList arrayList = new ArrayList();
        for (Likelihood likelihood2 : unwindCompoundLikelihood) {
            if (!containsGaussianProcess(likelihood2)) {
                arrayList.add(likelihood2);
            }
        }
        CompoundLikelihood compoundLikelihood = new CompoundLikelihood(arrayList);
        drawFromSlice(compoundLikelihood, compoundLikelihood.getLogLikelihood() + MathUtils.randomLogDouble());
        return 0.0d;
    }

    private double[] pointOnEllipse(double[] dArr, double[] dArr2, double d, double[] dArr3) {
        int length = dArr.length;
        double cos = Math.cos(d);
        double sin = Math.sin(d);
        double[] dArr4 = new double[length];
        if (dArr3 == null) {
            for (int i = 0; i < length; i++) {
                dArr4[i] = (dArr[i] * cos) + (dArr2[i] * sin);
            }
        } else {
            for (int i2 = 0; i2 < length; i2++) {
                dArr4[i2] = ((dArr[i2] - dArr3[i2]) * cos) + ((dArr2[i2] - dArr3[i2]) * sin) + dArr3[i2];
            }
        }
        return dArr4;
    }

    private static void translate(double[] dArr, int i) {
        double[] dArr2 = new double[i];
        int i2 = 0;
        for (int i3 = 0; i3 < dArr.length / i; i3++) {
            for (int i4 = 0; i4 < i; i4++) {
                int i5 = i4;
                dArr2[i5] = dArr2[i5] + dArr[i2];
                i2++;
            }
        }
        for (int i6 = 0; i6 < i; i6++) {
            int i7 = i6;
            dArr2[i7] = dArr2[i7] / (dArr.length / i);
        }
        int i8 = 0;
        for (int i9 = 0; i9 < dArr.length / i; i9++) {
            for (int i10 = 0; i10 < i; i10++) {
                int i11 = i8;
                dArr[i11] = dArr[i11] - dArr2[i10];
                i8++;
            }
        }
    }

    private static void rotateNd(double[] dArr, int i) {
        DenseMatrix64F denseMatrix64F = new DenseMatrix64F(i, i);
        for (int i2 = 0; i2 < i; i2++) {
            for (int i3 = 0; i3 < i; i3++) {
                denseMatrix64F.set(i2, i3, dArr[(i3 * i) + i2]);
            }
        }
        QRDecomposition<DenseMatrix64F> qr = DecompositionFactory.qr(i, i);
        qr.decompose(denseMatrix64F);
        DenseMatrix64F q = qr.getQ(null, true);
        DenseMatrix64F r = qr.getR(null, true);
        if (r.get(0, 0) < 0.0d) {
            CommonOps.scale(-1.0d, r);
            CommonOps.scale(-1.0d, q);
        }
        DenseMatrix64F denseMatrix64F2 = new DenseMatrix64F(i, i);
        CommonOps.transpose(q, denseMatrix64F2);
        for (int i4 = 0; i4 < dArr.length / i; i4++) {
            WrappedVector.Raw raw = new WrappedVector.Raw(dArr, i4 * i, i);
            MissingOps.matrixVectorMultiple(denseMatrix64F2, raw, raw, i);
        }
    }

    private static void rotate(double[] dArr, int i) {
        rotateNd(dArr, i);
    }

    public static void transformPoint(double[] dArr, boolean z, boolean z2, int i) {
        if (z) {
            translate(dArr, i);
        }
        if (z2) {
            rotate(dArr, i);
        }
    }

    private void transformPoint(double[] dArr) {
        transformPoint(dArr, this.translationInvariant, this.rotationInvariant, 2);
    }

    private void setAllParameterValues(double[] dArr) {
        if (this.variable instanceof MatrixParameterInterface) {
            ((MatrixParameterInterface) this.variable).setAllParameterValuesQuietly(dArr, 0);
            return;
        }
        for (int i = 0; i < dArr.length; i++) {
            this.variable.setParameterValueQuietly(i, dArr[i]);
        }
    }

    private void setVariable(double[] dArr) {
        transformPoint(dArr);
        setAllParameterValues(dArr);
        if (this.signalConstituentParameters) {
            this.variable.fireParameterChangedEvent();
        } else {
            this.variable.fireParameterChangedEvent(-1, Variable.ChangeType.ALL_VALUES_CHANGED);
        }
    }

    private void drawFromSlice(Likelihood likelihood, double d) {
        Interval interval;
        double draw;
        double[] parameterValues = this.variable.getParameterValues();
        double[] dArr = (double[]) this.gaussianProcess.nextRandom();
        if (this.bracketAngle == 0.0d) {
            draw = MathUtils.nextDouble() * 2.0d * 3.141592653589793d;
            interval = new Interval(draw - 6.283185307179586d, draw);
        } else {
            double nextDouble = (-this.bracketAngle) * MathUtils.nextDouble();
            interval = new Interval(nextDouble, nextDouble + this.bracketAngle);
            draw = interval.draw();
        }
        boolean z = false;
        while (!z) {
            setVariable(pointOnEllipse(parameterValues, dArr, draw, this.priorMean));
            if (evaluate(likelihood, this.pathParameter) - getLogGaussianPrior() > d) {
                z = true;
            } else {
                interval.adjust(draw);
                draw = interval.draw();
            }
        }
    }

    private void drawFromSlice(CompoundLikelihood compoundLikelihood, double d) {
        Interval interval;
        double draw;
        double[] parameterValues = this.variable.getParameterValues();
        double[] dArr = (double[]) this.gaussianProcess.nextRandom();
        if (this.bracketAngle == 0.0d) {
            draw = MathUtils.nextDouble() * 2.0d * 3.141592653589793d;
            interval = new Interval(draw - 6.283185307179586d, draw);
        } else {
            double nextDouble = (-this.bracketAngle) * MathUtils.nextDouble();
            interval = new Interval(nextDouble, nextDouble + this.bracketAngle);
            draw = interval.draw();
        }
        boolean z = false;
        while (!z) {
            setVariable(pointOnEllipse(parameterValues, dArr, draw, this.priorMean));
            if (compoundLikelihood.getLogLikelihood() > d) {
                z = true;
            } else {
                interval.adjust(draw);
                draw = interval.draw();
            }
        }
    }

    @Override // dr.inference.operators.SimpleMetropolizedGibbsOperator
    public int getStepCount() {
        return 1;
    }

    @Override // dr.inference.operators.PathDependent
    public void setPathParameter(double d) {
        this.pathParameter = d;
    }

    @Override // dr.inference.operators.SimpleMetropolizedGibbsOperator, dr.inference.operators.SimpleOperator, dr.inference.operators.MCMCOperator
    public String getOperatorName() {
        return EllipticalSliceOperatorParser.ELLIPTICAL_SLICE_SAMPLER;
    }

    /* JADX WARN: Type inference failed for: r3v9, types: [double[], double[][]] */
    public static void main(String[] strArr) {
        Parameter.Default r0 = new Parameter.Default(new double[]{1.0d, 0.0d});
        MaskedParameter maskedParameter = new MaskedParameter(r0, new Parameter.Default(new double[]{1.0d, 0.0d}), true);
        TransformedParameter transformedParameter = new TransformedParameter(new MaskedParameter(r0, new Parameter.Default(new double[]{0.0d, 1.0d}), true), new Transform.LogTransform(), true);
        DistributionLikelihood distributionLikelihood = new DistributionLikelihood((ParametricDistributionModel) new NormalDistributionModel(maskedParameter, transformedParameter, true));
        MultivariateNormalDistribution multivariateNormalDistribution = new MultivariateNormalDistribution(new double[]{0.0d, 0.0d}, (double[][]) new double[]{new double[]{0.001d, 0.0d}, new double[]{0.0d, 0.001d}});
        MultivariateDistributionLikelihood multivariateDistributionLikelihood = new MultivariateDistributionLikelihood(multivariateNormalDistribution);
        multivariateDistributionLikelihood.addData((Parameter) r0);
        distributionLikelihood.addData(new Attribute.Default("Data", new double[]{1.0d, 2.0d, 3.0d, 4.0d, 5.0d, 6.0d, 7.0d, 8.0d, 9.0d}));
        ArrayList arrayList = new ArrayList();
        arrayList.add(distributionLikelihood);
        arrayList.add(multivariateDistributionLikelihood);
        CompoundLikelihood compoundLikelihood = new CompoundLikelihood(0, arrayList);
        EllipticalSliceOperator ellipticalSliceOperator = new EllipticalSliceOperator(r0, multivariateNormalDistribution, false, true);
        int dimension = r0.getDimension();
        double[] dArr = new double[dimension];
        double[] dArr2 = new double[dimension];
        Parameter[] parameterArr = new Parameter[dimension];
        parameterArr[0] = maskedParameter;
        parameterArr[1] = transformedParameter;
        for (int i = 0; i < 100000; i++) {
            ellipticalSliceOperator.doOperation(compoundLikelihood);
            for (int i2 = 0; i2 < dimension; i2++) {
                double doubleValue = parameterArr[i2].getValue(0).doubleValue();
                int i3 = i2;
                dArr[i3] = dArr[i3] + doubleValue;
                int i4 = i2;
                dArr2[i4] = dArr2[i4] + (doubleValue * doubleValue);
            }
        }
        for (int i5 = 0; i5 < dimension; i5++) {
            int i6 = i5;
            dArr[i6] = dArr[i6] / 100000.0d;
            int i7 = i5;
            dArr2[i7] = dArr2[i7] / 100000.0d;
            int i8 = i5;
            dArr2[i8] = dArr2[i8] - (dArr[i5] * dArr[i5]);
        }
        System.out.println("E(x)\tStErr(x)");
        for (int i9 = 0; i9 < dimension; i9++) {
            System.out.println(dArr[i9] + " " + Math.sqrt(dArr2[i9]));
        }
    }
}
