package dr.inference.operators;

import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.inferencexml.operators.SelectorOperatorParser;
import dr.math.MathUtils;
import java.util.ArrayList;
import java.util.List;

/* loaded from: input_file:dr/inference/operators/SelectorOperator.class */
public class SelectorOperator extends SimpleMCMCOperator {
    private final Parameter selector;
    private final int[] np_m1;
    private final int[] np_m2;

    public SelectorOperator(Parameter parameter) {
        this.selector = parameter;
        int size = parameter.getSize();
        this.np_m1 = new int[size + 1];
        for (int i = 0; i < this.np_m1.length; i++) {
            this.np_m1[i] = npos(size, i);
        }
        this.np_m2 = new int[size + 1];
        this.np_m2[0] = 1;
        for (int i2 = 0; i2 < size; i2++) {
            this.np_m2[i2 + 1] = 0;
            for (int i3 = 1; i3 < size + 1; i3++) {
                int[] iArr = this.np_m2;
                int i4 = i2 + 1;
                iArr[i4] = iArr[i4] + npos(i3, i2);
            }
        }
    }

    @Override // dr.inference.operators.SimpleMCMCOperator, dr.inference.operators.MCMCOperator
    public String getOperatorName() {
        return SelectorOperatorParser.SELECTOR_OPERATOR + "(" + this.selector.getParameterName() + ")";
    }

    @Override // dr.inference.operators.SimpleMCMCOperator
    public double doOperation() {
        int[] vals = vals();
        List<Integer> movesFrom_m2 = movesFrom_m2(vals);
        int nextInt = MathUtils.nextInt(movesFrom_m2.size() / 2);
        int[] iArr = new int[vals.length];
        System.arraycopy(vals, 0, iArr, 0, vals.length);
        Integer num = movesFrom_m2.get(2 * nextInt);
        iArr[num.intValue()] = movesFrom_m2.get((2 * nextInt) + 1).intValue();
        double count_sr_m2 = count_sr_m2(vals, iArr) * ((movesFrom_m2.size() * this.np_m2[max(vals) + 1]) / (movesFrom_m2(iArr).size() * this.np_m2[max(iArr) + 1]));
        this.selector.setParameterValue(num.intValue(), iArr[num.intValue()]);
        return Math.log(count_sr_m2);
    }

    public String getPerformanceSuggestion() {
        return null;
    }

