package dr.evolution.tree.treemetrics;

import dr.evolution.io.Importer;
import dr.evolution.io.NewickImporter;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeUtils;
import dr.evolution.tree.treemetrics.TreeMetric;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Set;

/* loaded from: input_file:dr/evolution/tree/treemetrics/KendallColijnPathDifferenceMetric.class */
public class KendallColijnPathDifferenceMetric implements TreeMetric {
    public static TreeMetric.Type TYPE = TreeMetric.Type.KENDALL_COLIJN;
    private Tree focalTree;
    private int dim;
    private double[] focalSmallM;
    private double[] focalLargeM;
    private final boolean fixedFocalTree = false;
    private final double lambda;

    public KendallColijnPathDifferenceMetric(double d) {
        this.lambda = d;
    }

    public KendallColijnPathDifferenceMetric(double d, Tree tree) {
        this.lambda = d;
        this.focalTree = tree;
        this.dim = tree.getExternalNodeCount() * tree.getExternalNodeCount();
        this.focalSmallM = new double[this.dim];
        this.focalLargeM = new double[this.dim];
        traverse(tree, tree.getRoot(), 0.0d, 0, this.focalLargeM, this.focalSmallM);
    }

    @Override // dr.evolution.tree.treemetrics.TreeMetric
    public double getMetric(Tree tree, Tree tree2) {
        TreeMetric.Utils.checkTreeTaxa(tree, tree2);
        if (tree != this.focalTree) {
            if (this.fixedFocalTree) {
                throw new RuntimeException("Focal tree is different from that set in the constructor.");
            }
            this.focalTree = tree;
            if (this.focalSmallM == null) {
                this.dim = this.focalTree.getExternalNodeCount() * this.focalTree.getExternalNodeCount();
                this.focalSmallM = new double[this.dim];
                this.focalLargeM = new double[this.dim];
            }
            traverse(this.focalTree, this.focalTree.getRoot(), 0.0d, 0, this.focalLargeM, this.focalSmallM);
        }
        double[] dArr = new double[this.dim];
        double[] dArr2 = new double[this.dim];
        traverse(tree2, tree2.getRoot(), 0.0d, 0, dArr2, dArr);
        new ArrayList();
        return calculateMetric(this.focalSmallM, this.focalLargeM, dArr, dArr2, tree.getExternalNodeCount(), this.lambda);
    }

    private double calculateMetric(double[] dArr, double[] dArr2, double[] dArr3, double[] dArr4, int i, double d) {
        double d2 = 0.0d;
        for (int i2 = 0; i2 < i; i2++) {
            for (int i3 = i2; i3 < i; i3++) {
                int i4 = (i2 * i) + i3;
                d2 += Math.pow((((1.0d - d) * dArr[i4]) + (d * dArr2[i4])) - (((1.0d - d) * dArr3[i4]) + (d * dArr4[i4])), 2.0d);
            }
        }
        return Math.sqrt(d2);
    }

    private Set<NodeRef> traverse(Tree tree, NodeRef nodeRef, double d, int i, double[] dArr, double[] dArr2) {
        Set<NodeRef> singleton;
        Set<NodeRef> singleton2;
        int number;
        int number2;
        NodeRef child = tree.getChild(nodeRef, 0);
        NodeRef child2 = tree.getChild(nodeRef, 1);
        if (tree.isExternal(child)) {
            singleton = Collections.singleton(child);
            int number3 = (child.getNumber() * tree.getExternalNodeCount()) + child.getNumber();
            dArr[number3] = tree.getBranchLength(child);
            dArr2[number3] = 1.0d;
        } else {
            singleton = traverse(tree, child, d + tree.getBranchLength(child), i + 1, dArr, dArr2);
        }
        if (tree.isExternal(child2)) {
            singleton2 = Collections.singleton(child2);
            int number4 = (child2.getNumber() * tree.getExternalNodeCount()) + child2.getNumber();
            dArr[number4] = tree.getBranchLength(child2);
            dArr2[number4] = 1.0d;
        } else {
            singleton2 = traverse(tree, child2, d + tree.getBranchLength(child2), i + 1, dArr, dArr2);
        }
        for (NodeRef nodeRef2 : singleton) {
            for (NodeRef nodeRef3 : singleton2) {
                if (nodeRef2.getNumber() < nodeRef3.getNumber()) {
                    number = nodeRef2.getNumber() * tree.getExternalNodeCount();
                    number2 = nodeRef3.getNumber();
                } else {
                    number = nodeRef3.getNumber() * tree.getExternalNodeCount();
                    number2 = nodeRef2.getNumber();
                }
                int i2 = number + number2;
                dArr[i2] = d;
                dArr2[i2] = i;
            }
        }
        HashSet hashSet = new HashSet();
        hashSet.addAll(singleton);
        hashSet.addAll(singleton2);
        return hashSet;
    }

