package dr.app.beagle.tools;

import dr.evolution.alignment.SimpleAlignment;
import dr.evolution.datatype.Codons;
import dr.evolution.datatype.DataType;
import dr.evolution.sequence.Sequence;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeTrait;
import dr.evolution.tree.TreeTraitProvider;
import dr.evolution.tree.TreeUtils;
import dr.evomodel.branchratemodel.BranchRateModel;
import dr.evomodel.siteratemodel.GammaSiteRateModel;
import dr.evomodel.substmodel.FrequencyModel;
import dr.evoxml.AlignmentParser;
import dr.evoxml.util.GraphMLUtils;
import dr.inference.markovjumps.MarkovJumpsRegisterAcceptor;
import dr.inference.markovjumps.MarkovJumpsType;
import dr.inference.markovjumps.StateHistory;
import dr.inference.model.Parameter;
import dr.math.MathUtils;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.logging.Logger;

/* loaded from: input_file:dr/app/beagle/tools/CompleteHistorySimulator.class */
public class CompleteHistorySimulator extends SimpleAlignment implements MarkovJumpsRegisterAcceptor, TreeTraitProvider {
    protected int nReplications;
    protected Tree tree;
    protected GammaSiteRateModel siteModel;
    protected BranchRateModel branchRateModel;
    int categoryCount;
    int stateCount;
    private boolean branchSpecificLambda;
    private Parameter branchVariableParameter;
    private Parameter branchPossibleValuesParameter;
    private DataType dataType;
    protected List<double[]> registers;
    protected List<String> jumpTags;
    protected List<MarkovJumpsType> jumpTypes;
    protected List<double[][]> realizedJumps;
    protected boolean sumAcrossSites;
    private Map<Integer, Sequence> alignmentTraitList;
    private NumberFormat format;
    protected int nJumpProcesses = 0;
    private final Map<String, Integer> idMap = new HashMap();
    private boolean saveAlignment = false;
    private boolean alignmentOnly = false;
    protected TreeTraitProvider.Helper treeTraits = new TreeTraitProvider.Helper();

