package dr.evomodel.branchratemodel;

import dr.evolution.alignment.PatternList;
import dr.evolution.datatype.DataType;
import dr.evolution.parsimony.FitchParsimony;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeTrait;
import dr.evolution.tree.TreeTraitProvider;
import dr.evolution.util.TaxonList;
import dr.evomodel.tree.TreeModel;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;

/* loaded from: input_file:dr/evomodel/branchratemodel/DiscreteTraitBranchRateModel.class */
public class DiscreteTraitBranchRateModel extends AbstractBranchRateModel {
    private static final boolean CACHING_RATES = true;
    public static final String DISCRETE_TRAIT_BRANCH_RATE_MODEL = "discreteTraitRateModel";
    protected TreeTrait trait;
    private Parameter rateParameter;
    private Parameter relativeRatesParameter;
    private Parameter indicatorParameter;
    protected int traitIndex;
    private double[] rates;
    private double[] storedRates;
    private boolean[] rateKnown;
    private TreeTrait[] traits;
    private FitchParsimony fitchParsimony;
    private boolean treeChanged;
    private Mode mode;
    private DataType dataType;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:dr/evomodel/branchratemodel/DiscreteTraitBranchRateModel$Mode.class */
    public enum Mode {
        NODE_STATES,
        MARKOV_JUMP_PROCESS,
        MARKOV_JUMP_COUNT,
        PARSIMONY
    }

    public DiscreteTraitBranchRateModel(TreeModel treeModel, PatternList patternList, int i, Parameter parameter) {
        this(treeModel, i, parameter, null, null);
        if (!TaxonList.Utils.getTaxonListIdSet(treeModel).equals(TaxonList.Utils.getTaxonListIdSet(patternList))) {
            throw new IllegalArgumentException("Tree model and pattern list must have the same list of taxa!");
        }
        parameter.setDimension(patternList.getDataType().getStateCount());
        this.fitchParsimony = new FitchParsimony(patternList, false);
        this.mode = Mode.PARSIMONY;
    }

    public DiscreteTraitBranchRateModel(TreeTraitProvider treeTraitProvider, DataType dataType, TreeModel treeModel, TreeTrait treeTrait, int i, Parameter parameter, Parameter parameter2, Parameter parameter3) {
        this(treeModel, i, parameter, parameter2, parameter3);
        this.trait = treeTrait;
        this.dataType = dataType;
        if (treeTrait.getTraitName().equals("states")) {
            this.mode = Mode.NODE_STATES;
        } else {
            this.mode = Mode.MARKOV_JUMP_PROCESS;
        }
        parameter2.setDimension(dataType.getStateCount());
        if (treeTraitProvider instanceof Model) {
            addModel((Model) treeTraitProvider);
        }
        if (treeTrait instanceof Model) {
            addModel((Model) treeTrait);
        }
    }

    public DiscreteTraitBranchRateModel(TreeTraitProvider treeTraitProvider, DataType dataType, TreeModel treeModel, TreeTrait treeTrait, int i, Parameter parameter) {
        this(treeModel, i, parameter, null, null);
        this.trait = treeTrait;
        this.dataType = dataType;
        if (treeTrait.getTraitName().equals("states")) {
            this.mode = Mode.NODE_STATES;
        } else {
            this.mode = Mode.MARKOV_JUMP_PROCESS;
        }
        parameter.setDimension(dataType.getStateCount());
        if (treeTraitProvider instanceof Model) {
            addModel((Model) treeTraitProvider);
        }
        if (treeTrait instanceof Model) {
            addModel((Model) treeTrait);
        }
    }

    public DiscreteTraitBranchRateModel(TreeTraitProvider treeTraitProvider, TreeTrait[] treeTraitArr, TreeModel treeModel, Parameter parameter) {
        this(treeModel, 0, parameter, null, null);
        this.traits = treeTraitArr;
        this.mode = Mode.MARKOV_JUMP_PROCESS;
        parameter.setDimension(treeTraitArr.length);
        if (treeTraitProvider instanceof Model) {
            addModel((Model) treeTraitProvider);
        }
    }

    private DiscreteTraitBranchRateModel(TreeModel treeModel, int i, Parameter parameter, Parameter parameter2, Parameter parameter3) {
        super(DISCRETE_TRAIT_BRANCH_RATE_MODEL);
        this.trait = null;
        this.treeChanged = true;
        addModel(treeModel);
        this.traitIndex = i;
        this.rateParameter = parameter;
        addVariable(parameter);
        this.relativeRatesParameter = parameter2;
        if (parameter2 != null) {
            addVariable(parameter2);
        }
        this.indicatorParameter = parameter3;
        if (parameter3 != null) {
            addVariable(parameter3);
        }
        this.rates = new double[treeModel.getNodeCount()];
        this.storedRates = new double[treeModel.getNodeCount()];
        this.rateKnown = new boolean[treeModel.getNodeCount()];
    }

    @Override // dr.inference.model.AbstractModel
    public void handleModelChangedEvent(Model model, Object obj, int i) {
        for (int i2 = 0; i2 < this.rateKnown.length; i2++) {
            this.rateKnown[i2] = false;
        }
        this.treeChanged = true;
        fireModelChanged();
    }

