package dr.evolution.parsimony;

import dr.evolution.alignment.PatternList;
import dr.evolution.alignment.Patterns;
import dr.evolution.datatype.Nucleotides;
import dr.evolution.tree.FlexibleNode;
import dr.evolution.tree.FlexibleTree;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.util.Taxon;
import java.util.Iterator;
import java.util.TreeSet;

/* loaded from: input_file:dr/evolution/parsimony/SankoffParsimony.class */
public class SankoffParsimony implements ParsimonyCriterion {
    private final int stateCount;
    private int[][] stateSets;
    private double[][][] nodeScores;
    private int[][] nodeStates;
    private final PatternList patterns;
    private final double[][] costMatrix;
    private final double[] siteScores;
    private Tree tree = null;
    private final boolean compressStates = true;
    private boolean hasCalculatedSteps = false;
    private boolean hasRecontructedStates = false;

    public SankoffParsimony(PatternList patternList) {
        if (patternList == null) {
            throw new IllegalArgumentException("The patterns cannot be null");
        }
        this.stateCount = patternList.getDataType().getStateCount();
        this.costMatrix = new double[this.stateCount][this.stateCount];
        for (int i = 0; i < this.stateCount; i++) {
            for (int i2 = 0; i2 < this.stateCount; i2++) {
                if (i == i2) {
                    this.costMatrix[i][i2] = 0.0d;
                } else {
                    this.costMatrix[i][i2] = 1.0d;
                }
            }
        }
        this.patterns = patternList;
        this.siteScores = new double[patternList.getPatternCount()];
    }

    public SankoffParsimony(PatternList patternList, double[][] dArr) {
        if (patternList == null) {
            throw new IllegalArgumentException("The patterns cannot be null");
        }
        this.stateCount = patternList.getDataType().getStateCount();
        if (dArr.length != this.stateCount || dArr[0].length != this.stateCount) {
            throw new IllegalArgumentException("The cost matrix is of the wrong dimension: expecting " + this.stateCount + " square");
        }
        this.costMatrix = dArr;
        this.patterns = patternList;
        this.siteScores = new double[patternList.getPatternCount()];
    }

    @Override // dr.evolution.parsimony.ParsimonyCriterion
    public double[] getSiteScores(Tree tree) {
        if (tree == null) {
            throw new IllegalArgumentException("The tree cannot be null");
        }
        if (this.tree == null || this.tree != tree) {
            this.tree = tree;
            initialize();
        }
        if (!this.hasCalculatedSteps) {
            calculateSteps(tree, tree.getRoot(), this.patterns);
            for (int i = 0; i < this.siteScores.length; i++) {
                this.siteScores[i] = minScore(this.nodeScores[tree.getRoot().getNumber()][i], this.stateSets[i]);
            }
            this.hasCalculatedSteps = true;
        }
        return this.siteScores;
    }

    @Override // dr.evolution.parsimony.ParsimonyCriterion
    public double getScore(Tree tree) {
        getSiteScores(tree);
        double d = 0.0d;
        for (int i = 0; i < this.patterns.getPatternCount(); i++) {
            d += this.siteScores[i] * this.patterns.getPatternWeight(i);
        }
        return d;
    }

    @Override // dr.evolution.parsimony.ParsimonyCriterion
    public int[] getStates(Tree tree, NodeRef nodeRef) {
        getSiteScores(tree);
        if (!this.hasRecontructedStates) {
            for (int i = 0; i < this.patterns.getPatternCount(); i++) {
                this.nodeStates[tree.getRoot().getNumber()][i] = minState(this.nodeScores[tree.getRoot().getNumber()][i], this.stateSets[i]);
            }
            reconstructStates(tree, tree.getRoot(), this.nodeStates[tree.getRoot().getNumber()]);
            this.hasRecontructedStates = true;
        }
        return this.nodeStates[nodeRef.getNumber()];
    }

