package dr.evolution.parsimony;

import dr.evolution.alignment.PatternList;
import dr.evolution.alignment.Patterns;
import dr.evolution.datatype.Nucleotides;
import dr.evolution.tree.FlexibleNode;
import dr.evolution.tree.FlexibleTree;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeUtils;
import dr.evolution.util.Taxon;

/* loaded from: input_file:dr/evolution/parsimony/FitchParsimony.class */
public class FitchParsimony implements ParsimonyCriterion {
    private final int stateCount;
    private final boolean gapsAreStates;
    private boolean[][][] stateSets;
    private int[][] states;
    private final PatternList patterns;
    private final double[] siteScores;
    private Tree tree = null;
    private boolean hasCalculatedSteps = false;
    private boolean hasRecontructedStates = false;

    public PatternList getPatterns() {
        return this.patterns;
    }

    public FitchParsimony(PatternList patternList, boolean z) {
        if (patternList == null) {
            throw new IllegalArgumentException("The patterns cannot be null");
        }
        this.gapsAreStates = z;
        if (z) {
            this.stateCount = patternList.getDataType().getStateCount() + 1;
        } else {
            this.stateCount = patternList.getDataType().getStateCount();
        }
        this.patterns = patternList;
        this.siteScores = new double[patternList.getPatternCount()];
    }

    @Override // dr.evolution.parsimony.ParsimonyCriterion
    public double[] getSiteScores(Tree tree) {
        if (tree == null) {
            throw new IllegalArgumentException("The tree cannot be null");
        }
        if (this.tree == null || this.tree != tree) {
            initialize(tree);
        }
        if (!this.hasCalculatedSteps) {
            for (int i = 0; i < this.siteScores.length; i++) {
                this.siteScores[i] = 0.0d;
            }
            calculateSteps(tree, tree.getRoot(), this.patterns);
            this.hasCalculatedSteps = true;
        }
        return this.siteScores;
    }

    @Override // dr.evolution.parsimony.ParsimonyCriterion
    public double getScore(Tree tree) {
        getSiteScores(tree);
        double d = 0.0d;
        for (int i = 0; i < this.patterns.getPatternCount(); i++) {
            d += this.siteScores[i] * this.patterns.getPatternWeight(i);
        }
        return d;
    }

    @Override // dr.evolution.parsimony.ParsimonyCriterion
    public int[] getStates(Tree tree, NodeRef nodeRef) {
        if (!TreeUtils.isBinary(tree)) {
            throw new IllegalArgumentException("The Fitch algorithm can only reconstruct ancestral states on binary trees");
        }
        getSiteScores(tree);
        if (!this.hasRecontructedStates) {
            reconstructStates(tree, tree.getRoot(), null);
            this.hasRecontructedStates = true;
        }
        return this.states[nodeRef.getNumber()];
    }

    public void initialize(Tree tree) {
        this.tree = tree;
        this.hasCalculatedSteps = false;
        this.hasRecontructedStates = false;
        this.stateSets = new boolean[tree.getNodeCount()][this.patterns.getPatternCount()];
        this.states = new int[tree.getNodeCount()][this.patterns.getPatternCount()];
        for (int i = 0; i < this.patterns.getPatternCount(); i++) {
            int[] pattern = this.patterns.getPattern(i);
            for (int i2 = 0; i2 < tree.getExternalNodeCount(); i2++) {
                int i3 = pattern[this.patterns.getTaxonIndex(tree.getNodeTaxon(tree.getExternalNode(i2)).getId())];
                if (this.gapsAreStates) {
                    this.stateSets[i2][i] = new boolean[this.stateCount];
                    if (this.patterns.getDataType().isGapState(i3)) {
                        this.stateSets[i2][i][this.stateCount - 1] = true;
                    } else {
                        boolean[] stateSet = this.patterns.getDataType().getStateSet(i3);
                        for (int i4 = 0; i4 < stateSet.length; i4++) {
                            this.stateSets[i2][i][i4] = stateSet[i4];
                        }
                    }
                } else {
                    this.stateSets[i2][i] = this.patterns.getDataType().getStateSet(i3);
                }
            }
        }
    }

    private void calculateSteps(Tree tree, NodeRef nodeRef, PatternList patternList) {
        if (tree.isExternal(nodeRef)) {
            return;
        }
        for (int i = 0; i < tree.getChildCount(nodeRef); i++) {
            calculateSteps(tree, tree.getChild(nodeRef, i), patternList);
        }
        for (int i2 = 0; i2 < patternList.getPatternCount(); i2++) {
            boolean[] zArr = this.stateSets[tree.getChild(nodeRef, 0).getNumber()][i2];
            boolean[] zArr2 = this.stateSets[tree.getChild(nodeRef, 0).getNumber()][i2];
            for (int i3 = 1; i3 < tree.getChildCount(nodeRef); i3++) {
                zArr = union(zArr, this.stateSets[tree.getChild(nodeRef, i3).getNumber()][i2]);
                zArr2 = intersection(zArr2, this.stateSets[tree.getChild(nodeRef, i3).getNumber()][i2]);
            }
            if (size(zArr2) > 0) {
                this.stateSets[nodeRef.getNumber()][i2] = zArr2;
            } else {
                this.stateSets[nodeRef.getNumber()][i2] = zArr;
                double[] dArr = this.siteScores;
                int i4 = i2;
                dArr[i4] = dArr[i4] + 1.0d;
            }
        }
    }