    private int[] vals() {
        return intVals(this.selector);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static int[] intVals(Variable<Double> variable) {
        int[] iArr = new int[variable.getSize()];
        for (int i = 0; i < iArr.length; i++) {
            double doubleValue = variable.getValue(i).doubleValue();
            iArr[i] = (int) (doubleValue + (doubleValue >= 0.0d ? 0.5d : -0.5d));
        }
        return iArr;
    }

    private List<Integer> movesFrom_m2(int[] iArr) {
        int max = max(iArr);
        int[] counts_used_m2 = counts_used_m2(iArr, max);
        ArrayList arrayList = new ArrayList(5);
        for (int i = 0; i < iArr.length; i++) {
            int i2 = iArr[i];
            if (i2 < 0) {
                arrayList.add(Integer.valueOf(i));
                arrayList.add(0);
                if (max >= 0) {
                    arrayList.add(Integer.valueOf(i));
                    arrayList.add(Integer.valueOf(max + 1));
                }
                for (int i3 = 1; i3 < max + 1; i3++) {
                    if (counts_used_m2[i3] + 1 <= counts_used_m2[i3 - 1]) {
                        arrayList.add(Integer.valueOf(i));
                        arrayList.add(Integer.valueOf(i3));
                    }
                }
            } else if (i2 >= max || (counts_used_m2[i2] != 1 && counts_used_m2[i2] != counts_used_m2[i2 + 1])) {
                for (int i4 = 0; i4 < max + 1; i4++) {
                    if (i4 != i2 && ((i4 > i2 && counts_used_m2[i2] - 1 >= counts_used_m2[i4] + 1 && counts_used_m2[i4 - 1] >= counts_used_m2[i4] + 1) || (i4 < i2 && ((i4 > 0 && counts_used_m2[i4] + 1 <= counts_used_m2[i4 - 1]) || i4 == 0)))) {
                        arrayList.add(Integer.valueOf(i));
                        arrayList.add(Integer.valueOf(i4));
                    }
                }
                if (counts_used_m2[i2] > 1) {
                    arrayList.add(Integer.valueOf(i));
                    arrayList.add(Integer.valueOf(max + 1));
                }
                arrayList.add(Integer.valueOf(i));
                arrayList.add(-1);
            }
        }
        return arrayList;
    }

    private List<Integer> movesFrom_m1(int[] iArr) {
        int max = max(iArr);
        int[] counts_m1 = counts_m1(iArr, max);
        ArrayList arrayList = new ArrayList(5);
        for (int i = 0; i < iArr.length; i++) {
            int i2 = iArr[i];
            if (i2 >= max || (counts_m1[i2] != 1 && counts_m1[i2] != counts_m1[i2 + 1])) {
                for (int i3 = 0; i3 < max + 1; i3++) {
                    if (i3 != i2 && ((i3 > i2 && counts_m1[i2] - 1 >= counts_m1[i3] + 1 && counts_m1[i3 - 1] >= counts_m1[i3] + 1) || (i3 < i2 && ((i3 > 0 && counts_m1[i3] + 1 <= counts_m1[i3 - 1]) || i3 == 0)))) {
                        arrayList.add(Integer.valueOf(i));
                        arrayList.add(Integer.valueOf(i3));
                    }
                }
                if (counts_m1[i2] > 1) {
                    arrayList.add(Integer.valueOf(i));
                    arrayList.add(Integer.valueOf(max + 1));
                }
            }
        }
        return arrayList;
    }

    private static int npos(int i, int i2) {
        return npos(i, i2, 1);
    }

    private static int npos(int i, int i2, int i3) {
        int i4;
        if (i2 == 0 || i == 0) {
            return 1;
        }
        int i5 = 0;
        for (int i6 = i3; i6 < 1 + (i / i2) && (i4 = i - (i6 * (i2 + 1))) >= 0; i6++) {
            i5 += npos(i4, i2 - 1, 0);
        }
        return i5;
    }

    private static int sum(int[] iArr) {
        int i = 0;
        for (int i2 : iArr) {
            i += i2;
        }
        return i;
    }

    private static int max(int[] iArr) {
        int i = iArr[0];
        for (int i2 = 1; i2 < iArr.length; i2++) {
            if (i < iArr[i2]) {
                i = iArr[i2];
            }
        }
        return i;
    }

    private static int[] counts_m1(int[] iArr, int i) {
        int[] iArr2 = new int[i + 1];
        for (int i2 : iArr) {
            iArr2[i2] = iArr2[i2] + 1;
        }
        return iArr2;
    }

    private static int[] counts_m2(int[] iArr, int i) {
        int[] iArr2 = new int[i + 2];
        for (int i2 : iArr) {
            int i3 = i2 + 1;
            iArr2[i3] = iArr2[i3] + 1;
        }
        return iArr2;
    }

    static int[] counts_m2(int[] iArr) {
        return counts_m2(iArr, max(iArr));
    }

    private static int[] counts_used_m2(int[] iArr, int i) {
        int[] iArr2 = new int[i + 1];
        for (int i2 : iArr) {
            if (i2 >= 0) {
                iArr2[i2] = iArr2[i2] + 1;
            }
        }
        return iArr2;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static int[] counts_used_m2(int[] iArr) {
        return counts_used_m2(iArr, max(iArr));
    }

    private static long choose(int i, int i2) {
        double d = 1.0d;
        while (i > i2) {
            d = (d * i) / (i - i2);
            i--;
        }
        return (long) (d + 0.5d);
    }

    private static long[] countl_m1(int[] iArr) {
        int sum = sum(iArr);
        int i = 0;
        long[] jArr = new long[iArr.length];
        while (sum > 0) {
            jArr[i] = choose(sum, iArr[i]);
            sum -= iArr[i];
            i++;
        }
        return jArr;
    }

    private static double count_sr_m1(int[] iArr, int[] iArr2) {
        long[] countl_m1 = countl_m1(counts_m1(iArr, max(iArr)));
        long[] countl_m12 = countl_m1(counts_m1(iArr2, max(iArr2)));
        int min = Math.min(countl_m1.length, countl_m12.length);
        double d = 1.0d;
        for (int i = 0; i < min; i++) {
            d = (d * countl_m1[i]) / countl_m12[i];
        }
        for (int i2 = min; i2 < countl_m1.length; i2++) {
            d *= countl_m1[i2];
        }
        for (int i3 = min; i3 < countl_m12.length; i3++) {
            d /= countl_m12[i3];
        }
        return d;
    }

    private static long[] countl_m2(int[] iArr) {
        if (iArr.length == 1) {
            return new long[]{1};
        }
        int sum = sum(iArr);
        long[] jArr = new long[iArr.length];
        jArr[0] = choose(sum, iArr[0]);
        int i = sum - iArr[0];
        int i2 = 1;
        while (i > 0) {
            jArr[i2] = choose(i, iArr[i2]);
            i -= iArr[i2];
            i2++;
        }
        return jArr;
    }

    private static double count_sr_m2(int[] iArr, int[] iArr2) {
        long[] countl_m2 = countl_m2(counts_m2(iArr));
        long[] countl_m22 = countl_m2(counts_m2(iArr2));
        int min = Math.min(countl_m2.length, countl_m22.length);
        double d = 1.0d;
        for (int i = 0; i < min; i++) {
            d = (d * countl_m2[i]) / countl_m22[i];
        }
        for (int i2 = min; i2 < countl_m2.length; i2++) {
            d *= countl_m2[i2];
        }
        for (int i3 = min; i3 < countl_m22.length; i3++) {
            d /= countl_m22[i3];
        }
        return d;
    }
}
