package dr.evolution.tree;

import dr.evolution.distance.DistanceMatrix;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;

/* loaded from: input_file:dr/evolution/tree/RzhetskyNeiBranchLengthsTree.class */
public class RzhetskyNeiBranchLengthsTree extends SimpleTree {
    private final DistanceMatrix distanceMatrix;
    private final Set<Integer> allTaxonSet;
    static final /* synthetic */ boolean $assertionsDisabled;

    public RzhetskyNeiBranchLengthsTree(Tree tree, DistanceMatrix distanceMatrix) {
        super(tree);
        this.distanceMatrix = distanceMatrix;
        this.allTaxonSet = new HashSet(getTaxonSets(this, getRoot(), new HashMap()));
    }

    private Set<Integer> getTaxonSets(Tree tree, NodeRef nodeRef, Map<NodeRef, Set<Integer>> map) {
        HashSet hashSet = new HashSet();
        if (tree.isExternal(nodeRef)) {
            hashSet.add(Integer.valueOf(nodeRef.getNumber()));
        } else {
            if (!$assertionsDisabled && tree.getChildCount(nodeRef) != 2) {
                throw new AssertionError("Must be a strictly bifurcating tree");
            }
            for (int i = 0; i < tree.getChildCount(nodeRef); i++) {
                hashSet.addAll(getTaxonSets(tree, getChild(nodeRef, i), map));
            }
        }
        map.put(nodeRef, hashSet);
        return hashSet;
    }

    private void calculateBranchLengths(Tree tree, NodeRef nodeRef, NodeRef nodeRef2, Map<NodeRef, Set<Integer>> map) {
        double sumOfDistances;
        if (tree.isExternal(nodeRef)) {
            Set<Integer> set = map.get(nodeRef);
            Set<Integer> set2 = map.get(nodeRef2);
            HashSet hashSet = new HashSet(this.allTaxonSet);
            hashSet.removeAll(set);
            hashSet.removeAll(set2);
            double size = hashSet.size();
            double size2 = set2.size();
            sumOfDistances = 0.5d * (((getSumOfDistances(set, hashSet) / size) + (getSumOfDistances(set, set2) / size2)) - (getSumOfDistances(hashSet, set2) / (size * size2)));
        } else {
            NodeRef child = getChild(nodeRef, 0);
            NodeRef child2 = getChild(nodeRef, 1);
            calculateBranchLengths(tree, child, child2, map);
            calculateBranchLengths(tree, child2, child, map);
            Set<Integer> set3 = map.get(child);
            Set<Integer> set4 = map.get(child2);
            Set<Integer> set5 = map.get(nodeRef2);
            HashSet hashSet2 = new HashSet(this.allTaxonSet);
            hashSet2.removeAll(set3);
            hashSet2.removeAll(set4);
            hashSet2.removeAll(set5);
            double size3 = hashSet2.size();
            double size4 = set5.size();
            double size5 = set3.size();
            double size6 = set4.size();
            double d = ((size4 * size5) + (size3 * size6)) / ((size3 + size4) * (size5 + size6));
            sumOfDistances = 0.5d * ((((d * (((getSumOfDistances(hashSet2, set3) / size3) * size5) + ((getSumOfDistances(set5, set4) / size4) * size6))) + ((1.0d - d) * (((getSumOfDistances(set5, set3) / size4) * size5) + ((getSumOfDistances(hashSet2, set4) / size3) * size6)))) - ((getSumOfDistances(hashSet2, set5) / size3) * size4)) - ((getSumOfDistances(set3, set4) / size5) * size6));
        }
        setBranchLength(nodeRef, sumOfDistances);
    }

    private double getSumOfDistances(Set<Integer> set, Set<Integer> set2) {
        double d = 0.0d;
        Iterator<Integer> it = set.iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            Iterator<Integer> it2 = set2.iterator();
            while (it2.hasNext()) {
                d += this.distanceMatrix.getElement(intValue, it2.next().intValue());
            }
        }
        return d;
    }

    static {
        $assertionsDisabled = !RzhetskyNeiBranchLengthsTree.class.desiredAssertionStatus();
    }
}