    @Deprecated
    public ArrayList<Double> getMetric_old(Tree tree, ArrayList<Double> arrayList) {
        if (this.focalTree.getExternalNodeCount() != tree.getExternalNodeCount()) {
            throw new RuntimeException("Different number of taxa in both trees.");
        }
        for (int i = 0; i < this.focalTree.getExternalNodeCount(); i++) {
            if (!this.focalTree.getNodeTaxon(this.focalTree.getExternalNode(i)).getId().equals(tree.getNodeTaxon(tree.getExternalNode(i)).getId())) {
                throw new RuntimeException("Mismatch between taxa in both trees: " + this.focalTree.getNodeTaxon(this.focalTree.getExternalNode(i)).getId() + " vs. " + tree.getNodeTaxon(tree.getExternalNode(i)).getId());
            }
        }
        double[] dArr = new double[this.dim];
        double[] dArr2 = new double[this.dim];
        int i2 = 0;
        for (int i3 = 0; i3 < tree.getExternalNodeCount(); i3++) {
            for (int i4 = i3 + 1; i4 < tree.getExternalNodeCount(); i4++) {
                NodeRef commonAncestor = TreeUtils.getCommonAncestor(tree, tree.getExternalNode(i3), tree.getExternalNode(i4));
                int i5 = 0;
                double d = 0.0d;
                while (commonAncestor != tree.getRoot()) {
                    i5++;
                    d += tree.getNodeHeight(tree.getParent(commonAncestor)) - tree.getNodeHeight(commonAncestor);
                    commonAncestor = tree.getParent(commonAncestor);
                }
                dArr[i2] = i5;
                dArr2[i2] = d;
                i2++;
            }
        }
        int externalNodeCount = tree.getExternalNodeCount();
        int i6 = ((externalNodeCount - 2) * (externalNodeCount - 1)) + externalNodeCount;
        int i7 = 0;
        for (int i8 = (externalNodeCount - 1) * (externalNodeCount - 2); i8 < i6; i8++) {
            dArr[i8] = 1.0d;
            dArr2[i8] = tree.getNodeHeight(tree.getParent(tree.getExternalNode(i7))) - tree.getNodeHeight(tree.getExternalNode(i7));
            i7++;
        }
        double[] dArr3 = new double[this.dim];
        double[] dArr4 = new double[this.dim];
        ArrayList<Double> arrayList2 = new ArrayList<>();
        Iterator<Double> it = arrayList.iterator();
        while (it.hasNext()) {
            Double next = it.next();
            double d2 = 0.0d;
            for (int i9 = 0; i9 < this.dim; i9++) {
                dArr3[i9] = ((1.0d - next.doubleValue()) * this.focalSmallM[i9]) + (next.doubleValue() * this.focalLargeM[i9]);
                dArr4[i9] = ((1.0d - next.doubleValue()) * dArr[i9]) + (next.doubleValue() * dArr2[i9]);
                d2 += Math.pow(dArr3[i9] - dArr4[i9], 2.0d);
            }
            arrayList2.add(Double.valueOf(Math.sqrt(d2)));
        }
        return arrayList2;
    }