    private void reconstructStates(Tree tree, NodeRef nodeRef, int[] iArr) {
        for (int i = 0; i < this.patterns.getPatternCount(); i++) {
            if (iArr == null || !this.stateSets[nodeRef.getNumber()][i][iArr[i]]) {
                this.states[nodeRef.getNumber()][i] = firstIndex(this.stateSets[nodeRef.getNumber()][i]);
            } else {
                this.states[nodeRef.getNumber()][i] = iArr[i];
            }
        }
        for (int i2 = 0; i2 < tree.getChildCount(nodeRef); i2++) {
            reconstructStates(tree, tree.getChild(nodeRef, i2), this.states[nodeRef.getNumber()]);
        }
    }

    private static boolean[] union(boolean[] zArr, boolean[] zArr2) {
        boolean[] zArr3 = new boolean[zArr.length];
        for (int i = 0; i < zArr3.length; i++) {
            zArr3[i] = zArr[i] || zArr2[i];
        }
        return zArr3;
    }

    private static boolean[] intersection(boolean[] zArr, boolean[] zArr2) {
        boolean[] zArr3 = new boolean[zArr.length];
        for (int i = 0; i < zArr3.length; i++) {
            zArr3[i] = zArr[i] && zArr2[i];
        }
        return zArr3;
    }

    private static int firstIndex(boolean[] zArr) {
        for (int i = 0; i < zArr.length; i++) {
            if (zArr[i]) {
                return i;
            }
        }
        return -1;
    }

    private static int size(boolean[] zArr) {
        int i = 0;
        for (boolean z : zArr) {
            if (z) {
                i++;
            }
        }
        return i;
    }

    public static void main(String[] strArr) {
        FlexibleNode flexibleNode = new FlexibleNode(new Taxon("tip1"));
        FlexibleNode flexibleNode2 = new FlexibleNode(new Taxon("tip2"));
        FlexibleNode flexibleNode3 = new FlexibleNode(new Taxon("tip3"));
        FlexibleNode flexibleNode4 = new FlexibleNode(new Taxon("tip4"));
        FlexibleNode flexibleNode5 = new FlexibleNode(new Taxon("tip5"));
        FlexibleNode flexibleNode6 = new FlexibleNode();
        flexibleNode6.addChild(flexibleNode);
        flexibleNode6.addChild(flexibleNode2);
        FlexibleNode flexibleNode7 = new FlexibleNode();
        flexibleNode7.addChild(flexibleNode4);
        flexibleNode7.addChild(flexibleNode5);
        FlexibleNode flexibleNode8 = new FlexibleNode();
        flexibleNode8.addChild(flexibleNode3);
        flexibleNode8.addChild(flexibleNode7);
        FlexibleNode flexibleNode9 = new FlexibleNode();
        flexibleNode9.addChild(flexibleNode6);
        flexibleNode9.addChild(flexibleNode8);
        FlexibleTree flexibleTree = new FlexibleTree(flexibleNode9);
        Patterns patterns = new Patterns(Nucleotides.INSTANCE, flexibleTree);
        patterns.addPattern(new int[]{1, 0, 1, 2, 2});
        FitchParsimony fitchParsimony = new FitchParsimony(patterns, false);
        System.out.println("No. Steps = " + fitchParsimony.getScore(flexibleTree));
        System.out.println(" state(node1) = " + fitchParsimony.getStates(flexibleTree, flexibleNode6)[0]);
        System.out.println(" state(node2) = " + fitchParsimony.getStates(flexibleTree, flexibleNode7)[0]);
        System.out.println(" state(node3) = " + fitchParsimony.getStates(flexibleTree, flexibleNode8)[0]);
        System.out.println(" state(root) = " + fitchParsimony.getStates(flexibleTree, flexibleNode9)[0]);
        System.out.println("\nParsimony static methods:");
        System.out.println("No. Steps = " + Parsimony.getParsimonySteps(flexibleTree, patterns));
        Parsimony.reconstructParsimonyStates(flexibleTree, patterns);
        System.out.println(" state(node1) = " + flexibleTree.getNodeAttribute(flexibleNode6, "rstate1"));
        System.out.println(" state(node2) = " + flexibleTree.getNodeAttribute(flexibleNode7, "rstate1"));
        System.out.println(" state(node3) = " + flexibleTree.getNodeAttribute(flexibleNode8, "rstate1"));
        System.out.println(" state(root) = " + flexibleTree.getNodeAttribute(flexibleNode9, "rstate1"));
    }
}