    /* JADX WARN: Type inference failed for: r1v5, types: [int[], int[][]] */
    private void initialize() {
        this.hasCalculatedSteps = false;
        this.hasRecontructedStates = false;
        this.stateSets = new int[this.patterns.getPatternCount()];
        this.nodeScores = new double[this.tree.getNodeCount()][this.patterns.getPatternCount()];
        this.nodeStates = new int[this.tree.getNodeCount()][this.patterns.getPatternCount()];
        for (int i = 0; i < this.patterns.getPatternCount(); i++) {
            int[] pattern = this.patterns.getPattern(i);
            TreeSet treeSet = new TreeSet();
            for (int i2 : pattern) {
                boolean[] stateSet = this.patterns.getDataType().getStateSet(i2);
                for (int i3 = 0; i3 < stateSet.length; i3++) {
                    if (stateSet[i3]) {
                        treeSet.add(new Integer(i3));
                    }
                }
            }
            this.stateSets[i] = new int[treeSet.size()];
            Iterator it = treeSet.iterator();
            int i4 = 0;
            while (it.hasNext()) {
                this.stateSets[i][i4] = ((Integer) it.next()).intValue();
                i4++;
            }
            for (int i5 = 0; i5 < this.tree.getExternalNodeCount(); i5++) {
                boolean[] stateSet2 = this.patterns.getDataType().getStateSet(pattern[this.patterns.getTaxonIndex(this.tree.getNodeTaxon(this.tree.getExternalNode(i5)).getId())]);
                this.nodeScores[i5][i] = new double[this.stateCount];
                for (int i6 = 0; i6 < this.stateCount; i6++) {
                    if (stateSet2[i6]) {
                        this.nodeScores[i5][i][i6] = 0.0d;
                    } else {
                        this.nodeScores[i5][i][i6] = Double.POSITIVE_INFINITY;
                    }
                }
            }
            for (int i7 = 0; i7 < this.tree.getInternalNodeCount(); i7++) {
                this.nodeScores[i7 + this.tree.getExternalNodeCount()][i] = new double[this.stateCount];
            }
        }
    }

    private void calculateSteps(Tree tree, NodeRef nodeRef, PatternList patternList) {
        if (tree.isExternal(nodeRef)) {
            return;
        }
        for (int i = 0; i < tree.getChildCount(nodeRef); i++) {
            calculateSteps(tree, tree.getChild(nodeRef, i), patternList);
        }
        for (int i2 = 0; i2 < patternList.getPatternCount(); i2++) {
            double[] dArr = this.nodeScores[tree.getChild(nodeRef, 0).getNumber()][i2];
            double[] dArr2 = this.nodeScores[nodeRef.getNumber()][i2];
            int[] iArr = this.stateSets[i2];
            for (int i3 = 0; i3 < iArr.length; i3++) {
                dArr2[iArr[i3]] = minCost(i3, dArr, this.costMatrix, iArr);
            }
            for (int i4 = 1; i4 < tree.getChildCount(nodeRef); i4++) {
                double[] dArr3 = this.nodeScores[tree.getChild(nodeRef, i4).getNumber()][i2];
                for (int i5 = 0; i5 < iArr.length; i5++) {
                    int i6 = iArr[i5];
                    dArr2[i6] = dArr2[i6] + minCost(i5, dArr3, this.costMatrix, iArr);
                }
            }
        }
    }

    private void reconstructStates(Tree tree, NodeRef nodeRef, int[] iArr) {
        for (int i = 0; i < this.patterns.getPatternCount(); i++) {
            double[] dArr = this.nodeScores[nodeRef.getNumber()][i];
            int[] iArr2 = this.stateSets[i];
            int i2 = iArr2[0];
            double d = dArr[i2] + this.costMatrix[iArr[i]][i2];
            for (int i3 = 1; i3 < iArr2.length; i3++) {
                double d2 = dArr[iArr2[i3]] + this.costMatrix[iArr[i]][iArr2[i3]];
                if (d2 < d) {
                    i2 = iArr2[i3];
                    d = d2;
                }
            }
            this.nodeStates[nodeRef.getNumber()][i] = i2;
        }
        for (int i4 = 0; i4 < tree.getChildCount(nodeRef); i4++) {
            reconstructStates(tree, tree.getChild(nodeRef, i4), this.nodeStates[nodeRef.getNumber()]);
        }
    }

    private int minState(double[] dArr) {
        int i = 0;
        for (int i2 = 1; i2 < dArr.length; i2++) {
            if (dArr[i2] < dArr[i]) {
                i = i2;
            }
        }
        return i;
    }

    private double minScore(double[] dArr) {
        double d = dArr[0];
        for (int i = 1; i < dArr.length; i++) {
            if (dArr[i] < d) {
                d = dArr[i];
            }
        }
        return d;
    }