    @Deprecated
    public ArrayList<Double> getMetric_old(Tree tree, Tree tree2, ArrayList<Double> arrayList) {
        int externalNodeCount = ((tree.getExternalNodeCount() - 2) * (tree.getExternalNodeCount() - 1)) + tree.getExternalNodeCount();
        double[] dArr = new double[externalNodeCount];
        double[] dArr2 = new double[externalNodeCount];
        double[] dArr3 = new double[externalNodeCount];
        double[] dArr4 = new double[externalNodeCount];
        if (tree.getExternalNodeCount() != tree2.getExternalNodeCount()) {
            throw new RuntimeException("Different number of taxa in both trees.");
        }
        for (int i = 0; i < tree.getExternalNodeCount(); i++) {
            if (!tree.getNodeTaxon(tree.getExternalNode(i)).getId().equals(tree2.getNodeTaxon(tree2.getExternalNode(i)).getId())) {
                throw new RuntimeException("Mismatch between taxa in both trees: " + tree.getNodeTaxon(tree.getExternalNode(i)).getId() + " vs. " + tree2.getNodeTaxon(tree2.getExternalNode(i)).getId());
            }
        }
        int i2 = 0;
        for (int i3 = 0; i3 < tree.getExternalNodeCount(); i3++) {
            for (int i4 = i3 + 1; i4 < tree.getExternalNodeCount(); i4++) {
                NodeRef commonAncestor = TreeUtils.getCommonAncestor(tree, tree.getExternalNode(i3), tree.getExternalNode(i4));
                int i5 = 0;
                double d = 0.0d;
                while (commonAncestor != tree.getRoot()) {
                    i5++;
                    d += tree.getNodeHeight(tree.getParent(commonAncestor)) - tree.getNodeHeight(commonAncestor);
                    commonAncestor = tree.getParent(commonAncestor);
                }
                dArr[i2] = i5;
                dArr2[i2] = d;
                i2++;
            }
        }
        int externalNodeCount2 = tree2.getExternalNodeCount();
        int i6 = ((externalNodeCount2 - 2) * (externalNodeCount2 - 1)) + externalNodeCount2;
        int i7 = 0;
        for (int i8 = (externalNodeCount2 - 1) * (externalNodeCount2 - 2); i8 < i6; i8++) {
            dArr[i8] = 1.0d;
            dArr2[i8] = tree.getNodeHeight(tree.getParent(tree.getExternalNode(i7))) - tree.getNodeHeight(tree.getExternalNode(i7));
            i7++;
        }
        int i9 = 0;
        for (int i10 = 0; i10 < tree2.getExternalNodeCount(); i10++) {
            for (int i11 = i10 + 1; i11 < tree2.getExternalNodeCount(); i11++) {
                NodeRef commonAncestor2 = TreeUtils.getCommonAncestor(tree2, tree2.getExternalNode(i10), tree2.getExternalNode(i11));
                int i12 = 0;
                double d2 = 0.0d;
                while (commonAncestor2 != tree2.getRoot()) {
                    i12++;
                    d2 += tree2.getNodeHeight(tree2.getParent(commonAncestor2)) - tree2.getNodeHeight(commonAncestor2);
                    commonAncestor2 = tree2.getParent(commonAncestor2);
                }
                dArr3[i9] = i12;
                dArr4[i9] = d2;
                i9++;
            }
        }
        int i13 = 0;
        for (int i14 = (externalNodeCount2 - 1) * (externalNodeCount2 - 2); i14 < i6; i14++) {
            dArr3[i14] = 1.0d;
            dArr4[i14] = tree2.getNodeHeight(tree2.getParent(tree2.getExternalNode(i13))) - tree2.getNodeHeight(tree2.getExternalNode(i13));
            i13++;
        }
        double[] dArr5 = new double[externalNodeCount];
        double[] dArr6 = new double[externalNodeCount];
        ArrayList<Double> arrayList2 = new ArrayList<>();
        Iterator<Double> it = arrayList.iterator();
        while (it.hasNext()) {
            Double next = it.next();
            double d3 = 0.0d;
            for (int i15 = 0; i15 < externalNodeCount; i15++) {
                dArr5[i15] = ((1.0d - next.doubleValue()) * dArr[i15]) + (next.doubleValue() * dArr2[i15]);
                dArr6[i15] = ((1.0d - next.doubleValue()) * dArr3[i15]) + (next.doubleValue() * dArr4[i15]);
                d3 += Math.pow(dArr5[i15] - dArr6[i15], 2.0d);
            }
            arrayList2.add(Double.valueOf(Math.sqrt(d3)));
        }
        return arrayList2;
    }