    public CompleteHistorySimulator(Tree tree, GammaSiteRateModel gammaSiteRateModel, BranchRateModel branchRateModel, int i, boolean z, Parameter parameter, Parameter parameter2) {
        this.branchSpecificLambda = false;
        this.branchVariableParameter = null;
        this.branchPossibleValuesParameter = null;
        this.tree = tree;
        this.siteModel = gammaSiteRateModel;
        this.branchRateModel = branchRateModel;
        this.nReplications = i;
        this.stateCount = this.siteModel.getSubstitutionModel().getDataType().getStateCount();
        this.categoryCount = this.siteModel.getCategoryCount();
        this.dataType = gammaSiteRateModel.getSubstitutionModel().getDataType();
        this.sumAcrossSites = z;
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 < tree.getTaxonCount(); i2++) {
            arrayList.add(tree.getTaxon(i2).getId());
        }
        int i3 = 1;
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            this.idMap.put((String) it.next(), Integer.valueOf(i3));
            i3++;
        }
        this.format = NumberFormat.getNumberInstance(Locale.ENGLISH);
        this.format.setMaximumFractionDigits(3);
        if (parameter != null && parameter2 != null) {
            if (parameter.getDimension() != 1) {
                throw new RuntimeException("branchVariableParameter has the wrong dimension; should be 1");
            }
            if (parameter2.getDimension() != tree.getNodeCount()) {
                throw new RuntimeException("branchPossibleValuesParameter has the wrong dimension; should be " + tree.getNodeCount());
            }
            this.branchSpecificLambda = true;
            this.branchPossibleValuesParameter = parameter2;
            this.branchVariableParameter = parameter;
            Logger.getLogger("dr.app.beagle.tools").info("Doing a complete history simulation using branch-specific variables\n\tReplacing variable '" + parameter.getId() + "' with values from '" + parameter2.getId() + "'");
        }
        this.alignmentTraitList = new HashMap(tree.getNodeCount());
    }

    Sequence intArray2Sequence(int[] iArr, NodeRef nodeRef) {
        String str = "";
        for (int i = 0; i < this.nReplications; i++) {
            str = this.dataType instanceof Codons ? str + this.dataType.getTriplet(iArr[i]) : str + this.dataType.getCode(iArr[i]);
        }
        return new Sequence(this.tree.getNodeTaxon(nodeRef), str);
    }

    public void addAlignmentTrait() {
        this.saveAlignment = true;
        this.treeTraits.addTrait(new TreeTrait.S() { // from class: dr.app.beagle.tools.CompleteHistorySimulator.1
            @Override // dr.evolution.tree.TreeTrait
            public String getTraitName() {
                return AlignmentParser.ALIGNMENT;
            }

            @Override // dr.evolution.tree.TreeTrait
            public TreeTrait.Intent getIntent() {
                return TreeTrait.Intent.NODE;
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // dr.evolution.tree.TreeTrait
            public String getTrait(Tree tree, NodeRef nodeRef) {
                return ((Sequence) CompleteHistorySimulator.this.alignmentTraitList.get(Integer.valueOf(nodeRef.getNumber()))).getSequenceString();
            }
        });
    }

    @Override // dr.inference.markovjumps.MarkovJumpsRegisterAcceptor
    public void addRegister(Parameter parameter, MarkovJumpsType markovJumpsType, boolean z) {
        if (this.registers == null) {
            this.registers = new ArrayList();
        }
        if (this.jumpTags == null) {
            this.jumpTags = new ArrayList();
        }
        if (this.jumpTypes == null) {
            this.jumpTypes = new ArrayList();
        }
        if (this.realizedJumps == null) {
            this.realizedJumps = new ArrayList();
        }
        final String id = parameter.getId();
        this.registers.add(parameter.getParameterValues());
        this.jumpTags.add(id);
        this.jumpTypes.add(markovJumpsType);
        this.realizedJumps.add(new double[this.tree.getNodeCount()][this.nReplications]);
        final int i = this.nJumpProcesses;
        this.treeTraits.addTrait(new TreeTrait.S() { // from class: dr.app.beagle.tools.CompleteHistorySimulator.2
            @Override // dr.evolution.tree.TreeTrait
            public String getTraitName() {
                return id;
            }

            @Override // dr.evolution.tree.TreeTrait
            public TreeTrait.Intent getIntent() {
                return TreeTrait.Intent.NODE;
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // dr.evolution.tree.TreeTrait
            public String getTrait(Tree tree, NodeRef nodeRef) {
                return CompleteHistorySimulator.this.formattedValue(tree, nodeRef, i);
            }
        });
        this.nJumpProcesses++;
    }

    public double[] getMarkovJumpsForNodeAndRegister(Tree tree, NodeRef nodeRef, int i) {
        if (this.tree != tree) {
            throw new RuntimeException("Wrong tree!");
        }
        return this.realizedJumps.get(i)[nodeRef.getNumber()];
    }

    public int getNumberOfJumpProcess() {
        return this.nJumpProcesses;
    }

    @Override // dr.evolution.tree.TreeTraitProvider
    public TreeTrait[] getTreeTraits() {
        return this.treeTraits.getTreeTraits();
    }

    @Override // dr.evolution.tree.TreeTraitProvider
    public TreeTrait getTreeTrait(String str) {
        return this.treeTraits.getTreeTrait(str);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public String formattedValue(Tree tree, NodeRef nodeRef, int i) {
        StringBuffer stringBuffer = new StringBuffer();
        double[] markovJumpsForNodeAndRegister = getMarkovJumpsForNodeAndRegister(tree, nodeRef, i);
        if (this.sumAcrossSites) {
            double d = 0.0d;
            for (double d2 : markovJumpsForNodeAndRegister) {
                d += d2;
            }
            stringBuffer.append(d);
        } else {
            stringBuffer.append(GraphMLUtils.START_SECTION);
            for (int i2 = 0; i2 < markovJumpsForNodeAndRegister.length; i2++) {
                if (i2 > 0) {
                    stringBuffer.append(",");
                }
                stringBuffer.append(markovJumpsForNodeAndRegister[i2]);
            }
            stringBuffer.append(GraphMLUtils.END_SECTION);
        }
        return stringBuffer.toString();
    }

    @Override // dr.evolution.alignment.SimpleAlignment
    public String toString() {
        StringBuffer stringBuffer = new StringBuffer();
        if (this.alignmentOnly) {
            setReportCountStatistics(false);
            stringBuffer.append(super.toString());
            stringBuffer.append("\n");
        } else {
            stringBuffer.append("alignment\n");
            stringBuffer.append(super.toString());
            stringBuffer.append("\n");
            stringBuffer.append("tree\n");
            TreeUtils.newick(this.tree, this.tree.getRoot(), true, TreeUtils.BranchLengthType.LENGTHS_AS_TIME, this.format, null, (this.nJumpProcesses > 0 || this.saveAlignment) ? new TreeTraitProvider[]{this} : null, this.idMap, stringBuffer);
            stringBuffer.append("\n");
        }
        return stringBuffer.toString();
    }

    public void simulate() {
        double[] dArr = new double[this.stateCount * this.stateCount];
        if (!this.branchSpecificLambda) {
            this.siteModel.getSubstitutionModel().getInfinitesimalMatrix(dArr);
        }
        NodeRef root = this.tree.getRoot();
        double[] categoryProportions = this.siteModel.getCategoryProportions();
        int[] iArr = new int[this.nReplications];
        for (int i = 0; i < this.nReplications; i++) {
            iArr[i] = MathUtils.randomChoicePDF(categoryProportions);
        }
        FrequencyModel frequencyModel = this.siteModel.getSubstitutionModel().getFrequencyModel();
        int[] iArr2 = new int[this.nReplications];
        for (int i2 = 0; i2 < this.nReplications; i2++) {
            iArr2[i2] = MathUtils.randomChoicePDF(frequencyModel.getFrequencies());
        }
        setDataType(this.siteModel.getSubstitutionModel().getDataType());
        traverse(root, iArr2, iArr, this, dArr);
    }

    private void traverse(NodeRef nodeRef, int[] iArr, int[] iArr2, SimpleAlignment simpleAlignment, double[] dArr) {
        if (this.saveAlignment) {
            this.alignmentTraitList.put(Integer.valueOf(nodeRef.getNumber()), intArray2Sequence(iArr, nodeRef));
        }
        for (int i = 0; i < this.tree.getChildCount(nodeRef); i++) {
            NodeRef child = this.tree.getChild(nodeRef, i);
            int[] iArr3 = new int[this.nReplications];
            StateHistory[] stateHistoryArr = new StateHistory[this.nReplications];
            if (this.branchSpecificLambda) {
                this.branchVariableParameter.setParameterValue(0, this.branchPossibleValuesParameter.getParameterValue(child.getNumber()));
                this.siteModel.getSubstitutionModel().getInfinitesimalMatrix(dArr);
            }
            for (int i2 = 0; i2 < this.nReplications; i2++) {
                stateHistoryArr[i2] = simulateAlongBranch(this.tree, child, iArr2[i2], iArr[i2], dArr);
                iArr3[i2] = stateHistoryArr[i2].getEndingState();
            }
            processHistory(child, stateHistoryArr);
            if (this.tree.getChildCount(child) == 0) {
                simpleAlignment.addSequence(intArray2Sequence(iArr3, child));
            }
            traverse(this.tree.getChild(nodeRef, i), iArr3, iArr2, simpleAlignment, dArr);
        }
    }

    protected void processHistory(NodeRef nodeRef, StateHistory[] stateHistoryArr) {
        for (int i = 0; i < this.nJumpProcesses; i++) {
            double[] dArr = this.registers.get(i);
            MarkovJumpsType markovJumpsType = this.jumpTypes.get(i);
            double[] dArr2 = this.realizedJumps.get(i)[nodeRef.getNumber()];
            for (int i2 = 0; i2 < this.nReplications; i2++) {
                if (markovJumpsType == MarkovJumpsType.COUNTS) {
                    dArr2[i2] = stateHistoryArr[i2].getTotalRegisteredCounts(dArr);
                } else {
                    if (markovJumpsType != MarkovJumpsType.REWARDS) {
                        throw new IllegalAccessError("Unknown MarkovJumps type");
                    }
                    dArr2[i2] = stateHistoryArr[i2].getTotalReward(dArr);
                }
            }
        }
    }

    private StateHistory simulateAlongBranch(Tree tree, NodeRef nodeRef, int i, int i2, double[] dArr) {
        double branchRate = this.branchRateModel.getBranchRate(tree, nodeRef) * (tree.getNodeHeight(tree.getParent(nodeRef)) - tree.getNodeHeight(nodeRef));
        if (branchRate < 0.0d) {
            throw new RuntimeException("Negative branch length: " + branchRate);
        }
        return StateHistory.simulateUnconditionalOnEndingState(0.0d, i2, this.siteModel.getRateForCategory(i) * branchRate, dArr, this.stateCount);
    }

    public void setAlignmentOnly() {
        this.alignmentOnly = true;
    }
}
