diff --git a/historydag/beast_loader.py b/historydag/beast_loader.py index 79e1a38..9e6c20f 100644 --- a/historydag/beast_loader.py +++ b/historydag/beast_loader.py @@ -34,7 +34,6 @@ def dag_from_beast_trees( include_sequence_names_in_labels: If True, augment leaf node labels with a ``name`` attribute containing the name of the corresponding sequence. Useful for distinguishing leaves when observed sequences are not unique. - """ dp_trees = load_beast_trees( beast_xml_file, diff --git a/historydag/dag.py b/historydag/dag.py index 06b4605..9cbb29c 100644 --- a/historydag/dag.py +++ b/historydag/dag.py @@ -26,7 +26,7 @@ from collections import Counter, namedtuple from copy import deepcopy from historydag import utils -from historydag.utils import Weight, Label, UALabel, prod +from historydag.utils import Weight, Label, UALabel, prod, TaxaError from historydag.counterops import counter_sum, counter_prod import historydag.parsimony_utils as parsimony_utils from historydag.dag_node import ( @@ -1977,7 +1977,7 @@ def default_accum_above_edge(subtree_weight, edge_weight): return downward_weights, upward_weights - def count_nodes(self, collapse=False) -> Dict[HistoryDagNode, int]: + def count_nodes(self, collapse=False, rooted=True) -> Dict[HistoryDagNode, int]: """Counts the number of trees each node takes part in. For node supports with respect to a uniform distribution on trees, use @@ -1987,6 +1987,11 @@ def count_nodes(self, collapse=False) -> Dict[HistoryDagNode, int]: collapse: A flag that when set to true, treats nodes as clade unions and ignores label information. Then, the returned dictionary is keyed by clade union sets. + rooted: A flag which is ignored unless ``collapse`` is ``True``. When ``rooted`` is also ``False``, + the returned dictionary is keyed by splits -- that is, sets containing each clade + union and its complement, with values the number of (rooted) trees in the DAG containing + each split. Splits are not double-counted when a tree has a bifurcating root. + If False, dag is expected to have trees all on the same set of leaf labels. Returns: A dictionary mapping each node in the DAG to the number of trees @@ -2031,7 +2036,37 @@ def count_nodes(self, collapse=False) -> Dict[HistoryDagNode, int]: collapsed_n2c[clade] = 0 collapsed_n2c[clade] += node2count[node] - return collapsed_n2c + if rooted: + # Remove the UA node clade union from N + try: + collapsed_n2c.pop(frozenset()) + except KeyError: + pass + return collapsed_n2c + else: + # Create dictionary counting in how many trees each split + # occurs as child of bifurcating root + split2adjustment = {} + all_taxa = next(self.dagroot.children()).clade_union() + if any(all_taxa != n.clade_union() for n in self.dagroot.children()): + raise TaxaError( + "Unrooted splits cannot be counted properly because" + " trees in this dag are on different sets of taxa." + ) + for treeroot in self.dagroot.children(): + if len(treeroot.clades) == 2: + split = frozenset(treeroot.clades.keys()) + before = split2adjustment.get(split, 0) + split2adjustment[split] = before + node2count[treeroot] + split2count = {} + for clade, count in collapsed_n2c.items(): + split = frozenset({clade, all_taxa - clade}) + before = split2count.get(split, 0) + split2count[split] = before + count + for split, adjustment in split2adjustment.items(): + split2count[split] -= adjustment + split2count.pop(frozenset({all_taxa, frozenset()}), None) + return split2count else: return node2count @@ -2258,6 +2293,9 @@ def overestimate_rf_diameter(self): def optimal_sum_rf_distance( self, reference_dag: "HistoryDag", + rooted: bool = True, + one_sided: str = None, + one_sided_coefficients: Tuple[float, float] = (1, 1), optimal_func: Callable[[List[Weight]], Weight] = min, ): """Returns the optimal (min or max) summed rooted RF distance to all @@ -2269,17 +2307,28 @@ def optimal_sum_rf_distance( instead of making multiple calls to this method with the same reference history DAG. """ - kwargs = utils.sum_rfdistance_funcs(reference_dag) + kwargs = utils.sum_rfdistance_funcs( + reference_dag, + rooted=rooted, + one_sided=one_sided, + one_sided_coefficients=one_sided_coefficients, + ) return self.optimal_weight_annotate(**kwargs, optimal_func=optimal_func) def trim_optimal_sum_rf_distance( self, reference_dag: "HistoryDag", + rooted: bool = True, + one_sided: str = None, + one_sided_coefficients: Tuple[float, float] = (1, 1), optimal_func: Callable[[List[Weight]], Weight] = min, ): """Trims the DAG to contain only histories with the optimal (min or max) sum rooted RF distance to the given reference DAG. + See :meth:`utils.sum_rfdistance_funcs` for detailed documentation of + arguments. + Trimming to the minimum sum RF distance is equivalent to finding 'median' topologies, and trimming to maximum sum rf distance is equivalent to finding topological outliers. @@ -2289,18 +2338,28 @@ def trim_optimal_sum_rf_distance( instead of making multiple calls to this method with the same reference history. """ - kwargs = utils.sum_rfdistance_funcs(reference_dag) + kwargs = utils.sum_rfdistance_funcs( + reference_dag, + rooted=rooted, + one_sided=one_sided, + one_sided_coefficients=one_sided_coefficients, + ) return self.trim_optimal_weight(**kwargs, optimal_func=optimal_func) def trim_optimal_rf_distance( self, history: "HistoryDag", rooted: bool = False, + one_sided: str = None, + one_sided_coefficients: Tuple[float, float] = (1, 1), optimal_func: Callable[[List[Weight]], Weight] = min, ): """Trims this history DAG to the optimal (min or max) RF distance to a given history. + See :meth:`utils.make_rfdistance_countfuncs` for detailed documentation of + arguments. + Also returns that optimal RF distance The given history must be on the same taxa as all trees in the DAG. @@ -2309,43 +2368,81 @@ def trim_optimal_rf_distance( instead of making multiple calls to this method with the same reference history. """ - kwargs = utils.make_rfdistance_countfuncs(history, rooted=rooted) + kwargs = utils.make_rfdistance_countfuncs( + history, + rooted=rooted, + one_sided=one_sided, + one_sided_coefficients=one_sided_coefficients, + ) return self.trim_optimal_weight(**kwargs, optimal_func=optimal_func) def optimal_rf_distance( self, history: "HistoryDag", rooted: bool = False, + one_sided: str = None, + one_sided_coefficients: Tuple[float, float] = (1, 1), optimal_func: Callable[[List[Weight]], Weight] = min, ): """Returns the optimal (min or max) RF distance to a given history. + See :meth:`utils.make_rfdistance_countfuncs` for detailed documentation of + arguments. + The given history must be on the same taxa as all trees in the DAG. Since computing reference splits is expensive, it is better to use :meth:`optimal_weight_annotate` and :meth:`utils.make_rfdistance_countfuncs` instead of making multiple calls to this method with the same reference history. """ - kwargs = utils.make_rfdistance_countfuncs(history, rooted=rooted) + kwargs = utils.make_rfdistance_countfuncs( + history, + rooted=rooted, + one_sided=one_sided, + one_sided_coefficients=one_sided_coefficients, + ) return self.optimal_weight_annotate(**kwargs, optimal_func=optimal_func) - def count_rf_distances(self, history: "HistoryDag", rooted: bool = False): + def count_rf_distances( + self, + history: "HistoryDag", + rooted: bool = False, + one_sided: str = None, + one_sided_coefficients: Tuple[float, float] = (1, 1), + ): """Returns a Counter containing all RF distances to a given history. The given history must be on the same taxa as all trees in the DAG. + See :meth:`utils.make_rfdistance_countfuncs` for detailed documentation of + arguments. + Since computing reference splits is expensive, it is better to use :meth:`weight_count` and :meth:`utils.make_rfdistance_countfuncs` instead of making multiple calls to this method with the same reference history. """ - kwargs = utils.make_rfdistance_countfuncs(history, rooted=rooted) + kwargs = utils.make_rfdistance_countfuncs( + history, + rooted=rooted, + one_sided=one_sided, + one_sided_coefficients=one_sided_coefficients, + ) return self.weight_count(**kwargs) - def count_sum_rf_distances(self, reference_dag: "HistoryDag", rooted: bool = False): + def count_sum_rf_distances( + self, + reference_dag: "HistoryDag", + rooted: bool = True, + one_sided: str = None, + one_sided_coefficients: Tuple[float, float] = (1, 1), + ): """Returns a Counter containing all sum RF distances to a given reference DAG. + See :meth:`utils.sum_rfdistance_funcs` for detailed documentation of + arguments. + The given history DAG must be on the same taxa as all trees in the DAG. Since computing reference splits is expensive, it is better to use @@ -2353,14 +2450,42 @@ def count_sum_rf_distances(self, reference_dag: "HistoryDag", rooted: bool = Fal instead of making multiple calls to this method with the same reference history DAG. """ - kwargs = utils.sum_rfdistance_funcs(reference_dag) + kwargs = utils.sum_rfdistance_funcs( + reference_dag, + rooted=rooted, + one_sided=one_sided, + one_sided_coefficients=one_sided_coefficients, + ) return self.weight_count(**kwargs) - def sum_rf_distances(self, reference_dag: "HistoryDag" = None): - r"""Computes the sum of all Robinson-Foulds distances between a history - in this DAG and a history in the reference DAG. + def sum_rf_distances( + self, + reference_dag: "HistoryDag" = None, + rooted: bool = True, + one_sided: str = None, + one_sided_coefficients: Tuple[float, float] = (1, 1), + ): + r"""Computes the sum of Robinson-Foulds distances over all pairs of + histories in this DAG and the provided reference DAG. + + Args: + reference_dag: If None, the sum of pairwise distances between histories in this DAG + is computed. If provided, the sum is over pairs containing one history in this DAG and + one from ``reference_dag``. + rooted: If False, use edges' splits for RF distance computation. Otherwise, use + the clade below each edge. + one_sided: May be 'left', 'right', or None. 'left' means that we count + splits (or clades, in the rooted case) which are in the reference trees but not + in the DAG tree, especially useful if trees in the DAG might be resolutions of + multifurcating trees in the reference DAG. 'right' means that we count splits or clades in + the DAG tree which are not in the reference trees, useful if the reference trees + are possibly resolutions of multifurcating trees in the DAG. If not None, + one_sided_coefficients are ignored. + one_sided_coefficients: coefficients for non-standard symmetric difference calculations. + See :meth:`utils.make_rfdistance_countfuncs` for more details. - This is rooted RF distance. + Returns: + An integer sum of RF distances. If T is the set of histories in the reference DAG, and T' is the set of histories in this DAG, then the returned sum is: @@ -2372,22 +2497,16 @@ def sum_rf_distances(self, reference_dag: "HistoryDag" = None): That is, since RF distance is symmetric, when T = T' (such as when ``reference_dag=None``), or when the intersection of T and T' is nonempty, some distances are counted twice. - Args: - reference_dag: If None, the sum of pairwise distances between histories in this DAG - is computed. If provided, the sum is over pairs containing one history in this DAG and - one from ``reference_dag``. - - Returns: - An integer sum of RF distances. + Note that when computing one-sided distances, or when the one_sided_coefficients values are not + equal, this 'distance' is no longer symmetric. """ + s, t, _ = utils._process_rf_one_sided_coefficients( + one_sided, one_sided_coefficients + ) def get_data(dag): n_histories = dag.count_histories() - N = dag.count_nodes(collapse=True) - try: - N.pop(frozenset()) - except KeyError: - pass + N = dag.count_nodes(collapse=True, rooted=rooted) clade_count_sum = sum(N.values()) return (n_histories, N, clade_count_sum) @@ -2410,13 +2529,13 @@ def get_data(dag): ) return ( - n_histories * clade_count_sum_prime - + n_histories_prime * clade_count_sum - - 2 * intersection_term + t * n_histories * clade_count_sum_prime + + s * n_histories_prime * clade_count_sum + - (s + t) * intersection_term ) def average_pairwise_rf_distance( - self, reference_dag: "HistoryDag" = None, non_identical=True + self, reference_dag: "HistoryDag" = None, non_identical=True, **kwargs ): """Return the average Robinson-Foulds distance between pairs of histories. @@ -2425,6 +2544,7 @@ def average_pairwise_rf_distance( reference_dag: A history DAG from which to take the second history in each pair. If None, ``self`` will be used as the reference. non_identical: If True, mean divisor will be the number of non-identical pairs. + kwargs: See :meth:`historydag.sum_rf_distances` for additional keyword arguments Returns: The average rf-distance between pairs of histories, where the first history @@ -2433,7 +2553,9 @@ def average_pairwise_rf_distance( ``non_identical`` is True, in which case the number of histories which appear in both DAGs is subtracted from this constant. """ - sum_pairwise_distance = self.sum_rf_distances(reference_dag=reference_dag) + sum_pairwise_distance = self.sum_rf_distances( + reference_dag=reference_dag, **kwargs + ) if reference_dag is None: # ignore the diagonal in the distance matrix, since it contains # zeros: diff --git a/historydag/utils.py b/historydag/utils.py index b4aaf39..619c835 100644 --- a/historydag/utils.py +++ b/historydag/utils.py @@ -30,6 +30,10 @@ F = TypeVar("F", bound=Callable[..., Any]) +class TaxaError(ValueError): + pass + + class UALabel(str): _fields: Tuple = tuple() @@ -541,63 +545,176 @@ def natural_edge_probability(parent, child): according to the natural distribution induced by the DAG topology.""" -def sum_rfdistance_funcs(reference_dag: "HistoryDag"): +def _process_rf_one_sided_coefficients(one_sided, one_sided_coefficients): + rf_type_suffix = "distance" + if one_sided_coefficients != (1, 1): + rf_type_suffix = "nonstandard" + + if one_sided is None: + pass + elif one_sided.lower() == "left": + one_sided_coefficients = (1, 0) + rf_type_suffix = "left_difference" + elif one_sided.lower() == "right": + one_sided_coefficients = (0, 1) + rf_type_suffix = "right_difference" + else: + raise ValueError( + f"Argument `one_sided` must have value 'left', 'right', or None, not {one_sided}" + ) + + s, t = one_sided_coefficients + return s, t, rf_type_suffix + + +def sum_rfdistance_funcs( + reference_dag: "HistoryDag", + rooted: bool = True, + one_sided: str = None, + one_sided_coefficients: Tuple[float, float] = (1, 1), +): """Provides functions to compute the sum over all histories in the provided reference DAG, of rooted RF distances to those histories. Args: reference_dag: The reference DAG. The sum will be computed over all RF distances to histories in this DAG + rooted: If False, use edges' splits for RF distance computation. Otherwise, use + the clade below each edge. + one_sided: May be 'left', 'right', or None. 'left' means that we count + splits (or clades, in the rooted case) which are in the reference trees but not + in the DAG tree, especially useful if trees in the DAG might be resolutions of + multifurcating trees in the reference DAG. 'right' means that we count splits or clades in + the DAG tree which are not in the reference trees, useful if the reference trees + are possibly resolutions of multifurcating trees in the DAG. If not None, + one_sided_coefficients are ignored. + one_sided_coefficients: coefficients for non-standard symmetric difference calculations + (explained in notes below) The reference DAG must have the same taxa as all the trees in the DAG on which these count - functions are used. + functions are used. If this is not true, methods using the keyword arguments produced by this + function may fail silently, returning values which mean nothing. + + This function allows computation of sums of a Robinson-Foulds distance generalized by the + coefficients ``(s, t)`` provided to the ``one_sided_coefficients`` argument (or implicitly + set by the ``one_sided`` argument). Given a tree in the DAG with set of clades (or splits) A, and + a tree in the reference DAG with set of clades B, this distance is given by: + + ``d_{s,t}(A, B) = s|B - A| + t|A - B|`` + + Notice that when s and t are both 1, this is the symmetric difference of A and B, the standard RF + distance. - The edge weight is computed using the expression 2 * N[c_e] - |T| where c_e is the clade under - the relevant edge, and |T| is the number of trees in the reference dag. This provide rooted RF - distances, meaning that the clade below each edge is used for RF distance computation. + For each tree A in a DAG, the AddFuncDict returned by this function computes the sum of this distance + over all trees B in the reference DAG. - The weights are represented by an IntState object and are shifted by a constant K, + + Note that when computing unrooted weights, the sums are over all rooted trees in the reference + DAG, so a single unrooted tree contained twice in the reference DAG with different rootings + will be counted twice. + + Weights are represented by an IntState object and are shifted by a constant K, which is the sum of number of clades in each tree in the DAG. """ - N = reference_dag.count_nodes(collapse=True) + s, t, rf_type_suffix = _process_rf_one_sided_coefficients( + one_sided, one_sided_coefficients + ) - # Remove the UA node clade union from N - try: - N.pop(frozenset()) - except KeyError: - pass + N = reference_dag.count_nodes(collapse=True, rooted=rooted) # K is the constant that the weights are shifted by - K = sum(N.values()) + K = s * sum(N.values()) - num_trees = reference_dag.count_histories() + # We also scale num_trees by s... + num_trees = t * reference_dag.count_histories() - def make_intstate(n): - return IntState(n + K, state=n) + if rooted: - def edge_func(n1, n2): - clade = n2.clade_union() - if clade in N: - weight = num_trees - (2 * N[n2.clade_union()]) - else: - # This clade's count should then just be 0: - weight = num_trees - return make_intstate(weight) + def make_intstate(n): + return IntState(n + K, state=n) + + def edge_func(n1, n2): + clade = n2.clade_union() + clade_count = N.get(clade, 0) + weight = num_trees - ((s + t) * clade_count) + return make_intstate(weight) + + kwargs = AddFuncDict( + { + "start_func": lambda n: make_intstate(0), + "edge_weight_func": edge_func, + "accum_func": lambda wlist: make_intstate( + sum(w.state for w in wlist) + ), # summation over edge weights + }, + name="RF_rooted_sum_" + rf_type_suffix, + ) + + else: + taxa = next(reference_dag.dagroot.children()).clade_union() + n_taxa = len(taxa) + + def is_history_root(n): + # TODO this is slow and dirty! Make more efficient + return len(list(n.clade_union())) == n_taxa + + def split(node): + cu = node.clade_union() + return frozenset({cu, taxa - cu}) + + # We accumulate tuples, where the first number contains the weight, + # except any contribution of a split below a bifurcating root node + # is contained in the second number. This way its contribution can be + # added exactly once + + def make_intstate(tup): + return IntState(tup[0] + tup[1] + K, state=tup) + + def summer(tupseq): + tupseq = list(tupseq) + a = 0 + for ia, _ in tupseq: + a += ia + # second value should only be counted once. Any nonzero + # values of the second value will always be identical + if len(tupseq) == 0: + b = 0 + else: + b = max(tupseq, key=lambda tup: abs(tup[1]))[1] + return (a, b) + + def edge_func(n1, n2): + spl = split(n2) + spl_count = N.get(spl, 0) + if n1.is_ua_node(): + return make_intstate((0, 0)) + else: + val = num_trees - ((s + t) * spl_count) + if len(n1.clades) == 2 and is_history_root(n1): + return make_intstate((0, val)) + else: + return make_intstate((val, 0)) + + kwargs = AddFuncDict( + { + "start_func": lambda n: make_intstate((0, 0)), + "edge_weight_func": edge_func, + "accum_func": lambda wlist: make_intstate( + summer(w.state for w in wlist) + ), # summation over edge weights + }, + name="RF_unrooted_sum_" + rf_type_suffix, + ) - kwargs = AddFuncDict( - { - "start_func": lambda n: make_intstate(0), - "edge_weight_func": edge_func, - "accum_func": lambda wlist: make_intstate( - sum(w.state for w in wlist) - ), # summation over edge weights - }, - name="RF_rooted_sum", - ) return kwargs -def make_rfdistance_countfuncs(ref_tree: "HistoryDag", rooted: bool = False): +def make_rfdistance_countfuncs( + ref_tree: "HistoryDag", + rooted: bool = False, + one_sided: str = None, + one_sided_coefficients: Tuple[float, float] = (1, 1), +): """Provides functions to compute Robinson-Foulds (RF) distances of trees in a DAG, relative to a fixed reference tree. @@ -613,13 +730,23 @@ def make_rfdistance_countfuncs(ref_tree: "HistoryDag", rooted: bool = False): ref_tree: A tree with respect to which Robinson-Foulds distance will be computed. rooted: If False, use edges' splits for RF distance computation. Otherwise, use the clade below each edge. + one_sided: May be 'left', 'right', or None. 'left' means that we count + splits (or clades, in the rooted case) which are in the reference tree but not + in the DAG tree, especially useful if trees in the DAG might be resolutions of + a multifurcating reference. 'right' means that we count splits or clades in + the DAG tree which are not in the reference tree, useful if the reference tree + is possibly a resolution of multifurcating trees in the DAG. If not None, + one_sided_coefficients are ignored. + one_sided_coefficients: coefficients for non-standard symmetric difference calculations + (explained in notes below) The reference tree must have the same taxa as all the trees in the DAG. This calculation relies on the observation that the symmetric distance between - the splits A in a tree in the DAG, and the splits B in the reference tree, can - be computed as: - ``|A ^ B| = |A U B| - |A n B| = |A - B| + |B| - |A n B|`` + the splits (or clades, in the rooted case) A in a tree in the DAG, and the splits + (or clades) B in the reference tree, can be computed as: + + ``|B ^ A| = |B - A| + |A - B| = |B| - |A n B| + |A - B|`` As long as tree edges are in bijection with splits, this can be computed without constructing the set A by considering each edge's split independently. @@ -627,12 +754,33 @@ def make_rfdistance_countfuncs(ref_tree: "HistoryDag", rooted: bool = False): In order to accommodate multiple edges with the same split in a tree with root bifurcation, we keep track of the contribution of such edges separately. + One-sided RF distances are computed in this framework by introducing a pair of + ``one_sided_coefficients`` ``(s, t)``, which affect how much weight is given to + the right and left differences in the RF distance calculation: + + ``|B ^ A| = s|B - A| + t|A - B| = s(|B| - |A n B|) + t|A - B|`` + + When both ``s`` and ``t`` are 1, we get the standard RF distance. + When ``s=1`` and ``t=0``, then we have a one-sided "left" RF difference, counting + the number of splits in the reference tree which are not in each DAG tree. When + ``one_sided`` is set to `left`, then these coefficients will be used, regardless of + the values passed. + When ``s=0`` and ``t=1``, then we have a one-sided "right" RF difference, counting + the number of splits in each DAG tree which are not in the reference. When + ``one_sided`` is set to `right`, these coefficients will be used, regardless of + the values passed. + The weight type is a tuple wrapped in an IntState object. The first tuple value `a` is the contribution of edges which are not part of a root bifurcation, where edges whose splits are in B - contribute `-1`, and edges whose splits are not in B contribute `-1`, and the second tuple + contribute `-1`, and edges whose splits are not in B contribute `1`, and the second tuple value `b` is the contribution of the edges which are part of a root bifurcation. The value of the IntState is computed as `a + sign(b) + |B|`, which on the UA node of the hDAG gives RF distance. """ + + s, t, rf_type_suffix = _process_rf_one_sided_coefficients( + one_sided, one_sided_coefficients + ) + taxa = frozenset(n.label for n in ref_tree.get_leaves()) if not rooted: @@ -646,15 +794,18 @@ def split(node): ref_splits = ref_splits - { frozenset({taxa, frozenset()}), } - shift = len(ref_splits) + shift = s * len(ref_splits) n_taxa = len(taxa) def is_history_root(n): + # TODO this is slow and dirty! Make more efficient return len(list(n.clade_union())) == n_taxa def sign(n): - return (-1) * (n < 0) + (n > 0) + # Should return the value of a single term corresponding + # to the identical root splits below a bifurcating root + return (-s) * (n < 0) + t * (n > 0) def summer(tupseq): a, b = 0, 0 @@ -677,9 +828,9 @@ def edge_func(n1, n2): return make_intstate((0, 1)) else: if spl in ref_splits: - return make_intstate((-1, 0)) + return make_intstate((-s, 0)) else: - return make_intstate((1, 0)) + return make_intstate((t, 0)) kwargs = AddFuncDict( { @@ -689,23 +840,24 @@ def edge_func(n1, n2): summer(w.state for w in wlist) ), }, - name="RF_unrooted_distance", + name="RF_unrooted_distance_" + rf_type_suffix, ) else: ref_cus = frozenset( node.clade_union() for node in ref_tree.preorder(skip_ua_node=True) ) - shift = len(ref_cus) + shift = s * len(ref_cus) def make_intstate(n): return IntState(n + shift, state=n) def edge_func(n1, n2): if n2.clade_union() in ref_cus: - return make_intstate(-1) + inval = 1 else: - return make_intstate(1) + inval = 0 + return make_intstate(t - (s + t) * inval) kwargs = AddFuncDict( { @@ -713,7 +865,7 @@ def edge_func(n1, n2): "edge_weight_func": edge_func, "accum_func": lambda wlist: make_intstate(sum(w.state for w in wlist)), }, - name="RF_rooted_distance", + name="RF_rooted_" + rf_type_suffix, ) return kwargs diff --git a/tests/test_factory.py b/tests/test_factory.py index 95258f9..d502921 100644 --- a/tests/test_factory.py +++ b/tests/test_factory.py @@ -487,12 +487,83 @@ def test_remove_label_fields(): assert old_fieldset == new_fieldset +# ############# RF Distance Tests: ############### def rooted_rf_distance(history1, history2): cladeset1 = {n.clade_union() for n in history1.preorder(skip_ua_node=True)} cladeset2 = {n.clade_union() for n in history2.preorder(skip_ua_node=True)} return len(cladeset1 ^ cladeset2) +def test_right_left_rf_add_correctly(): + # In both the rooted and unrooted cases, left and right RF distances should + # sum to the normal RF distance. + for rooted in (True, False): + for dag in dags: + ref_tree = dag.sample() + left_kwargs = dagutils.make_rfdistance_countfuncs( + ref_tree, rooted=rooted, one_sided="left" + ) + right_kwargs = dagutils.make_rfdistance_countfuncs( + ref_tree, rooted=rooted, one_sided="right" + ) + kwargs = dagutils.make_rfdistance_countfuncs(ref_tree, rooted=rooted) + + for tree in dag: + assert tree.optimal_weight_annotate( + **left_kwargs + ) + tree.optimal_weight_annotate( + **right_kwargs + ) == tree.optimal_weight_annotate( + **kwargs + ) + + +def test_right_left_rf_collapse(): + """ + When one tree is a resolution of another, one-sided RF distance should be + able to detect this with a distance of 0. The relevant descriptions from + the docstring: + + one_sided: May be 'left', 'right', or None. 'left' means that we count + splits (or clades, in the rooted case) which are in the reference tree but not + in the DAG tree, especially useful if trees in the DAG might be resolutions of + a multifurcating reference. 'right' means that we count splits or clades in + the DAG tree which are not in the reference tree, useful if the reference tree + is possibly a resolution of multifurcating trees in the DAG. If not None, + one_sided_coefficients are ignored. + """ + for rooted in (True, False): + count = 0 + for dag in dags: + for tree in dag: + ctree = tree.copy() + ctree.convert_to_collapsed() + kwargs = dagutils.make_rfdistance_countfuncs(ctree, rooted=rooted) + if tree.optimal_weight_annotate(**kwargs) == 0: + # Then they're the same topology (when unrooted, a simple + # node count isn't enough to identify this) + continue + else: + count += 1 + left_kwargs = dagutils.make_rfdistance_countfuncs( + ctree, rooted=rooted, one_sided="left" + ) + assert tree.optimal_weight_annotate(**left_kwargs) == 0 + oleft_kwargs = dagutils.make_rfdistance_countfuncs( + tree, rooted=rooted, one_sided="left" + ) + assert ctree.optimal_weight_annotate(**oleft_kwargs) > 0 + right_kwargs = dagutils.make_rfdistance_countfuncs( + ctree, rooted=rooted, one_sided="right" + ) + assert tree.optimal_weight_annotate(**right_kwargs) > 0 + oright_kwargs = dagutils.make_rfdistance_countfuncs( + tree, rooted=rooted, one_sided="right" + ) + assert ctree.optimal_weight_annotate(**oright_kwargs) == 0 + assert count > 0 + + def test_rf_rooted_distances(): for dag in dags: ref_tree = dag.sample() @@ -560,35 +631,62 @@ def rf_distance(intree): def test_optimal_sum_rf_distance(): - for dag_idx, ref_dag in enumerate(dags): + # Can only use unrooted sum RF distances on dags containing trees all on + # the same taxon set. + + def one_taxon_set(dag): + return len({n.clade_union() for n in dag.dagroot.children()}) == 1 + + dags_to_test = [dag for dag in dags if one_taxon_set(dag)] + assert len(dags_to_test) > 5 + + def other(side): + return {"right": "left", "left": "right", None: None}[side] + + for dag_idx, ref_dag in enumerate(dags_to_test): print("dagnum ", dag_idx) # let's just do this test for three trees in each dag: for tree_idx, tree in zip(range(3), ref_dag): - print("treenum ", tree_idx) - # First let's just make sure that when the ref_dag is just a single - # tree, optimal_sum_rf_distance agrees with normal rf_distance. - single_tree_dag = ref_dag[0] - # Here we get all the distances between trees in 'single_tree_dag' and the - # reference tree 'tree' (there's only one, since 'single_tree_dag' - # only contains one tree: - expected = single_tree_dag.count_rf_distances(tree, rooted=True) - expected_sum = sum(expected.elements()) - calculated_sum = tree.optimal_sum_rf_distance(single_tree_dag) - assert calculated_sum == expected_sum - - # Now let's try computing the summed rf distance on tree relative - # to ref_dag... - - # Here we get all the distances between trees in 'dag' and the - # reference tree 'tree': - expected = dag.count_rf_distances(tree, rooted=True) - # Here we sum all elements in the counter, with multiplicity: - # in other words we sum all distances from trees in 'dag' to 'tree' - expected_sum = sum(expected.elements()) - # This should calculate the sum RF distance from 'tree' to all - # trees in 'dag': - calculated_sum = tree.optimal_sum_rf_distance(dag) - assert calculated_sum == expected_sum + for one_sided in ("left", "right", None): + for rooted in (True, False): + print("treenum ", tree_idx) + print("one_side ", one_sided) + print("rooted ", rooted) + # First let's just make sure that when the ref_dag is just a single + # tree, optimal_sum_rf_distance agrees with normal rf_distance. + single_tree_dag = ref_dag[0] + # Here we get all the distances between trees in 'single_tree_dag' and the + # reference tree 'tree' (there's only one, since 'single_tree_dag' + # only contains one tree: + expected = single_tree_dag.count_rf_distances( + tree, rooted=rooted, one_sided=one_sided + ) + expected_sum = sum(expected.elements()) + calculated_sum = tree.optimal_sum_rf_distance( + single_tree_dag, rooted=rooted, one_sided=other(one_sided) + ) + assert calculated_sum == expected_sum + + # Now let's try computing the summed rf distance on tree relative + # to ref_dag... + + # Here we get all the distances between trees in 'dag' and the + # reference tree 'tree': + expected = ref_dag.count_rf_distances( + tree, rooted=rooted, one_sided=one_sided + ) + # Here we sum all elements in the counter, with multiplicity: + # in other words we sum all distances from trees in 'dag' to 'tree' + expected_sum = sum(expected.elements()) + # This should calculate the sum RF distance from 'tree' to all + # trees in 'dag': + calculated_sum = tree.optimal_sum_rf_distance( + ref_dag, rooted=rooted, one_sided=other(one_sided) + ) + assert calculated_sum == expected_sum + + +# ############# END RF Distance Tests: ############### def test_trim_range(): @@ -692,18 +790,44 @@ def test_weight_range_annotate(): def test_sum_all_pair_rf_distance(): dag = dags[-1] - # check 0 on single-tree dag vs itself: - assert dag[0].sum_rf_distances() == 0 - assert dag[0].sum_rf_distances(reference_dag=dag[0]) == 0 + small_dag_1 = dag[0] | (dag[i] for i in range(1, 7)) + small_dag_2 = dag[-1] | (dag[i] for i in range(60, 67)) + small_dag_1.summary() + small_dag_2.summary() + for rooted in (False, True): + for one_sided in ("left", "right", None): + # check 0 on single-tree dag vs itself: + assert dag[0].sum_rf_distances(rooted=rooted, one_sided=one_sided) == 0 + assert ( + dag[0].sum_rf_distances( + reference_dag=dag[0], rooted=rooted, one_sided=one_sided + ) + == 0 + ) + + # check matches single rf distance between two single-tree dags: + udag = dag.unlabel() + assert udag[0].sum_rf_distances( + reference_dag=udag[-1], rooted=rooted, one_sided=one_sided + ) == udag[0].optimal_rf_distance( + udag[-1], rooted=rooted, one_sided=one_sided + ) - # check matches single rf distance between two single-tree dags: - udag = dag.unlabel() - assert udag[0].sum_rf_distances(reference_dag=udag[-1]) == udag[ - 0 - ].optimal_rf_distance(udag[-1]) + # check matches truth on whole DAG vs self: + assert dag.sum_rf_distances(rooted=rooted, one_sided=one_sided) == sum( + dag.count_sum_rf_distances( + dag, rooted=rooted, one_sided=one_sided + ).elements() + ) - # check matches truth on whole DAG vs self: - assert dag.sum_rf_distances() == sum(dag.count_sum_rf_distances(dag).elements()) + # check matches truth on dag1 vs dag2 + assert small_dag_1.sum_rf_distances( + reference_dag=small_dag_2, rooted=rooted, one_sided=one_sided + ) == sum( + small_dag_1.count_sum_rf_distances( + small_dag_2, rooted=rooted, one_sided=one_sided + ).elements() + ) def test_sum_weight(): @@ -851,6 +975,22 @@ def test_count_nodes(): for edge in edge_counts: assert edge_counts[edge] == round(edge_supports[edge] * n_histories) + # Now counting splits: + dag = dags[-1].copy() + + def history_to_splits(history): + splits = set() + all_taxa = next(history.dagroot.children()).clade_union() + for node in history.preorder(skip_ua_node=True): + node_clade = node.clade_union() + splits.add(frozenset({node_clade, all_taxa - node_clade})) + return splits - frozenset({all_taxa, frozenset()}) + + split_sets = [history_to_splits(history) for history in dag] + split_counts = dag.count_nodes(collapse=True, rooted=False) + for split, count in split_counts.items(): + assert count == sum(1 for s in split_sets if split in s) + def test_likelihoods(): dag = dags[-1]