package dr.evomodel.operators;

import dr.evolution.tree.Clade;
import dr.evolution.tree.MutableTree;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evomodel.tree.AbstractCladeImportanceDistribution;
import dr.evomodel.tree.TreeModel;
import dr.inference.model.Likelihood;
import dr.inference.operators.AdaptationMode;
import dr.inference.operators.GeneralOperator;
import dr.inference.operators.OperatorSchedule;
import dr.inference.operators.SimpleMCMCOperator;
import dr.inference.operators.SimpleOperatorSchedule;
import dr.math.MathUtils;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.BitSet;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Queue;

@Deprecated
/* loaded from: input_file:dr/evomodel/operators/AbstractImportanceDistributionOperator.class */
public abstract class AbstractImportanceDistributionOperator extends SimpleMCMCOperator implements GeneralOperator {
    private OperatorSchedule schedule;
    protected TreeModel tree;
    protected AbstractCladeImportanceDistribution probabilityEstimater;
    private int sampleEvery;
    private int samples;
    private int sampleCount;
    private Queue<NodeRef> internalNodes;
    private Map<Integer, NodeRef> externalNodes;
    static final /* synthetic */ boolean $assertionsDisabled;
    private long transitions = 0;
    private boolean burnin = false;

    public AbstractImportanceDistributionOperator(TreeModel treeModel, double d) {
        this.tree = treeModel;
        setWeight(d);
        this.samples = 10000;
        this.sampleEvery = 10;
        init();
    }

    public AbstractImportanceDistributionOperator(TreeModel treeModel, double d, int i, int i2) {
        this.tree = treeModel;
        setWeight(d);
        this.samples = i;
        this.sampleEvery = i2;
        init();
    }

    private void init() {
        this.schedule = getOperatorSchedule(this.tree);
        this.sampleCount = 0;
        this.internalNodes = new LinkedList();
        this.externalNodes = new HashMap();
        fillExternalNodes(this.tree.getRoot());
    }

    @Override // dr.inference.operators.SimpleMCMCOperator
    public double doOperation() {
        return 0.0d;
    }

    @Override // dr.inference.operators.SimpleMCMCOperator
    public double doOperation(Likelihood likelihood) {
        if (this.burnin) {
            return doUnguidedOperation();
        }
        if (this.sampleCount >= this.samples * this.sampleEvery) {
            return doImportanceDistributionOperation(likelihood);
        }
        this.sampleCount++;
        if (this.sampleCount % this.sampleEvery == 0) {
            this.probabilityEstimater.addTree(this.tree);
        }
        setAcceptCount(0L);
        setRejectCount(0L);
        setTransitions(0);
        return doUnguidedOperation();
    }

    protected double doImportanceDistributionOperation(Likelihood likelihood) {
        NodeRef root = this.tree.getRoot();
        BitSet bitSet = new BitSet();
        bitSet.set(0, (this.tree.getNodeCount() + 1) / 2);
        Clade clade = new Clade(bitSet, this.tree.getNodeHeight(root));
        this.internalNodes.clear();
        fillInternalNodes(root);
        this.internalNodes.poll();
        this.externalNodes.clear();
        fillExternalNodes(root);
        double treeProbability = this.probabilityEstimater.getTreeProbability(this.tree);
        try {
            this.tree.beginTreeEdit();
            ArrayList arrayList = new ArrayList();
            extractClades(this.tree, this.tree.getRoot(), arrayList, null);
            double[] absoluteNodeHeights = getAbsoluteNodeHeights(arrayList);
            Arrays.sort(absoluteNodeHeights);
            double chanceForNodeHeights = treeProbability + getChanceForNodeHeights(absoluteNodeHeights);
            double createTree = createTree(root, clade);
            assignDummyHeights(root);
            double nodeHeights = createTree + setNodeHeights(absoluteNodeHeights);
            this.tree.endTreeEdit();
            this.tree.checkTreeIsValid();
            this.tree.pushTreeChangedEvent(root);
            return chanceForNodeHeights - nodeHeights;
        } catch (MutableTree.InvalidTreeException e) {
            throw new RuntimeException(e.getMessage());
        }
    }

