package dr.evomodel.tree;

import dr.evolution.MetagenomeData;
import dr.evolution.alignment.Alignment;
import dr.evolution.alignment.PatternList;
import dr.evolution.datatype.AminoAcids;
import dr.evolution.datatype.Codons;
import dr.evolution.datatype.DataType;
import dr.evolution.datatype.Nucleotides;
import dr.evolution.util.Taxon;
import dr.evolution.util.TaxonList;
import dr.evomodel.tipstatesmodel.TipStatesModel;
import dr.evomodelxml.tree.HiddenLinkageModelParser;
import dr.inference.model.Model;
import dr.inference.model.Variable;
import dr.math.MathUtils;
import dr.oldevomodel.substmodel.SubstitutionModel;
import dr.oldevomodel.treelikelihood.GeneralLikelihoodCore;
import dr.oldevomodel.treelikelihood.LikelihoodCore;
import dr.oldevomodel.treelikelihood.NativeAminoAcidLikelihoodCore;
import dr.oldevomodel.treelikelihood.NativeNucleotideLikelihoodCore;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;

/* loaded from: input_file:dr/evomodel/tree/HiddenLinkageModel.class */
public class HiddenLinkageModel extends TipStatesModel implements PatternList {
    int linkageGroupCount;
    ArrayList<HashSet<Taxon>> groups;
    MetagenomeData data;
    ArrayList<Taxon> alignmentTaxa;
    double[][] tipPartials;
    double[][] storedTipPartials;
    boolean[] dirtyTipPartials;
    LikelihoodCore core;
    double blen;
    SubstitutionModel substitutionModel;
    double[] tipMatrix;
    double[] internalMatrix;
    ArrayList<Move> movesMade;
    int[] nodeIdToMyTaxaMap;

    /* loaded from: input_file:dr/evomodel/tree/HiddenLinkageModel$Move.class */
    private class Move {
        Taxon read;
        int fromGroup;
        int toGroup;

        public Move(Taxon taxon, int i, int i2) {
            this.read = taxon;
            this.fromGroup = i;
            this.toGroup = i2;
        }
    }

    public HiddenLinkageModel(int i, MetagenomeData metagenomeData) {
        super(HiddenLinkageModelParser.NAME, metagenomeData.getReferenceTaxa(), metagenomeData.getReadsTaxa());
        this.linkageGroupCount = 0;
        this.groups = null;
        this.data = null;
        this.blen = 0.001d;
        this.movesMade = new ArrayList<>();
        this.linkageGroupCount = i;
        this.data = metagenomeData;
        this.groups = new ArrayList<>(i);
        for (int i2 = 0; i2 < i; i2++) {
            this.groups.add(new HashSet<>());
        }
        TaxonList readsTaxa = metagenomeData.getReadsTaxa();
        for (int i3 = 0; i3 < readsTaxa.getTaxonCount(); i3++) {
            this.groups.get(MathUtils.nextInt(i)).add(readsTaxa.getTaxon(i3));
        }
        this.alignmentTaxa = new ArrayList<>(metagenomeData.getReferenceTaxa().asList());
        for (int i4 = 0; i4 < i; i4++) {
            this.alignmentTaxa.add(new Taxon("LinkageGroup_" + i4));
        }
        int siteCount = metagenomeData.getAlignment().getSiteCount() * metagenomeData.getAlignment().getStateCount();
        this.tipPartials = new double[this.alignmentTaxa.size()][siteCount];
        this.storedTipPartials = new double[this.alignmentTaxa.size()][siteCount];
        this.dirtyTipPartials = new boolean[this.alignmentTaxa.size()];
        initCore();
        setupMatrices();
        for (int i5 = 0; i5 < this.tipPartials.length; i5++) {
            computeTipPartials(i5);
        }
    }

    @Override // dr.evolution.alignment.PatternList
    public boolean areUnique() {
        return false;
    }

    @Override // dr.evolution.alignment.PatternList
    public boolean areUncertain() {
        return false;
    }

