package dr.app.seqgen;

import dr.evolution.io.Importer;
import dr.evolution.io.NewickImporter;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.oldevomodel.sitemodel.GammaSiteModel;
import dr.oldevomodel.sitemodel.SiteModel;
import dr.oldevomodel.substmodel.FrequencyModel;
import dr.oldevomodel.substmodel.HKY;
import dr.oldevomodel.substmodel.SubstitutionModel;
import java.io.BufferedWriter;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import jebl.evolution.alignments.Alignment;
import jebl.evolution.alignments.BasicAlignment;
import jebl.evolution.io.FastaExporter;
import jebl.evolution.sequences.BasicSequence;
import jebl.evolution.sequences.NucleotideState;
import jebl.evolution.sequences.Nucleotides;
import jebl.evolution.sequences.Sequence;
import jebl.evolution.sequences.SequenceType;
import jebl.evolution.sequences.State;
import jebl.evolution.taxa.Taxon;
import jebl.math.Random;

/* loaded from: input_file:dr/app/seqgen/SeqGen.class */
public class SeqGen {
    final int length;
    final double substitutionRate;
    final FrequencyModel freqModel;
    final SubstitutionModel substModel;
    final SiteModel siteModel;
    final double damageRate;
    static final /* synthetic */ boolean $assertionsDisabled;

    public SeqGen(int i, double d, FrequencyModel frequencyModel, SubstitutionModel substitutionModel, SiteModel siteModel, double d2) {
        this.length = i;
        this.substitutionRate = d;
        this.freqModel = frequencyModel;
        this.substModel = substitutionModel;
        this.siteModel = siteModel;
        this.damageRate = d2;
    }