    private void assignDummyHeights(NodeRef nodeRef) {
        double nodeHeight = this.tree.getNodeHeight(nodeRef) * this.tree.getInternalNodeCount();
        this.tree.setNodeHeight(nodeRef, nodeHeight);
        int childCount = this.tree.getChildCount(nodeRef);
        for (int i = 0; i < childCount; i++) {
            NodeRef child = this.tree.getChild(nodeRef, i);
            if (!this.tree.isExternal(child)) {
                assignDummyHeights(child, nodeHeight / 2.0d);
            }
        }
    }

    private void assignDummyHeights(NodeRef nodeRef, double d) {
        if (!$assertionsDisabled && this.tree.isExternal(nodeRef)) {
            throw new AssertionError();
        }
        this.tree.setNodeHeight(nodeRef, d);
        int childCount = this.tree.getChildCount(nodeRef);
        for (int i = 0; i < childCount; i++) {
            NodeRef child = this.tree.getChild(nodeRef, i);
            if (!this.tree.isExternal(child)) {
                assignDummyHeights(child, d / 2.0d);
            }
        }
    }

    private double createTree(NodeRef nodeRef, Clade clade) throws MutableTree.InvalidTreeException {
        NodeRef poll;
        NodeRef poll2;
        double d = 0.0d;
        if (clade.getSize() == 2) {
            int nextSetBit = clade.getBits().nextSetBit(0);
            int nextSetBit2 = clade.getBits().nextSetBit(nextSetBit + 1);
            NodeRef nodeRef2 = this.externalNodes.get(Integer.valueOf(nextSetBit));
            NodeRef nodeRef3 = this.externalNodes.get(Integer.valueOf(nextSetBit2));
            removeChildren(nodeRef);
            NodeRef parent = this.tree.getParent(nodeRef2);
            if (parent != null) {
                this.tree.removeChild(parent, nodeRef2);
            }
            NodeRef parent2 = this.tree.getParent(nodeRef3);
            if (parent2 != null) {
                this.tree.removeChild(parent2, nodeRef3);
            }
            this.tree.addChild(nodeRef, nodeRef2);
            this.tree.addChild(nodeRef, nodeRef3);
        } else {
            Clade[] cladeArr = new Clade[2];
            d = splitClade(clade, cladeArr);
            if (cladeArr[0].getSize() == 1) {
                poll = this.externalNodes.get(Integer.valueOf(cladeArr[0].getBits().nextSetBit(0)));
            } else {
                poll = this.internalNodes.poll();
                this.tree.setNodeHeight(poll, this.tree.getNodeHeight(nodeRef) * 0.5d);
                d += createTree(poll, cladeArr[0]);
            }
            if (cladeArr[1].getSize() == 1) {
                poll2 = this.externalNodes.get(Integer.valueOf(cladeArr[1].getBits().nextSetBit(0)));
            } else {
                poll2 = this.internalNodes.poll();
                this.tree.setNodeHeight(poll2, this.tree.getNodeHeight(nodeRef) * 0.5d);
                d += createTree(poll2, cladeArr[1]);
            }
            removeChildren(nodeRef);
            NodeRef parent3 = this.tree.getParent(poll);
            if (parent3 != null) {
                this.tree.removeChild(parent3, poll);
            }
            NodeRef parent4 = this.tree.getParent(poll2);
            if (parent4 != null) {
                this.tree.removeChild(parent4, poll2);
            }
            this.tree.addChild(nodeRef, poll);
            this.tree.addChild(nodeRef, poll2);
        }
        return d;
    }

    private void removeChildren(NodeRef nodeRef) {
        NodeRef child = this.tree.getChild(nodeRef, 0);
        if (child != null) {
            this.tree.removeChild(nodeRef, child);
        }
        NodeRef child2 = this.tree.getChild(nodeRef, 1);
        if (child2 != null) {
            this.tree.removeChild(nodeRef, child2);
        }
    }