    private void initCore() {
        if (this.data.getAlignment().getDataType() instanceof Nucleotides) {
            this.core = new NativeNucleotideLikelihoodCore();
        }
        if (this.data.getAlignment().getDataType() instanceof AminoAcids) {
            this.core = new NativeAminoAcidLikelihoodCore();
        }
        if (this.data.getAlignment().getDataType() instanceof Codons) {
            this.core = new GeneralLikelihoodCore(this.data.getAlignment().getStateCount());
        }
        this.core.initialize(this.data.getReadsTaxa().getTaxonCount() * 2, this.data.getAlignment().getSiteCount(), 1, false);
        for (int i = 0; i < this.data.getReadsTaxa().getTaxonCount(); i++) {
            int taxonIndex = this.data.getAlignment().getTaxonIndex(this.data.getReadsTaxa().getTaxon(i));
            int[] iArr = new int[this.data.getAlignment().getSiteCount()];
            for (int i2 = 0; i2 < iArr.length; i2++) {
                iArr[i2] = this.data.getAlignment().getState(taxonIndex, i2);
            }
            this.core.setNodeStates(i, iArr);
        }
        for (int i3 = 0; i3 < this.data.getReadsTaxa().getTaxonCount(); i3++) {
            this.core.createNodePartials(i3 + this.data.getReadsTaxa().getTaxonCount());
        }
    }

    private void setupMatrices() {
        this.tipMatrix = new double[this.data.getAlignment().getStateCount() * this.data.getAlignment().getStateCount()];
        this.internalMatrix = new double[this.data.getAlignment().getStateCount() * this.data.getAlignment().getStateCount()];
        double d = 1.0d - this.blen;
        double stateCount = this.blen / (this.data.getAlignment().getStateCount() - 1);
        double stateCount2 = (1.0d - 0.99999999999999d) / (this.data.getAlignment().getStateCount() - 1);
        for (int i = 0; i < this.tipMatrix.length; i++) {
            this.tipMatrix[i] = stateCount;
            this.internalMatrix[i] = stateCount2;
        }
        for (int i2 = 0; i2 < this.data.getAlignment().getStateCount(); i2++) {
            this.tipMatrix[(i2 * this.data.getAlignment().getStateCount()) + i2] = d;
            this.internalMatrix[(i2 * this.data.getAlignment().getStateCount()) + i2] = 0.99999999999999d;
        }
        for (int i3 = 0; i3 < this.data.getReadsTaxa().getTaxonCount(); i3++) {
            this.core.setNodeMatrix(i3, 0, this.tipMatrix);
        }
        for (int i4 = 0; i4 < this.data.getReadsTaxa().getTaxonCount(); i4++) {
            this.core.setNodeMatrix(i4 + this.data.getReadsTaxa().getTaxonCount(), 0, this.internalMatrix);
        }
    }

    public int getLinkageGroupCount() {
        return this.linkageGroupCount;
    }

    public MetagenomeData getData() {
        return this.data;
    }

    public int getLinkageGroupId(Taxon taxon) {
        int i = 0;
        Iterator<HashSet<Taxon>> it = this.groups.iterator();
        while (it.hasNext() && !it.next().contains(taxon)) {
            i++;
        }
        return i;
    }

    @Override // dr.evomodel.tipstatesmodel.TipStatesModel, dr.inference.model.AbstractModel
    protected void acceptState() {
        this.movesMade.clear();
        for (int i = 0; i < this.dirtyTipPartials.length; i++) {
            this.dirtyTipPartials[i] = false;
        }
    }

    @Override // dr.evomodel.tipstatesmodel.TipStatesModel, dr.inference.model.AbstractModel
    protected void handleModelChangedEvent(Model model, Object obj, int i) {
    }

    @Override // dr.evomodel.tipstatesmodel.TipStatesModel, dr.inference.model.AbstractModel
    protected void handleVariableChangedEvent(Variable variable, int i, Variable.ChangeType changeType) {
    }

    @Override // dr.evomodel.tipstatesmodel.TipStatesModel, dr.inference.model.AbstractModel
    protected void restoreState() {
        for (int size = this.movesMade.size(); size > 0; size--) {
            Move move = this.movesMade.get(size - 1);
            this.groups.get(move.toGroup).remove(move.read);
            this.groups.get(move.fromGroup).add(move.read);
        }
        this.movesMade.clear();
        for (int i = 0; i < this.dirtyTipPartials.length; i++) {
            if (this.dirtyTipPartials[i]) {
                swapTipPartials(i);
                this.dirtyTipPartials[i] = false;
            }
        }
    }

    @Override // dr.evomodel.tipstatesmodel.TipStatesModel, dr.inference.model.AbstractModel
    protected void storeState() {
        this.movesMade.clear();
        for (int i = 0; i < this.dirtyTipPartials.length; i++) {
            this.dirtyTipPartials[i] = false;
        }
    }

    public Set<Taxon> getGroup(int i) {
        return this.groups.get(i);
    }