    @Override // dr.inference.model.AbstractModel
    protected final void handleVariableChangedEvent(Variable variable, int i, Variable.ChangeType changeType) {
        for (int i2 = 0; i2 < this.rateKnown.length; i2++) {
            this.rateKnown[i2] = false;
        }
        fireModelChanged();
    }

    @Override // dr.inference.model.AbstractModel
    protected void storeState() {
        System.arraycopy(this.rates, 0, this.storedRates, 0, this.rates.length);
    }

    @Override // dr.inference.model.AbstractModel
    protected void restoreState() {
        double[] dArr = this.rates;
        this.rates = this.storedRates;
        this.storedRates = dArr;
        for (int i = 0; i < this.rateKnown.length; i++) {
            this.rateKnown[i] = true;
        }
    }

    @Override // dr.inference.model.AbstractModel
    protected void acceptState() {
    }

    protected int getStateCount() {
        int i = 0;
        if (this.mode == Mode.NODE_STATES || this.mode == Mode.MARKOV_JUMP_PROCESS) {
            i = this.dataType.getStateCount();
        } else if (this.mode == Mode.PARSIMONY) {
            i = this.fitchParsimony.getPatterns().getStateCount();
        }
        return i;
    }

    @Override // dr.evolution.tree.BranchRates
    public double getBranchRate(Tree tree, NodeRef nodeRef) {
        if (!this.rateKnown[nodeRef.getNumber()]) {
            this.rates[nodeRef.getNumber()] = getRawBranchRate(tree, nodeRef);
            this.rateKnown[nodeRef.getNumber()] = true;
        }
        return this.rates[nodeRef.getNumber()];
    }

    protected double getRawBranchRate(Tree tree, NodeRef nodeRef) {
        double d = 0.0d;
        int stateCount = getStateCount();
        double[] processValues = getProcessValues(tree, nodeRef);
        double[] dArr = new double[stateCount];
        double d2 = 0.0d;
        for (int i = 0; i < stateCount; i++) {
            int i2 = i;
            dArr[i2] = dArr[i2] + processValues[i];
            d2 += processValues[i];
        }
        for (int i3 = 0; i3 < stateCount; i3++) {
            int i4 = i3;
            dArr[i4] = dArr[i4] / d2;
        }
        if (this.relativeRatesParameter != null && this.indicatorParameter == null) {
            double parameterValue = this.rateParameter.getParameterValue(0);
            for (int i5 = 0; i5 < stateCount; i5++) {
                d += parameterValue * this.relativeRatesParameter.getParameterValue(i5) * dArr[i5];
            }
        } else if (this.relativeRatesParameter == null || this.indicatorParameter == null) {
            for (int i6 = 0; i6 < stateCount; i6++) {
                d += this.rateParameter.getParameterValue(i6) * processValues[i6];
                d2 += processValues[i6];
            }
        } else {
            double parameterValue2 = this.rateParameter.getParameterValue(0);
            for (int i7 = 0; i7 < stateCount; i7++) {
                d += parameterValue2 * this.relativeRatesParameter.getParameterValue(i7) * dArr[i7] * this.indicatorParameter.getParameterValue(i7);
            }
        }
        return d;
    }

    private double[] getProcessValues(Tree tree, NodeRef nodeRef) {
        double[] dArr = null;
        int stateCount = getStateCount();
        double branchLength = tree.getBranchLength(nodeRef);
        if (this.mode == Mode.MARKOV_JUMP_PROCESS) {
            dArr = new double[stateCount];
            for (int i = 0; i < stateCount; i++) {
                dArr[i] = ((TreeTrait.DA) this.traits[i]).getTrait(tree, nodeRef)[0];
            }
        } else if (this.mode == Mode.PARSIMONY) {
            if (this.treeChanged) {
                this.fitchParsimony.initialize(tree);
                this.treeChanged = false;
            }
            int[] states = this.fitchParsimony.getStates(tree, nodeRef);
            int[] states2 = this.fitchParsimony.getStates(tree, tree.getParent(nodeRef));
            dArr = new double[this.fitchParsimony.getPatterns().getStateCount()];
            for (int i2 : states) {
                dArr[i2] = dArr[i2] + (branchLength / 2.0d);
            }
            for (int i3 : states2) {
                dArr[i3] = dArr[i3] + (branchLength / 2.0d);
            }
            for (int i4 = 0; i4 < dArr.length; i4++) {
                int i5 = i4;
                dArr[i5] = dArr[i5] / ((states.length + states2.length) / 2);
            }
        } else if (this.mode == Mode.NODE_STATES) {
            dArr = new double[stateCount];
            int i6 = ((int[]) this.trait.getTrait(tree, nodeRef))[this.traitIndex];
            dArr[i6] = dArr[i6] + (branchLength / 2.0d);
            int i7 = ((int[]) this.trait.getTrait(tree, tree.getParent(nodeRef)))[this.traitIndex];
            dArr[i7] = dArr[i7] + (branchLength / 2.0d);
        }
        return dArr;
    }
}