    private double splitClade(Clade clade, Clade[] cladeArr) {
        return this.probabilityEstimater.splitClade(clade, cladeArr);
    }

    private void extractClades(Tree tree, NodeRef nodeRef, List<Clade> list, BitSet bitSet) {
        BitSet bitSet2 = new BitSet();
        if (tree.isExternal(nodeRef)) {
            bitSet2.set(nodeRef.getNumber());
        } else {
            for (int i = 0; i < tree.getChildCount(nodeRef); i++) {
                extractClades(tree, tree.getChild(nodeRef, i), list, bitSet2);
            }
            list.add(new Clade(bitSet2, tree.getNodeHeight(nodeRef)));
        }
        if (bitSet != null) {
            bitSet.or(bitSet2);
        }
    }

    private void assignCladeHeights(NodeRef nodeRef, HashMap<Clade, Double> hashMap, BitSet bitSet) {
        BitSet bitSet2 = new BitSet();
        if (this.tree.isExternal(nodeRef)) {
            bitSet2.set(nodeRef.getNumber());
        } else {
            for (int i = 0; i < this.tree.getChildCount(nodeRef); i++) {
                assignCladeHeights(this.tree.getChild(nodeRef, i), hashMap, bitSet2);
            }
            Clade clade = new Clade(bitSet2, this.tree.getNodeHeight(nodeRef));
            if (hashMap.containsKey(clade)) {
                this.tree.setNodeHeight(nodeRef, hashMap.get(clade).doubleValue());
                hashMap.remove(clade);
            }
        }
        if (bitSet != null) {
            bitSet.or(bitSet2);
        }
    }

    private double[] getRelativeNodeHeights(Tree tree) {
        int internalNodeCount = tree.getInternalNodeCount();
        double[] dArr = new double[internalNodeCount];
        for (int i = 0; i < internalNodeCount; i++) {
            NodeRef internalNode = tree.getInternalNode(i);
            dArr[i] = tree.getNodeHeight(internalNode) / tree.getNodeHeight(tree.getParent(internalNode));
        }
        return dArr;
    }

    private double[] getAbsoluteNodeHeights(Tree tree) {
        int internalNodeCount = tree.getInternalNodeCount();
        double[] dArr = new double[internalNodeCount];
        for (int i = 0; i < internalNodeCount; i++) {
            dArr[i] = tree.getNodeHeight(tree.getInternalNode(i));
        }
        return dArr;
    }

    private double[] getAbsoluteNodeHeights(List<Clade> list) {
        double[] dArr = new double[list.size()];
        int i = 0;
        Iterator<Clade> it = list.iterator();
        while (it.hasNext()) {
            dArr[i] = it.next().getHeight();
            i++;
        }
        return dArr;
    }

    private double getChanceForNodeHeights(double[] dArr) {
        return getChanceOfPermuation(dArr);
    }

    private double getChanceOfUniformNodeHeights(NodeRef nodeRef) {
        double d = 0.0d;
        NodeRef child = this.tree.getChild(nodeRef, 0);
        NodeRef child2 = this.tree.getChild(nodeRef, 1);
        if (!this.tree.isExternal(child)) {
            d = 0.0d + Math.log(1.0d / this.tree.getNodeHeight(nodeRef)) + getChanceOfUniformNodeHeights(child);
        }
        if (!this.tree.isExternal(child2)) {
            d = d + Math.log(1.0d / this.tree.getNodeHeight(nodeRef)) + getChanceOfUniformNodeHeights(child2);
        }
        return d;
    }

