package dr.app.beagle.tools;

import beagle.Beagle;
import beagle.BeagleFactory;
import dr.app.bss.Utils;
import dr.evolution.datatype.Codons;
import dr.evolution.datatype.DataType;
import dr.evolution.sequence.Sequence;
import dr.evolution.tree.NodeRef;
import dr.evolution.util.Taxon;
import dr.evomodel.branchmodel.BranchModel;
import dr.evomodel.branchratemodel.BranchRateModel;
import dr.evomodel.siteratemodel.GammaSiteRateModel;
import dr.evomodel.substmodel.FrequencyModel;
import dr.evomodel.tree.TreeModel;
import dr.evomodel.treedatalikelihood.BufferIndexHelper;
import dr.evomodel.treelikelihood.SubstitutionModelDelegate;
import dr.math.MathUtils;
import java.util.LinkedHashMap;
import java.util.Map;
import org.apache.commons.math.random.MersenneTwister;

/* loaded from: input_file:dr/app/beagle/tools/Partition.class */
public class Partition {
    private static final boolean DEBUG = false;
    public int from;
    public int to;
    public int every;
    private BranchModel branchModel;
    private TreeModel treeModel;
    private GammaSiteRateModel siteRateModel;
    private BranchRateModel branchRateModel;
    private FrequencyModel freqModel;
    private BufferIndexHelper partialBufferHelper;
    private BufferIndexHelper scaleBufferHelper;
    private BufferIndexHelper matrixBufferHelper;

    /* renamed from: beagle, reason: collision with root package name */
    private Beagle f2beagle;
    private SubstitutionModelDelegate substitutionModelDelegate;
    private Integer partitionNumber;
    private int partitionSiteCount;
    private int nodeCount;
    private int tipCount;
    private int internalNodeCount;
    private int stateCount;
    private int compactPartialsCount;
    private int patternCount;
    private int siteRateCategoryCount;
    private LinkedHashMap<Taxon, int[]> alignmentMap;
    private DataType dataType;
    private boolean hasRootSequence;
    private Sequence rootSequence;
    private boolean outputAncestralSequences;
    private MersenneTwister random;

    public Partition(TreeModel treeModel, BranchModel branchModel, GammaSiteRateModel gammaSiteRateModel, BranchRateModel branchRateModel, FrequencyModel frequencyModel, int i, int i2, int i3, DataType dataType) {
        this.hasRootSequence = false;
        this.rootSequence = null;
        this.outputAncestralSequences = false;
        this.treeModel = treeModel;
        this.siteRateModel = gammaSiteRateModel;
        this.freqModel = frequencyModel;
        this.branchModel = branchModel;
        this.branchRateModel = branchRateModel;
        this.from = i;
        this.to = i2;
        this.every = i3;
        if (dataType == null) {
            this.dataType = frequencyModel.getDataType();
        } else {
            this.dataType = dataType;
        }
        this.partitionSiteCount = getPartitionSiteCount();
        setBufferHelpers();
        setSubstitutionModelDelegate();
        loadBeagleInstance();
        this.alignmentMap = new LinkedHashMap<>();
        this.random = new MersenneTwister(MathUtils.nextLong());
    }

    public Partition(TreeModel treeModel, BranchModel branchModel, GammaSiteRateModel gammaSiteRateModel, BranchRateModel branchRateModel, FrequencyModel frequencyModel, int i, int i2, int i3) {
        this(treeModel, branchModel, gammaSiteRateModel, branchRateModel, frequencyModel, i, i2, i3, null);
    }

    private void setSubstitutionModelDelegate() {
        this.substitutionModelDelegate = new SubstitutionModelDelegate(this.treeModel, this.branchModel);
    }

    private void setBufferHelpers() {
        this.nodeCount = this.treeModel.getNodeCount();
        this.matrixBufferHelper = new BufferIndexHelper(this.nodeCount, 0);
        this.tipCount = this.treeModel.getExternalNodeCount();
        this.internalNodeCount = this.treeModel.getInternalNodeCount();
        this.partialBufferHelper = new BufferIndexHelper(this.nodeCount, this.tipCount);
        this.scaleBufferHelper = new BufferIndexHelper(this.internalNodeCount + 1, 0);
    }

    public void loadBeagleInstance() {
        this.compactPartialsCount = this.tipCount;
        this.stateCount = this.dataType.getStateCount();
        this.patternCount = this.partitionSiteCount;
        this.siteRateCategoryCount = this.siteRateModel.getCategoryCount();
        this.f2beagle = BeagleFactory.loadBeagleInstance(this.tipCount, this.partialBufferHelper.getBufferCount(), this.compactPartialsCount, this.stateCount, this.patternCount, this.substitutionModelDelegate.getEigenBufferCount(), this.substitutionModelDelegate.getMatrixBufferCount(), this.siteRateCategoryCount, this.scaleBufferHelper.getBufferCount(), new int[]{0}, 0L, 0L);
    }