    public void moveReadGroup(Taxon taxon, int i, int i2) {
        if (!this.groups.get(i).remove(taxon)) {
            throw new RuntimeException("Error, could not find read " + taxon + " in linkage group " + i);
        }
        this.groups.get(i2).add(taxon);
        this.movesMade.add(new Move(taxon, i, i2));
        computeTipPartials(this.data.getReferenceTaxa().getTaxonCount() + i);
        computeTipPartials(this.data.getReferenceTaxa().getTaxonCount() + i2);
        fireModelChanged(this.alignmentTaxa.get((this.alignmentTaxa.size() - this.groups.size()) + i));
        fireModelChanged(this.alignmentTaxa.get((this.alignmentTaxa.size() - this.groups.size()) + i2));
    }

    private void swapTipPartials(int i) {
        double[] dArr = this.storedTipPartials[i];
        this.storedTipPartials[i] = this.tipPartials[i];
        this.tipPartials[i] = dArr;
    }

    private void computeTipPartials(int i) {
        if (!this.dirtyTipPartials[i]) {
            swapTipPartials(i);
            this.dirtyTipPartials[i] = true;
        }
        double[] dArr = this.tipPartials[i];
        Alignment alignment = this.data.getAlignment();
        int stateCount = alignment.getStateCount();
        for (int i2 = 0; i2 < dArr.length; i2++) {
            dArr[i2] = 0.0d;
        }
        if (i < this.data.getReferenceTaxa().getTaxonCount()) {
            int i3 = 0;
            for (int i4 = 0; i4 < alignment.getSiteCount(); i4++) {
                int state = alignment.getState(i, i4);
                if (state >= stateCount) {
                    for (int i5 = 0; i5 < stateCount; i5++) {
                        dArr[i3 + i5] = 1.0d;
                    }
                } else {
                    dArr[i3 + state] = 1.0d;
                }
                i3 += stateCount;
            }
            return;
        }
        HashSet<Taxon> hashSet = this.groups.get(i - this.data.getReferenceTaxa().getTaxonCount());
        int taxonCount = this.data.getReadsTaxa().getTaxonCount();
        Taxon taxon = null;
        boolean z = false;
        Iterator<Taxon> it = hashSet.iterator();
        while (it.hasNext()) {
            Taxon next = it.next();
            if (taxon == null) {
                taxon = next;
            } else {
                int taxonIndex = this.data.getReadsTaxa().getTaxonIndex(next);
                if (z) {
                    this.core.setNodePartialsForUpdate(taxonCount);
                    this.core.calculatePartials(taxonCount - 1, taxonIndex, taxonCount);
                } else {
                    int taxonIndex2 = this.data.getReadsTaxa().getTaxonIndex(taxon);
                    this.core.setNodePartialsForUpdate(taxonCount);
                    this.core.calculatePartials(taxonIndex2, taxonIndex, taxonCount);
                }
                taxonCount++;
                z = true;
            }
        }
        if (hashSet.size() == 0) {
            for (int i6 = 0; i6 < dArr.length; i6++) {
                dArr[i6] = 1.0d;
            }
            return;
        }
        if (z) {
            this.core.getPartials(taxonCount - 1, dArr);
        } else {
            getPartialsForGroupSizeOne(taxon, dArr);
        }
    }

    private void getPartialsForGroupSizeOne(Taxon taxon, double[] dArr) {
        Alignment alignment = this.data.getAlignment();
        int stateCount = alignment.getStateCount();
        int taxonIndex = alignment.getTaxonIndex(taxon);
        int i = 0;
        for (int i2 = 0; i2 < alignment.getSiteCount(); i2++) {
            int state = alignment.getState(taxonIndex, i2);
            if (state >= stateCount) {
                for (int i3 = 0; i3 < stateCount; i3++) {
                    dArr[i + i3] = 1.0d;
                }
            } else {
                System.arraycopy(this.internalMatrix, state * stateCount, dArr, i, stateCount);
            }
            i += stateCount;
        }
    }

    public int newGroup() {
        throw new RuntimeException("Not implemented!");
    }

    public void deleteGroup() {
        throw new RuntimeException("Not implemented!");
    }

    @Override // dr.evomodel.tipstatesmodel.TipStatesModel
    public TipStatesModel.Type getModelType() {
        return TipStatesModel.Type.PARTIALS;
    }

    @Override // dr.evomodel.tipstatesmodel.TipStatesModel
    public void getTipStates(int i, int[] iArr) {
        throw new IllegalArgumentException("This model emits only tip partials");
    }

