package dr.evomodel.branchratemodel;

import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evomodel.tree.TreeModel;
import dr.evomodel.tree.TreeParameterModel;
import dr.evomodel.tree.randomlocalmodel.RandomLocalTreeVariable;
import dr.evomodelxml.branchratemodel.RandomLocalClockModelParser;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.util.Author;
import dr.util.Citable;
import dr.util.Citation;
import java.util.Collections;
import java.util.List;
import java.util.logging.Logger;

/* loaded from: input_file:dr/evomodel/branchratemodel/RandomLocalClockModel.class */
public class RandomLocalClockModel extends AbstractBranchRateModel implements RandomLocalTreeVariable, Citable {
    private double scaleFactor;
    private TreeModel treeModel;
    private boolean ratesAreMultipliers;
    private double[] unscaledBranchRates;
    private Parameter meanRateParameter;
    private TreeParameterModel indicators;
    private TreeParameterModel rates;
    private boolean recalculationNeeded;
    private final double threshold;
    public static Citation CITATION = new Citation(new Author[]{new Author("AJ", "Drummond"), new Author("MA", "Suchard")}, "Bayesian random local clocks, or one rate to rule them all", 2010, "BMC Biology", "8: 114", "10.1186/1741-7007-8-114");

    public RandomLocalClockModel(TreeModel treeModel, Parameter parameter, Parameter parameter2, Parameter parameter3, boolean z, double d) {
        super(RandomLocalClockModelParser.LOCAL_BRANCH_RATES);
        this.recalculationNeeded = true;
        this.ratesAreMultipliers = z;
        this.indicators = new TreeParameterModel(treeModel, parameter2, false);
        this.rates = new TreeParameterModel(treeModel, parameter3, false);
        if (Double.isNaN(d)) {
            parameter2.addBounds(new Parameter.DefaultBounds(1.0d, 0.0d, parameter2.getDimension()));
            this.threshold = 0.5d;
            for (int i = 0; i < parameter2.getDimension(); i++) {
                parameter2.setParameterValue(i, 0.0d);
            }
        } else {
            parameter2.addBounds(new Parameter.DefaultBounds(Double.MAX_VALUE, -1.7976931348623157E308d, parameter2.getDimension()));
            this.threshold = d;
        }
        parameter3.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, 0.0d, parameter3.getDimension()));
        for (int i2 = 0; i2 < parameter2.getDimension(); i2++) {
            parameter3.setParameterValue(i2, 1.0d);
        }
        this.meanRateParameter = parameter;
        addModel(treeModel);
        this.treeModel = treeModel;
        addModel(this.indicators);
        addModel(this.rates);
        if (parameter != null) {
            addVariable(parameter);
        }
        this.unscaledBranchRates = new double[treeModel.getNodeCount()];
        Logger.getLogger("dr.evomodel").info("  indicator parameter name is '" + parameter2.getId() + "' with threshold = " + d);
        recalculateScaleFactor();
    }

    @Override // dr.evomodel.tree.randomlocalmodel.RandomLocalTreeVariable
    public final double getVariable(Tree tree, NodeRef nodeRef) {
        return this.rates.getNodeValue(tree, nodeRef);
    }

    @Override // dr.evomodel.tree.randomlocalmodel.RandomLocalTreeVariable
    public final boolean isVariableSelected(Tree tree, NodeRef nodeRef) {
        return this.indicators.getNodeValue(tree, nodeRef) > this.threshold;
    }

    @Override // dr.inference.model.AbstractModel
    public void handleModelChangedEvent(Model model, Object obj, int i) {
        this.recalculationNeeded = true;
        fireModelChanged();
    }

    @Override // dr.inference.model.AbstractModel
    protected final void handleVariableChangedEvent(Variable variable, int i, Variable.ChangeType changeType) {
        this.recalculationNeeded = true;
        fireModelChanged();
    }

    @Override // dr.inference.model.AbstractModel
    protected void storeState() {
    }

    @Override // dr.inference.model.AbstractModel
    protected void restoreState() {
        this.recalculationNeeded = true;
    }

    @Override // dr.inference.model.AbstractModel
    protected void acceptState() {
    }

    @Override // dr.evolution.tree.BranchRates
    public double getBranchRate(Tree tree, NodeRef nodeRef) {
        if (this.recalculationNeeded) {
            recalculateScaleFactor();
            this.recalculationNeeded = false;
        }
        return this.unscaledBranchRates[nodeRef.getNumber()] * this.scaleFactor;
    }

    private void calculateUnscaledBranchRates(TreeModel treeModel) {
        recursivelyCompute(treeModel, treeModel.getRoot(), 1.0d);
    }

    private void recursivelyCompute(TreeModel treeModel, NodeRef nodeRef, double d) {
        int number = nodeRef.getNumber();
        if (!treeModel.isRoot(nodeRef) && isVariableSelected(treeModel, nodeRef)) {
            d = this.ratesAreMultipliers ? d * getVariable(treeModel, nodeRef) : getVariable(treeModel, nodeRef);
        }
        this.unscaledBranchRates[number] = d;
        int childCount = treeModel.getChildCount(nodeRef);
        for (int i = 0; i < childCount; i++) {
            recursivelyCompute(treeModel, treeModel.getChild(nodeRef, i), d);
        }
    }

    private void recalculateScaleFactor() {
        calculateUnscaledBranchRates(this.treeModel);
        double d = 0.0d;
        double d2 = 0.0d;
        for (int i = 0; i < this.treeModel.getNodeCount(); i++) {
            NodeRef node = this.treeModel.getNode(i);
            if (!this.treeModel.isRoot(node)) {
                double nodeHeight = this.treeModel.getNodeHeight(this.treeModel.getParent(node)) - this.treeModel.getNodeHeight(node);
                d += nodeHeight;
                d2 += nodeHeight * this.unscaledBranchRates[node.getNumber()];
            }
        }
        this.scaleFactor = d / d2;
        if (this.meanRateParameter != null) {
            this.scaleFactor *= this.meanRateParameter.getParameterValue(0);
        }
    }

    @Override // dr.util.Citable
    public Citation.Category getCategory() {
        return Citation.Category.MOLECULAR_CLOCK;
    }

    @Override // dr.util.Citable
    public String getDescription() {
        return "Local clock model";
    }

    @Override // dr.util.Citable
    public List<Citation> getCitations() {
        return Collections.singletonList(CITATION);
    }
}
