package dr.evomodel.arg.branchratemodel;

import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evomodel.arg.ARGModel;
import dr.evomodel.branchratemodel.AbstractBranchRateModel;
import dr.inference.distribution.ParametricDistributionModel;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.xml.AbstractXMLObjectParser;
import dr.xml.AttributeRule;
import dr.xml.ElementRule;
import dr.xml.XMLObject;
import dr.xml.XMLObjectParser;
import dr.xml.XMLParseException;
import dr.xml.XMLSyntaxRule;
import java.util.logging.Logger;

/* loaded from: input_file:dr/evomodel/arg/branchratemodel/ARGDiscretizedBranchRates.class */
public class ARGDiscretizedBranchRates extends AbstractBranchRateModel {
    public static final String DISCRETIZED_BRANCH_RATES = "argDiscretizedBranchRates";
    public static final String DISTRIBUTION = "distribution";
    public static final String NUM_RATE_CATEGORIES = "numRateCategories";
    public static final String SINGLE_ROOT_RATE = "singleRootRate";
    private ParametricDistributionModel distributionModel;
    private ARGModel tree;
    private Parameter rateCategoryParameter;
    private int rootNodeNumber;
    private int storedRootNodeNumber;
    private final int categoryCount;
    private final double step;
    private final double[] rates;
    private boolean ratesKnown;
    public static XMLObjectParser PARSER = new AbstractXMLObjectParser() { // from class: dr.evomodel.arg.branchratemodel.ARGDiscretizedBranchRates.1
        private XMLSyntaxRule[] rules = {new ElementRule(ARGModel.class), new ElementRule("distribution", ParametricDistributionModel.class, "The distribution model for rates among branches", false), AttributeRule.newIntegerRule(ARGDiscretizedBranchRates.NUM_RATE_CATEGORIES)};

        @Override // dr.xml.XMLObjectParser
        public String getParserName() {
            return ARGDiscretizedBranchRates.DISCRETIZED_BRANCH_RATES;
        }

        @Override // dr.xml.AbstractXMLObjectParser
        public Object parseXMLObject(XMLObject xMLObject) throws XMLParseException {
            ARGModel aRGModel = (ARGModel) xMLObject.getChild(ARGModel.class);
            ParametricDistributionModel parametricDistributionModel = (ParametricDistributionModel) xMLObject.getChild("distribution");
            int integerAttribute = xMLObject.getIntegerAttribute(ARGDiscretizedBranchRates.NUM_RATE_CATEGORIES);
            Logger.getLogger("dr.evomodel").info("Using discretized relaxed clock model.");
            Logger.getLogger("dr.evomodel").info("  parametric model = " + parametricDistributionModel.getModelName());
            Logger.getLogger("dr.evomodel").info("   rate categories = " + integerAttribute);
            if (xMLObject.hasAttribute("singleRootRate")) {
                Logger.getLogger("dr.evomodel").warning("   WARNING: single root rate is not implemented!");
            }
            return new ARGDiscretizedBranchRates(aRGModel, integerAttribute, parametricDistributionModel);
        }

        @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
        public String getParserDescription() {
            return "This element returns an discretized relaxed clock model.The branch rates are drawn from a discretized parametric distribution.";
        }

        @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
        public Class getReturnType() {
            return ARGDiscretizedBranchRates.class;
        }

        @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
        public XMLSyntaxRule[] getSyntaxRules() {
            return this.rules;
        }
    };

    public ARGDiscretizedBranchRates(ARGModel aRGModel, int i, ParametricDistributionModel parametricDistributionModel) {
        super(DISCRETIZED_BRANCH_RATES);
        this.ratesKnown = false;
        this.tree = aRGModel;
        this.categoryCount = i;
        this.step = 1.0d / this.categoryCount;
        this.rates = new double[this.categoryCount];
        this.distributionModel = parametricDistributionModel;
        this.rateCategoryParameter = this.rateCategoryParameter;
        if (i > aRGModel.getNodeCount() - 1) {
            throw new IllegalArgumentException("The rate category parameter must be less than the length 2*tipCount-1");
        }
        this.ratesKnown = false;
        addModel(parametricDistributionModel);
        addModel(aRGModel);
        this.rootNodeNumber = aRGModel.getRoot().getNumber();
        this.storedRootNodeNumber = this.rootNodeNumber;
    }

    @Override // dr.inference.model.AbstractModel
    public void handleModelChangedEvent(Model model, Object obj, int i) {
        if (model == this.distributionModel) {
            this.ratesKnown = false;
        } else if (model == this.tree) {
        }
        fireModelChanged();
    }

    @Override // dr.inference.model.AbstractModel
    protected final void handleVariableChangedEvent(Variable variable, int i, Variable.ChangeType changeType) {
        fireModelChanged();
    }

    @Override // dr.inference.model.AbstractModel
    protected void storeState() {
        this.storedRootNodeNumber = this.rootNodeNumber;
    }

    @Override // dr.inference.model.AbstractModel
    protected void restoreState() {
        this.ratesKnown = false;
        this.rootNodeNumber = this.storedRootNodeNumber;
    }

    @Override // dr.inference.model.AbstractModel
    protected void acceptState() {
    }

    @Override // dr.evolution.tree.BranchRates
    public double getBranchRate(Tree tree, NodeRef nodeRef) {
        if (tree.isRoot(nodeRef)) {
            throw new IllegalArgumentException("root node doesn't have a rate!");
        }
        if (!this.ratesKnown) {
            setupRates();
            this.ratesKnown = true;
        }
        return this.rates[(int) tree.getNodeRate(nodeRef)];
    }

    private void setupRates() {
        double d = this.step / 2.0d;
        for (int i = 0; i < this.categoryCount; i++) {
            this.rates[i] = this.distributionModel.quantile(d);
            d += this.step;
        }
    }

    private void shuffleIndices() {
        int number = this.tree.getRoot().getNumber();
        if (this.rootNodeNumber > number) {
            int round = (int) Math.round(this.rateCategoryParameter.getParameterValue(number));
            int min = Math.min(this.rateCategoryParameter.getDimension() - 1, this.rootNodeNumber);
            for (int i = number; i < min; i++) {
                this.rateCategoryParameter.setParameterValue(i, this.rateCategoryParameter.getParameterValue(i + 1));
            }
            this.rateCategoryParameter.setParameterValue(min, round);
        } else if (this.rootNodeNumber < number) {
            int min2 = Math.min(this.rateCategoryParameter.getDimension() - 1, number);
            int round2 = (int) Math.round(this.rateCategoryParameter.getParameterValue(min2));
            for (int i2 = min2; i2 > this.rootNodeNumber; i2--) {
                this.rateCategoryParameter.setParameterValue(i2, this.rateCategoryParameter.getParameterValue(i2 - 1));
            }
            this.rateCategoryParameter.setParameterValue(this.rootNodeNumber, round2);
        }
        this.rootNodeNumber = number;
    }
}