    public void simulatePartition() {
        try {
            NodeRef root = this.treeModel.getRoot();
            this.f2beagle.setCategoryRates(this.siteRateModel.getCategoryRates());
            double[] categoryProportions = this.siteRateModel.getCategoryProportions();
            int[] iArr = new int[this.partitionSiteCount];
            for (int i = 0; i < this.partitionSiteCount; i++) {
                iArr[i] = randomChoicePDF(categoryProportions, this.partitionNumber.intValue(), "categories");
            }
            int[] iArr2 = new int[this.partitionSiteCount];
            if (!this.hasRootSequence) {
                double[] frequencies = this.freqModel.getFrequencies();
                for (int i2 = 0; i2 < this.partitionSiteCount; i2++) {
                    iArr2[i2] = randomChoicePDF(frequencies, this.partitionNumber.intValue(), "root");
                }
            } else if (this.rootSequence.getLength() == this.partitionSiteCount) {
                iArr2 = sequence2intArray(this.rootSequence);
            } else {
                if (!(this.dataType instanceof Codons) || this.rootSequence.getLength() != 3 * this.partitionSiteCount) {
                    throw new RuntimeException("Ancestral sequence length of " + this.rootSequence.getLength() + " does not match partition site count of " + this.partitionSiteCount + ".");
                }
                iArr2 = sequence2intArray(this.rootSequence);
            }
            this.substitutionModelDelegate.updateSubstitutionModels(this.f2beagle);
            traverse(root, iArr2, iArr);
            this.f2beagle.finalize();
        } catch (Exception e) {
            e.printStackTrace();
        } catch (Throwable th) {
            System.err.println("BeagleException: " + th.getMessage());
            System.exit(-1);
        }
    }

    private void traverse(NodeRef nodeRef, int[] iArr, int[] iArr2) {
        for (int i = 0; i < this.treeModel.getChildCount(nodeRef); i++) {
            NodeRef child = this.treeModel.getChild(nodeRef, i);
            int[] iArr3 = new int[this.partitionSiteCount];
            double[] dArr = new double[this.stateCount];
            double[][] transitionProbabilities = getTransitionProbabilities(child);
            for (int i2 = 0; i2 < this.partitionSiteCount; i2++) {
                System.arraycopy(transitionProbabilities[iArr2[i2]], iArr[i2] * this.stateCount, dArr, 0, this.stateCount);
                iArr3[i2] = randomChoicePDF(dArr, this.partitionNumber.intValue(), "seq");
            }
            if (this.treeModel.getChildCount(child) == 0) {
                this.alignmentMap.put(this.treeModel.getNodeTaxon(child), iArr3);
            } else if (this.outputAncestralSequences) {
                this.alignmentMap.put(new Taxon("internalNodeHeight" + this.treeModel.getNodeHeight(child)), iArr3);
            }
            traverse(this.treeModel.getChild(nodeRef, i), iArr3, iArr2);
        }
    }

    private double[][] getTransitionProbabilities(NodeRef nodeRef) {
        double[][] dArr = new double[this.siteRateCategoryCount][this.stateCount * this.stateCount];
        int number = nodeRef.getNumber();
        this.matrixBufferHelper.flipOffset(number);
        this.substitutionModelDelegate.updateTransitionMatrices(this.f2beagle, new int[]{number}, new double[]{this.treeModel.getBranchLength(nodeRef) * this.branchRateModel.getBranchRate(this.treeModel, nodeRef)}, 1);
        double[] dArr2 = new double[this.siteRateCategoryCount * this.stateCount * this.stateCount];
        this.f2beagle.getTransitionMatrix(number, dArr2);
        for (int i = 0; i < this.siteRateCategoryCount; i++) {
            System.arraycopy(dArr2, i * this.stateCount * this.stateCount, dArr[i], 0, this.stateCount * this.stateCount);
        }
        return dArr;
    }

    private int[] sequence2intArray(Sequence sequence) {
        int[] iArr = new int[this.partitionSiteCount];
        if (this.dataType instanceof Codons) {
            int i = 0;
            for (int i2 = 0; i2 < this.partitionSiteCount; i2++) {
                iArr[i2] = ((Codons) this.dataType).getState(sequence.getChar(i), sequence.getChar(i + 1), sequence.getChar(i + 2));
                i += 3;
            }
        } else {
            for (int i3 = 0; i3 < this.partitionSiteCount; i3++) {
                iArr[i3] = this.dataType.getState(sequence.getChar(i3));
            }
        }
        return iArr;
    }

    private int randomChoicePDF(double[] dArr, int i, String str) {
        int i2 = -2147483647;
        double d = 0.0d;
        double nextDouble = this.random.nextDouble();
        int i3 = 0;
        while (true) {
            if (i3 >= dArr.length) {
                break;
            }
            d += dArr[i3];
            if (nextDouble < d) {
                i2 = i3;
                break;
            }
            i3++;
        }
        return i2;
    }

    public void setPartitionNumber(Integer num) {
        this.partitionNumber = num;
    }

    public void setRootSequence(Sequence sequence) {
        this.rootSequence = sequence;
        this.hasRootSequence = true;
    }

    public void setOutputAncestralSequences(boolean z) {
        this.outputAncestralSequences = z;
    }

    public TreeModel getTreeModel() {
        return this.treeModel;
    }

    public int getPartitionSiteCount() {
        return ((this.to - this.from) / this.every) + 1;
    }

    public BranchModel getBranchModel() {
        return this.branchModel;
    }

    public FrequencyModel getFreqModel() {
        return this.freqModel;
    }

    public Integer getPartitionNumber() {
        return this.partitionNumber;
    }

    public DataType getDataType() {
        return this.dataType;
    }

    public Map<Taxon, int[]> getTaxonSequencesMap() {
        return this.alignmentMap;
    }

    public Sequence getRootSequence() {
        return this.rootSequence;
    }

    public void printSequences() {
        System.out.println("partition " + this.partitionNumber);
        Utils.printMap(this.alignmentMap);
    }
}