    @Override // dr.evomodel.tipstatesmodel.TipStatesModel
    public void getTipPartials(int i, double[] dArr) {
        System.arraycopy(this.tipPartials[this.nodeIdToMyTaxaMap[this.tree.getNode(i).getNumber()]], 0, dArr, 0, dArr.length);
    }

    @Override // dr.evomodel.tipstatesmodel.TipStatesModel
    protected void taxaChanged() {
        this.nodeIdToMyTaxaMap = new int[this.tree.getNodeCount()];
        for (int i = 0; i < this.nodeIdToMyTaxaMap.length; i++) {
            int i2 = 0;
            while (true) {
                if (i2 >= this.alignmentTaxa.size()) {
                    break;
                }
                if (this.tree.getTaxon(i) != null) {
                    if (this.tree.getTaxon(i) == null || this.alignmentTaxa.get(i2) == null) {
                        System.err.print("asdgasdg\n");
                    } else if (this.tree.getTaxon(i).getId() == null || this.alignmentTaxa.get(i2).getId() == null) {
                        System.err.print("asdgasdg\n");
                    }
                    if (this.tree.getTaxon(i).getId().equalsIgnoreCase(this.alignmentTaxa.get(i2).getId())) {
                        this.nodeIdToMyTaxaMap[this.tree.getExternalNode(i).getNumber()] = i2;
                        break;
                    }
                }
                i2++;
            }
        }
    }

    @Override // dr.evolution.alignment.PatternList
    public DataType getDataType() {
        return this.data.getAlignment().getDataType();
    }

    @Override // dr.evolution.alignment.PatternList
    public int[] getPattern(int i) {
        return this.data.getAlignment().getPattern(i);
    }

    /* JADX WARN: Type inference failed for: r0v1, types: [double[], double[][]] */
    @Override // dr.evolution.alignment.PatternList
    public double[][] getUncertainPattern(int i) {
        return new double[0];
    }

    @Override // dr.evolution.alignment.PatternList
    public int getPatternCount() {
        return this.data.getAlignment().getPatternCount();
    }

    @Override // dr.evolution.alignment.PatternList
    public int getPatternLength() {
        return this.data.getAlignment().getPatternLength();
    }

    @Override // dr.evolution.alignment.PatternList
    public int getPatternState(int i, int i2) {
        if (i < this.data.getReferenceTaxa().getTaxonCount()) {
            return this.data.getAlignment().getPatternState(i, i2);
        }
        return 0;
    }

    @Override // dr.evolution.alignment.PatternList
    public double[] getUncertainPatternState(int i, int i2) {
        return new double[0];
    }

    @Override // dr.evolution.alignment.PatternList
    public double getPatternWeight(int i) {
        return this.data.getAlignment().getPatternWeight(i);
    }

    @Override // dr.evolution.alignment.PatternList
    public double[] getPatternWeights() {
        return this.data.getAlignment().getPatternWeights();
    }

    @Override // dr.evolution.alignment.PatternList
    public int getStateCount() {
        return this.data.getAlignment().getStateCount();
    }

    @Override // dr.evolution.alignment.PatternList
    public double[] getStateFrequencies() {
        return this.data.getAlignment().getStateFrequencies();
    }

    @Override // dr.evolution.util.TaxonList
    public List<Taxon> asList() {
        return this.alignmentTaxa;
    }

    @Override // dr.evolution.util.TaxonList
    public Taxon getTaxon(int i) {
        return this.alignmentTaxa.get(i);
    }

    @Override // dr.evolution.util.TaxonList
    public Object getTaxonAttribute(int i, String str) {
        return this.alignmentTaxa.get(i).getAttribute(str);
    }

    @Override // dr.evolution.util.TaxonList
    public int getTaxonCount() {
        return this.alignmentTaxa.size();
    }

    @Override // dr.evolution.util.TaxonList
    public String getTaxonId(int i) {
        return this.alignmentTaxa.get(i).getId();
    }

    @Override // dr.evolution.util.TaxonList
    public int getTaxonIndex(String str) {
        for (int i = 0; i < this.alignmentTaxa.size(); i++) {
            if (this.alignmentTaxa.get(i).getId().equals(str)) {
                return i;
            }
        }
        return -1;
    }

    @Override // dr.evolution.util.TaxonList
    public int getTaxonIndex(Taxon taxon) {
        for (int i = 0; i < this.alignmentTaxa.size(); i++) {
            if (this.alignmentTaxa.get(i).compareTo(taxon) == 0) {
                return i;
            }
        }
        return -1;
    }

    @Override // java.lang.Iterable
    public Iterator<Taxon> iterator() {
        return this.alignmentTaxa.iterator();
    }
}