    private double getChanceOfPermuation(double[] dArr) {
        LinkedList linkedList = new LinkedList();
        NodeRef root = this.tree.getRoot();
        NodeRef child = this.tree.getChild(root, 0);
        NodeRef child2 = this.tree.getChild(root, 1);
        if (!this.tree.isExternal(child)) {
            linkedList.add(child);
        }
        if (!this.tree.isExternal(child2)) {
            linkedList.add(child2);
        }
        int length = dArr.length - 2;
        double d = 0.0d;
        while (!linkedList.isEmpty()) {
            int highestNode = getHighestNode(linkedList);
            d += Math.log(1.0d / linkedList.size());
            NodeRef remove = linkedList.remove(highestNode);
            this.tree.setNodeHeight(remove, dArr[length]);
            length--;
            NodeRef child3 = this.tree.getChild(remove, 0);
            NodeRef child4 = this.tree.getChild(remove, 1);
            if (!this.tree.isExternal(child3)) {
                linkedList.add(child3);
            }
            if (!this.tree.isExternal(child4)) {
                linkedList.add(child4);
            }
        }
        return d;
    }

    private int getHighestNode(List<NodeRef> list) {
        double d = 0.0d;
        int i = 0;
        for (int i2 = 0; i2 < list.size(); i2++) {
            NodeRef nodeRef = list.get(i2);
            if (this.tree.getNodeHeight(nodeRef) > d) {
                d = this.tree.getNodeHeight(nodeRef);
                i = i2;
            }
        }
        return i;
    }

    private double setNodeHeights(double[] dArr) {
        return assignPermutedNodeHeights(dArr);
    }

    private double setUniformNodeHeights(NodeRef nodeRef) {
        double d = 0.0d;
        NodeRef child = this.tree.getChild(nodeRef, 0);
        NodeRef child2 = this.tree.getChild(nodeRef, 1);
        if (!this.tree.isExternal(child)) {
            double nodeHeight = this.tree.getNodeHeight(nodeRef);
            this.tree.setNodeHeight(child, nodeHeight * MathUtils.nextDouble());
            d = 0.0d + Math.log(1.0d / nodeHeight) + setUniformNodeHeights(child);
        }
        if (!this.tree.isExternal(child2)) {
            double nodeHeight2 = this.tree.getNodeHeight(nodeRef);
            this.tree.setNodeHeight(child2, nodeHeight2 * MathUtils.nextDouble());
            d = d + Math.log(1.0d / nodeHeight2) + setUniformNodeHeights(child2);
        }
        return d;
    }

    private double assignPermutedNodeHeights(double[] dArr) {
        LinkedList linkedList = new LinkedList();
        NodeRef root = this.tree.getRoot();
        NodeRef child = this.tree.getChild(root, 0);
        NodeRef child2 = this.tree.getChild(root, 1);
        if (!this.tree.isExternal(child)) {
            linkedList.add(child);
        }
        if (!this.tree.isExternal(child2)) {
            linkedList.add(child2);
        }
        int length = dArr.length - 2;
        double d = 0.0d;
        while (!linkedList.isEmpty()) {
            int nextInt = MathUtils.nextInt(linkedList.size());
            d += Math.log(1.0d / linkedList.size());
            NodeRef nodeRef = (NodeRef) linkedList.remove(nextInt);
            this.tree.setNodeHeight(nodeRef, dArr[length]);
            length--;
            NodeRef child3 = this.tree.getChild(nodeRef, 0);
            NodeRef child4 = this.tree.getChild(nodeRef, 1);
            if (!this.tree.isExternal(child3)) {
                linkedList.add(child3);
            }
            if (!this.tree.isExternal(child4)) {
                linkedList.add(child4);
            }
        }
        return d;
    }

    private double setMissingNodeHeights(NodeRef nodeRef) {
        double d = 0.0d;
        if (!this.tree.isExternal(nodeRef)) {
            for (int i = 0; i < this.tree.getChildCount(nodeRef); i++) {
                setMissingNodeHeights(this.tree.getChild(nodeRef, i));
            }
            double minNodeHeight = getMinNodeHeight(nodeRef);
            double maxNodeHeight = getMaxNodeHeight(nodeRef);
            if (maxNodeHeight <= minNodeHeight) {
                maxNodeHeight = this.tree.getNodeHeight(this.tree.getRoot());
            }
            d = 0.0d + Math.log(1.0d / (maxNodeHeight - minNodeHeight));
            this.tree.setNodeHeight(nodeRef, minNodeHeight + (MathUtils.nextDouble() * (maxNodeHeight - minNodeHeight)));
        }
        return d;
    }