    public static void main(String[] strArr) {
        try {
            Tree importNextTree = new NewickImporter("(('A':1.2,'B':0.8):0.5,('C':0.8,'D':1.0):1.1)").importNextTree();
            System.out.println("4-taxa tree 1: " + importNextTree);
            Tree importNextTree2 = new NewickImporter("((('A':0.8,'B':1.4):0.3,'C':0.7):0.9,'D':1.0)").importNextTree();
            System.out.println("4-taxa tree 2: " + importNextTree2);
            System.out.println();
            double[] dArr = {new KendallColijnPathDifferenceMetric(0.0d).getMetric(importNextTree, importNextTree2), new KendallColijnPathDifferenceMetric(0.5d).getMetric(importNextTree, importNextTree2), new KendallColijnPathDifferenceMetric(1.0d).getMetric(importNextTree, importNextTree2)};
            System.out.println("Paired trees:");
            System.out.println("lambda (0.0) = " + dArr[0]);
            System.out.println("lambda (0.5) = " + dArr[1]);
            System.out.println("lambda (1.0) = " + dArr[2]);
            System.out.println();
            double[] dArr2 = {new KendallColijnPathDifferenceMetric(0.0d, importNextTree).getMetric(importNextTree, importNextTree2), new KendallColijnPathDifferenceMetric(0.5d, importNextTree).getMetric(importNextTree, importNextTree2), new KendallColijnPathDifferenceMetric(1.0d, importNextTree).getMetric(importNextTree, importNextTree2)};
            System.out.println("Focal trees:");
            System.out.println("lambda (0.0) = " + dArr2[0]);
            System.out.println("lambda (0.5) = " + dArr2[1]);
            System.out.println("lambda (1.0) = " + dArr2[2]);
            System.out.println();
            System.out.println();
            Tree importNextTree3 = new NewickImporter("(((('A':0.6,'B':0.6):0.1,'C':0.5):0.4,'D':0.7):0.1,'E':1.3)").importNextTree();
            System.out.println("5-taxa tree 1: " + importNextTree3);
            Tree importNextTree4 = new NewickImporter("((('A':0.8,'B':1.4):0.1,'C':0.7):0.2,('D':1.0,'E':0.9):1.3)").importNextTree();
            System.out.println("5-taxa tree 2: " + importNextTree4);
            System.out.println();
            double[] dArr3 = {new KendallColijnPathDifferenceMetric(0.0d, importNextTree3).getMetric(importNextTree3, importNextTree4), new KendallColijnPathDifferenceMetric(0.5d, importNextTree3).getMetric(importNextTree3, importNextTree4), new KendallColijnPathDifferenceMetric(1.0d, importNextTree3).getMetric(importNextTree3, importNextTree4)};
            System.out.println("Paired trees:");
            System.out.println("lambda (0.0) = " + dArr3[0]);
            System.out.println("lambda (0.5) = " + dArr3[1]);
            System.out.println("lambda (1.0) = " + dArr3[2]);
            System.out.println();
            double[] dArr4 = {new KendallColijnPathDifferenceMetric(0.0d, importNextTree3).getMetric(importNextTree3, importNextTree4), new KendallColijnPathDifferenceMetric(0.5d, importNextTree3).getMetric(importNextTree3, importNextTree4), new KendallColijnPathDifferenceMetric(1.0d, importNextTree3).getMetric(importNextTree3, importNextTree4)};
            System.out.println("Focal trees:");
            System.out.println("lambda (0.0) = " + dArr4[0]);
            System.out.println("lambda (0.5) = " + dArr4[1]);
            System.out.println("lambda (1.0) = " + dArr4[2]);
            System.out.println();
            long currentTimeMillis = System.currentTimeMillis();
            for (int i = 0; i < 1000000; i++) {
                new KendallColijnPathDifferenceMetric(0.5d).getMetric(importNextTree3, importNextTree4);
            }
            System.out.println("New algorithm, 1M reps: " + (System.currentTimeMillis() - currentTimeMillis) + " ms");
        } catch (Importer.ImportException e) {
            System.err.println(e);
        } catch (IOException e2) {
            System.err.println(e2);
        }
    }

    @Override // dr.evolution.tree.treemetrics.TreeMetric
    public TreeMetric.Type getType() {
        return TYPE;
    }

    public String toString() {
        return getType().getShortName() + "(" + this.lambda + ")";
    }
}