    public Alignment simulate(Tree tree) {
        int[] iArr = new int[this.length];
        drawSequence(iArr, this.freqModel);
        int[] iArr2 = new int[this.length];
        drawSiteCategories(this.siteModel, iArr2);
        double[] dArr = new double[this.siteModel.getCategoryCount()];
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = this.siteModel.getRateForCategory(i) * this.substitutionRate;
        }
        for (int i2 = 0; i2 < tree.getChildCount(tree.getRoot()); i2++) {
            evolveSequences(iArr, tree, tree.getChild(tree.getRoot(), i2), this.substModel, iArr2, dArr);
        }
        HashMap hashMap = new HashMap();
        hashMap.put(Nucleotides.A_STATE, new State[]{Nucleotides.G_STATE});
        hashMap.put(Nucleotides.C_STATE, new State[]{Nucleotides.T_STATE});
        hashMap.put(Nucleotides.G_STATE, new State[]{Nucleotides.A_STATE});
        hashMap.put(Nucleotides.T_STATE, new State[]{Nucleotides.C_STATE});
        Sequence[] sequenceArr = new Sequence[tree.getExternalNodeCount()];
        List<NucleotideState> canonicalStates = Nucleotides.getCanonicalStates();
        for (int i3 = 0; i3 < tree.getExternalNodeCount(); i3++) {
            NodeRef externalNode = tree.getExternalNode(i3);
            int[] iArr3 = (int[]) tree.getNodeTaxon(externalNode).getAttribute("seq");
            State[] stateArr = new State[iArr3.length];
            for (int i4 = 0; i4 < stateArr.length; i4++) {
                stateArr[i4] = canonicalStates.get(iArr3[i4]);
            }
            if (this.damageRate > 0.0d) {
                damageSequence(stateArr, this.damageRate, tree.getNodeHeight(externalNode), hashMap);
            }
            sequenceArr[i3] = new BasicSequence(SequenceType.NUCLEOTIDE, Taxon.getTaxon(tree.getNodeTaxon(externalNode).getId()), stateArr);
        }
        return new BasicAlignment(sequenceArr);
    }

    void drawSiteCategories(SiteModel siteModel, int[] iArr) {
        double[] categoryProportions = siteModel.getCategoryProportions();
        double[] dArr = new double[categoryProportions.length];
        dArr[0] = categoryProportions[0];
        for (int i = 1; i < dArr.length; i++) {
            dArr[i] = dArr[i - 1] + categoryProportions[i];
        }
        for (int i2 = 0; i2 < iArr.length; i2++) {
            iArr[i2] = draw(dArr);
        }
    }

    public void drawSequence(int[] iArr, FrequencyModel frequencyModel) {
        double[] cumulativeFrequencies = frequencyModel.getCumulativeFrequencies();
        for (int i = 0; i < iArr.length; i++) {
            iArr[i] = draw(cumulativeFrequencies);
        }
    }

    void evolveSequences(int[] iArr, Tree tree, NodeRef nodeRef, SubstitutionModel substitutionModel, int[] iArr2, double[] dArr) {
        int stateCount = substitutionModel.getDataType().getStateCount();
        int[] iArr3 = new int[iArr.length];
        double[][][] dArr2 = new double[iArr2.length][stateCount][stateCount];
        for (int i = 0; i < dArr.length; i++) {
            double[] dArr3 = new double[stateCount * stateCount];
            substitutionModel.getTransitionProbabilities(tree.getBranchLength(nodeRef) * dArr[i], dArr3);
            int i2 = 0;
            for (int i3 = 0; i3 < stateCount; i3++) {
                dArr2[i][i3][0] = dArr3[i2];
                i2++;
                for (int i4 = 1; i4 < stateCount; i4++) {
                    dArr2[i][i3][i4] = dArr2[i][i3][i4 - 1] + dArr3[i2];
                    i2++;
                }
            }
        }
        evolveSequence(iArr, iArr2, dArr2, iArr3);
        if (tree.isExternal(nodeRef)) {
            tree.getNodeTaxon(nodeRef).setAttribute("seq", iArr3);
            return;
        }
        for (int i5 = 0; i5 < tree.getChildCount(nodeRef); i5++) {
            evolveSequences(iArr3, tree, tree.getChild(nodeRef, i5), substitutionModel, iArr2, dArr);
        }
    }

    private void evolveSequence(int[] iArr, int[] iArr2, double[][][] dArr, int[] iArr3) {
        for (int i = 0; i < iArr.length; i++) {
            iArr3[i] = draw(dArr[iArr2[i]][iArr[i]]);
        }
    }

    private void damageSequence(State[] stateArr, double d, double d2, Map<State, State[]> map) {
        double exp = Math.exp((-d) * d2);
        for (int i = 0; i < stateArr.length; i++) {
            if (Random.nextDouble() >= exp) {
                State[] stateArr2 = map.get(stateArr[i]);
                if (stateArr2.length > 0) {
                    stateArr[i] = stateArr2[Random.nextInt(stateArr2.length)];
                } else {
                    stateArr[i] = stateArr2[0];
                }
            }
        }
    }

    private int draw(double[] dArr) {
        double nextDouble = Random.nextDouble();
        int i = -1;
        int i2 = 0;
        while (true) {
            if (i2 >= dArr.length) {
                break;
            }
            if (nextDouble < dArr[i2]) {
                i = i2;
                break;
            }
            i2++;
        }
        if ($assertionsDisabled || i != -1) {
            return i;
        }
        throw new AssertionError();
    }

    public static void main(String[] strArr) {
        String str = strArr[0];
        String str2 = strArr[1];
        double[] dArr = {0.25d, 0.25d, 0.25d, 0.25d};
        double parseDouble = strArr.length < 3 ? 0.001d : Double.parseDouble(strArr[2]);
        int parseInt = strArr.length < 4 ? 8 : Integer.parseInt(strArr[3]);
        double parseDouble2 = strArr.length < 5 ? 0.0d : Double.parseDouble(strArr[4]);
        System.out.println("substitutionRate = " + parseDouble + "; categoryCount = " + parseInt + "; damageRate = " + parseDouble2);
        FrequencyModel frequencyModel = new FrequencyModel(dr.evolution.datatype.Nucleotides.INSTANCE, dArr);
        HKY hky = new HKY(10.0d, frequencyModel);
        GammaSiteModel gammaSiteModel = parseInt > 1 ? new GammaSiteModel(hky, 0.5d, parseInt) : new GammaSiteModel(hky);
        ArrayList arrayList = new ArrayList();
        try {
            NewickImporter newickImporter = new NewickImporter(new FileReader(str));
            while (newickImporter.hasTree()) {
                Tree importNextTree = newickImporter.importNextTree();
                arrayList.add(importNextTree);
                System.out.println("tree height = " + importNextTree.getNodeHeight(importNextTree.getRoot()) + "; leave nodes = " + importNextTree.getExternalNodeCount());
            }
            SeqGen seqGen = new SeqGen(500, parseDouble, frequencyModel, hky, gammaSiteModel, parseDouble2);
            int i = 1;
            Iterator it = arrayList.iterator();
            while (it.hasNext()) {
                Alignment simulate = seqGen.simulate((Tree) it.next());
                try {
                    String str3 = str2 + "-" + parseDouble + ".fasta";
                    BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(str3));
                    new FastaExporter(bufferedWriter).exportSequences(simulate.getSequenceList());
                    bufferedWriter.close();
                    System.out.println("Write " + i + "th sequence file : " + str3);
                    i++;
                } catch (IOException e) {
                    e.printStackTrace();
                    return;
                }
            }
        } catch (Importer.ImportException e2) {
            e2.printStackTrace();
        } catch (FileNotFoundException e3) {
            e3.printStackTrace();
        } catch (IOException e4) {
            e4.printStackTrace();
        }
    }

    static {
        $assertionsDisabled = !SeqGen.class.desiredAssertionStatus();
    }
}