    private double getMinNodeHeight(NodeRef nodeRef) {
        double d = Double.MAX_VALUE;
        for (int i = 0; i < this.tree.getChildCount(nodeRef); i++) {
            double nodeHeight = this.tree.getNodeHeight(this.tree.getChild(nodeRef, i));
            if (nodeHeight < d) {
                d = nodeHeight;
            }
        }
        return d;
    }

    private double getMaxNodeHeight(NodeRef nodeRef) {
        return this.tree.getNodeHeight(this.tree.getParent(nodeRef));
    }

    private void fillInternalNodes(NodeRef nodeRef) {
        if (this.tree.isExternal(nodeRef)) {
            return;
        }
        this.internalNodes.add(nodeRef);
        int childCount = this.tree.getChildCount(nodeRef);
        for (int i = 0; i < childCount; i++) {
            fillInternalNodes(this.tree.getChild(nodeRef, i));
        }
    }

    private void fillExternalNodes(NodeRef nodeRef) {
        if (this.tree.isExternal(nodeRef)) {
            this.externalNodes.put(Integer.valueOf(nodeRef.getNumber()), nodeRef);
            return;
        }
        int childCount = this.tree.getChildCount(nodeRef);
        for (int i = 0; i < childCount; i++) {
            fillExternalNodes(this.tree.getChild(nodeRef, i));
        }
    }

    private OperatorSchedule getOperatorSchedule(TreeModel treeModel) {
        ExchangeOperator exchangeOperator = new ExchangeOperator(0, treeModel, 10.0d);
        ExchangeOperator exchangeOperator2 = new ExchangeOperator(1, treeModel, 3.0d);
        SubtreeSlideOperator subtreeSlideOperator = new SubtreeSlideOperator(treeModel, 10.0d, 1.0d, true, false, false, false, AdaptationMode.ADAPTATION_ON, 0.234d);
        NNI nni = new NNI(treeModel, 10.0d);
        WilsonBalding wilsonBalding = new WilsonBalding(treeModel, 3.0d);
        FNPR fnpr = new FNPR(treeModel, 5.0d);
        SimpleOperatorSchedule simpleOperatorSchedule = new SimpleOperatorSchedule();
        simpleOperatorSchedule.addOperator(exchangeOperator);
        simpleOperatorSchedule.addOperator(exchangeOperator2);
        simpleOperatorSchedule.addOperator(subtreeSlideOperator);
        simpleOperatorSchedule.addOperator(nni);
        simpleOperatorSchedule.addOperator(wilsonBalding);
        simpleOperatorSchedule.addOperator(fnpr);
        return simpleOperatorSchedule;
    }

    protected double doUnguidedOperation() {
        return ((SimpleMCMCOperator) this.schedule.getOperator(this.schedule.getNextOperatorIndex())).doOperation();
    }

    public long getTransitions() {
        return this.transitions;
    }

    public void setTransitions(int i) {
        this.transitions = i;
    }

    public double getTransistionProbability() {
        return getTransitions() / (getAcceptCount() + getRejectCount());
    }

    @Override // dr.inference.operators.SimpleMCMCOperator, dr.inference.operators.MCMCOperator
    public void reset() {
        super.reset();
        this.transitions = 0L;
    }

    public double getMinimumAcceptanceLevel() {
        return 0.5d;
    }

    public double getMaximumAcceptanceLevel() {
        return 1.0d;
    }

    public double getMinimumGoodAcceptanceLevel() {
        return 0.75d;
    }

    public double getMaximumGoodAcceptanceLevel() {
        return 1.0d;
    }

    @Override // dr.inference.operators.SimpleMCMCOperator, dr.inference.operators.MCMCOperator
    public abstract String getOperatorName();

    public abstract String getPerformanceSuggestion();

    static {
        $assertionsDisabled = !AbstractImportanceDistributionOperator.class.desiredAssertionStatus();
    }
}
