package dr.evomodel.continuous;

import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evomodel.tree.TreeChangedEvent;
import dr.evomodel.tree.TreeModel;
import dr.evomodel.treedatalikelihood.continuous.cdi.ContinuousDiffusionIntegrator;
import dr.inference.model.AbstractModel;
import dr.inference.model.Model;
import dr.inference.model.Variable;
import java.util.BitSet;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;

/* loaded from: input_file:dr/evomodel/continuous/RestrictedPartialsModel.class */
public class RestrictedPartialsModel extends AbstractModel {
    private final TreeModel treeModel;
    private final List<RestrictedPartials> restrictedPartialsList;
    private final ContinuousDiffusionIntegrator cdi;
    private boolean updateTreeMapping;
    private boolean updateRestrictedPartials;
    private final int startingPartialCount;
    private final int startingMatrixCount;
    private int partialsCount;
    private int matrixCount;
    private final int sparePartialIndex;
    protected Map<BitSet, RestrictedPartials> clampList;
    protected Map<NodeRef, RestrictedPartials> nodeToClampMap;
    protected boolean anyClamps;

    public RestrictedPartialsModel(String str, List<RestrictedPartials> list, int i, int i2) {
        super(str);
        this.updateTreeMapping = true;
        this.updateRestrictedPartials = true;
        this.clampList = null;
        this.nodeToClampMap = null;
        this.anyClamps = false;
        this.treeModel = validateTreeModel(list);
        this.startingPartialCount = i;
        this.startingMatrixCount = i2;
        this.partialsCount = 0;
        this.matrixCount = 0;
        for (RestrictedPartials restrictedPartials : list) {
            restrictedPartials.setIndex(this.partialsCount + i);
            addRestrictedPartials(restrictedPartials);
            this.partialsCount++;
        }
        this.sparePartialIndex = this.partialsCount + i;
        this.partialsCount++;
        setupClamps();
        this.restrictedPartialsList = list;
        this.cdi = null;
    }

    public int getExtraPartialBufferCount() {
        return this.partialsCount;
    }

    public int getExtraMatrixBufferCount() {
        return this.matrixCount;
    }

    public void updatePartialRestrictions() {
        if (this.updateTreeMapping) {
            setupClamps();
            this.updateTreeMapping = false;
        }
        if (this.updateRestrictedPartials) {
            setupRestrictedPartials();
            this.updateRestrictedPartials = false;
        }
    }

    public boolean hasAnyPartialRestrictions() {
        return this.anyClamps;
    }

    private TreeModel validateTreeModel(List<RestrictedPartials> list) {
        TreeModel treeModel = list.get(0).getTreeModel();
        Iterator<RestrictedPartials> it = list.iterator();
        while (it.hasNext()) {
            if (it.next().getTreeModel() != treeModel) {
                throw new IllegalArgumentException("All tree models must be the same");
            }
        }
        return treeModel;
    }

    private void addRestrictedPartials(RestrictedPartials restrictedPartials) {
        if (this.clampList == null) {
            this.clampList = new HashMap();
        }
        this.clampList.put(restrictedPartials.getTipBitSet(), restrictedPartials);
        addModel(restrictedPartials);
        System.err.println("Added a CLAMP!");
    }

    private void setupRestrictedPartials() {
        for (RestrictedPartials restrictedPartials : this.clampList.values()) {
            int index = restrictedPartials.getIndex();
            double[] restrictedPartials2 = restrictedPartials.getRestrictedPartials();
            ContinuousDiffusionIntegrator continuousDiffusionIntegrator = null;
            continuousDiffusionIntegrator.setPostOrderPartial(index, restrictedPartials2);
        }
    }

    @Override // dr.inference.model.AbstractModel
    protected void handleModelChangedEvent(Model model, Object obj, int i) {
        if (model != this.treeModel) {
            if (!(model instanceof RestrictedPartials)) {
                throw new RuntimeException("Unknown model");
            }
            this.updateRestrictedPartials = true;
        } else if (obj instanceof TreeChangedEvent) {
            TreeChangedEvent treeChangedEvent = (TreeChangedEvent) obj;
            if (treeChangedEvent.isTreeChanged()) {
                this.updateTreeMapping = true;
            } else if (treeChangedEvent.isNodeChanged()) {
                this.updateTreeMapping = true;
            }
        }
    }

    @Override // dr.inference.model.AbstractModel
    protected void storeState() {
    }

    @Override // dr.inference.model.AbstractModel
    protected void restoreState() {
        this.updateTreeMapping = true;
        this.updateRestrictedPartials = true;
    }

    @Override // dr.inference.model.AbstractModel
    protected void acceptState() {
    }

    @Override // dr.inference.model.AbstractModel
    protected void handleVariableChangedEvent(Variable variable, int i, Variable.ChangeType changeType) {
    }

    public void setupClamps() {
        if (this.nodeToClampMap == null) {
            this.nodeToClampMap = new HashMap();
        }
        this.nodeToClampMap.clear();
        recursiveSetupClamp(this.treeModel, this.treeModel.getRoot(), new BitSet());
        this.anyClamps = this.nodeToClampMap.size() > 0;
    }

    private void recursiveSetupClamp(Tree tree, NodeRef nodeRef, BitSet bitSet) {
        if (tree.isExternal(nodeRef)) {
            bitSet.set(nodeRef.getNumber());
            return;
        }
        for (int i = 0; i < tree.getChildCount(nodeRef); i++) {
            NodeRef child = tree.getChild(nodeRef, i);
            BitSet bitSet2 = new BitSet();
            recursiveSetupClamp(tree, child, bitSet2);
            bitSet.or(bitSet2);
        }
        if (this.clampList.containsKey(bitSet)) {
            RestrictedPartials restrictedPartials = this.clampList.get(bitSet);
            restrictedPartials.setNode(nodeRef);
            this.nodeToClampMap.put(nodeRef, restrictedPartials);
        }
    }
}