    private double minCost(int i, double[] dArr, double[][] dArr2) {
        double[] dArr3 = dArr2[i];
        double d = dArr3[0] + dArr[0];
        for (int i2 = 1; i2 < dArr.length; i2++) {
            double d2 = dArr3[i2] + dArr[i2];
            if (d2 < d) {
                d = d2;
            }
        }
        return d;
    }

    private int minState(double[] dArr, int[] iArr) {
        int i = iArr[0];
        for (int i2 = 1; i2 < iArr.length; i2++) {
            if (dArr[iArr[i2]] < dArr[i]) {
                i = iArr[i2];
            }
        }
        return i;
    }

    private double minScore(double[] dArr, int[] iArr) {
        double d = dArr[iArr[0]];
        for (int i = 1; i < iArr.length; i++) {
            if (dArr[iArr[i]] < d) {
                d = dArr[iArr[i]];
            }
        }
        return d;
    }

    private double minCost(int i, double[] dArr, double[][] dArr2, int[] iArr) {
        double[] dArr3 = dArr2[iArr[i]];
        double d = dArr3[iArr[0]] + dArr[iArr[0]];
        for (int i2 = 1; i2 < iArr.length; i2++) {
            double d2 = dArr3[iArr[i2]] + dArr[iArr[i2]];
            if (d2 < d) {
                d = d2;
            }
        }
        return d;
    }

    public static void main(String[] strArr) {
        FlexibleNode flexibleNode = new FlexibleNode(new Taxon("tip1"));
        FlexibleNode flexibleNode2 = new FlexibleNode(new Taxon("tip2"));
        FlexibleNode flexibleNode3 = new FlexibleNode(new Taxon("tip3"));
        FlexibleNode flexibleNode4 = new FlexibleNode(new Taxon("tip4"));
        FlexibleNode flexibleNode5 = new FlexibleNode(new Taxon("tip5"));
        FlexibleNode flexibleNode6 = new FlexibleNode();
        flexibleNode6.addChild(flexibleNode);
        flexibleNode6.addChild(flexibleNode2);
        FlexibleNode flexibleNode7 = new FlexibleNode();
        flexibleNode7.addChild(flexibleNode4);
        flexibleNode7.addChild(flexibleNode5);
        FlexibleNode flexibleNode8 = new FlexibleNode();
        flexibleNode8.addChild(flexibleNode3);
        flexibleNode8.addChild(flexibleNode7);
        FlexibleNode flexibleNode9 = new FlexibleNode();
        flexibleNode9.addChild(flexibleNode6);
        flexibleNode9.addChild(flexibleNode8);
        FlexibleTree flexibleTree = new FlexibleTree(flexibleNode9);
        Patterns patterns = new Patterns(Nucleotides.INSTANCE, flexibleTree);
        patterns.addPattern(new int[]{2, 3, 1, 3, 3});
        FitchParsimony fitchParsimony = new FitchParsimony(patterns, false);
        SankoffParsimony sankoffParsimony = new SankoffParsimony(patterns);
        for (int i = 0; i < patterns.getPatternCount(); i++) {
            double[] siteScores = fitchParsimony.getSiteScores(flexibleTree);
            System.out.println("Pattern = " + i);
            System.out.println("Fitch:");
            System.out.println("  No. Steps = " + siteScores[i]);
            System.out.println("    state(node1) = " + fitchParsimony.getStates(flexibleTree, flexibleNode6)[i]);
            System.out.println("    state(node2) = " + fitchParsimony.getStates(flexibleTree, flexibleNode7)[i]);
            System.out.println("    state(node3) = " + fitchParsimony.getStates(flexibleTree, flexibleNode8)[i]);
            System.out.println("    state(root) = " + fitchParsimony.getStates(flexibleTree, flexibleNode9)[i]);
            double[] siteScores2 = sankoffParsimony.getSiteScores(flexibleTree);
            System.out.println("Sankoff:");
            System.out.println("  No. Steps = " + siteScores2[i]);
            System.out.println("    state(node1) = " + sankoffParsimony.getStates(flexibleTree, flexibleNode6)[i]);
            System.out.println("    state(node2) = " + sankoffParsimony.getStates(flexibleTree, flexibleNode7)[i]);
            System.out.println("    state(node3) = " + sankoffParsimony.getStates(flexibleTree, flexibleNode8)[i]);
            System.out.println("    state(root) = " + sankoffParsimony.getStates(flexibleTree, flexibleNode9)[i]);
            System.out.println();
        }
    }
}
