Skip to content

Commit

Permalink
Merge pull request #80 from matsengrp/wd-one-sided-rf
Browse files Browse the repository at this point in the history
One Sided RF distances
  • Loading branch information
marybarker authored Dec 13, 2023
2 parents 0a74f57 + 3707fea commit 262eae4
Show file tree
Hide file tree
Showing 4 changed files with 530 additions and 117 deletions.
1 change: 0 additions & 1 deletion historydag/beast_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ def dag_from_beast_trees(
include_sequence_names_in_labels: If True, augment leaf node labels with a ``name`` attribute
containing the name of the corresponding sequence. Useful for distinguishing leaves when
observed sequences are not unique.
"""
dp_trees = load_beast_trees(
beast_xml_file,
Expand Down
186 changes: 154 additions & 32 deletions historydag/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from collections import Counter, namedtuple
from copy import deepcopy
from historydag import utils
from historydag.utils import Weight, Label, UALabel, prod
from historydag.utils import Weight, Label, UALabel, prod, TaxaError
from historydag.counterops import counter_sum, counter_prod
import historydag.parsimony_utils as parsimony_utils
from historydag.dag_node import (
Expand Down Expand Up @@ -1977,7 +1977,7 @@ def default_accum_above_edge(subtree_weight, edge_weight):

return downward_weights, upward_weights

def count_nodes(self, collapse=False) -> Dict[HistoryDagNode, int]:
def count_nodes(self, collapse=False, rooted=True) -> Dict[HistoryDagNode, int]:
"""Counts the number of trees each node takes part in.
For node supports with respect to a uniform distribution on trees, use
Expand All @@ -1987,6 +1987,11 @@ def count_nodes(self, collapse=False) -> Dict[HistoryDagNode, int]:
collapse: A flag that when set to true, treats nodes as clade unions and
ignores label information. Then, the returned dictionary is keyed by
clade union sets.
rooted: A flag which is ignored unless ``collapse`` is ``True``. When ``rooted`` is also ``False``,
the returned dictionary is keyed by splits -- that is, sets containing each clade
union and its complement, with values the number of (rooted) trees in the DAG containing
each split. Splits are not double-counted when a tree has a bifurcating root.
If False, dag is expected to have trees all on the same set of leaf labels.
Returns:
A dictionary mapping each node in the DAG to the number of trees
Expand Down Expand Up @@ -2031,7 +2036,37 @@ def count_nodes(self, collapse=False) -> Dict[HistoryDagNode, int]:
collapsed_n2c[clade] = 0

collapsed_n2c[clade] += node2count[node]
return collapsed_n2c
if rooted:
# Remove the UA node clade union from N
try:
collapsed_n2c.pop(frozenset())
except KeyError:
pass
return collapsed_n2c
else:
# Create dictionary counting in how many trees each split
# occurs as child of bifurcating root
split2adjustment = {}
all_taxa = next(self.dagroot.children()).clade_union()
if any(all_taxa != n.clade_union() for n in self.dagroot.children()):
raise TaxaError(
"Unrooted splits cannot be counted properly because"
" trees in this dag are on different sets of taxa."
)
for treeroot in self.dagroot.children():
if len(treeroot.clades) == 2:
split = frozenset(treeroot.clades.keys())
before = split2adjustment.get(split, 0)
split2adjustment[split] = before + node2count[treeroot]
split2count = {}
for clade, count in collapsed_n2c.items():
split = frozenset({clade, all_taxa - clade})
before = split2count.get(split, 0)
split2count[split] = before + count
for split, adjustment in split2adjustment.items():
split2count[split] -= adjustment
split2count.pop(frozenset({all_taxa, frozenset()}), None)
return split2count
else:
return node2count

Expand Down Expand Up @@ -2258,6 +2293,9 @@ def overestimate_rf_diameter(self):
def optimal_sum_rf_distance(
self,
reference_dag: "HistoryDag",
rooted: bool = True,
one_sided: str = None,
one_sided_coefficients: Tuple[float, float] = (1, 1),
optimal_func: Callable[[List[Weight]], Weight] = min,
):
"""Returns the optimal (min or max) summed rooted RF distance to all
Expand All @@ -2269,17 +2307,28 @@ def optimal_sum_rf_distance(
instead of making multiple calls to this method with the same reference
history DAG.
"""
kwargs = utils.sum_rfdistance_funcs(reference_dag)
kwargs = utils.sum_rfdistance_funcs(
reference_dag,
rooted=rooted,
one_sided=one_sided,
one_sided_coefficients=one_sided_coefficients,
)
return self.optimal_weight_annotate(**kwargs, optimal_func=optimal_func)

def trim_optimal_sum_rf_distance(
self,
reference_dag: "HistoryDag",
rooted: bool = True,
one_sided: str = None,
one_sided_coefficients: Tuple[float, float] = (1, 1),
optimal_func: Callable[[List[Weight]], Weight] = min,
):
"""Trims the DAG to contain only histories with the optimal (min or
max) sum rooted RF distance to the given reference DAG.
See :meth:`utils.sum_rfdistance_funcs` for detailed documentation of
arguments.
Trimming to the minimum sum RF distance is equivalent to finding 'median' topologies,
and trimming to maximum sum rf distance is equivalent to finding topological outliers.
Expand All @@ -2289,18 +2338,28 @@ def trim_optimal_sum_rf_distance(
instead of making multiple calls to this method with the same reference
history.
"""
kwargs = utils.sum_rfdistance_funcs(reference_dag)
kwargs = utils.sum_rfdistance_funcs(
reference_dag,
rooted=rooted,
one_sided=one_sided,
one_sided_coefficients=one_sided_coefficients,
)
return self.trim_optimal_weight(**kwargs, optimal_func=optimal_func)

def trim_optimal_rf_distance(
self,
history: "HistoryDag",
rooted: bool = False,
one_sided: str = None,
one_sided_coefficients: Tuple[float, float] = (1, 1),
optimal_func: Callable[[List[Weight]], Weight] = min,
):
"""Trims this history DAG to the optimal (min or max) RF distance to a
given history.
See :meth:`utils.make_rfdistance_countfuncs` for detailed documentation of
arguments.
Also returns that optimal RF distance
The given history must be on the same taxa as all trees in the DAG.
Expand All @@ -2309,58 +2368,124 @@ def trim_optimal_rf_distance(
instead of making multiple calls to this method with the same reference
history.
"""
kwargs = utils.make_rfdistance_countfuncs(history, rooted=rooted)
kwargs = utils.make_rfdistance_countfuncs(
history,
rooted=rooted,
one_sided=one_sided,
one_sided_coefficients=one_sided_coefficients,
)
return self.trim_optimal_weight(**kwargs, optimal_func=optimal_func)

def optimal_rf_distance(
self,
history: "HistoryDag",
rooted: bool = False,
one_sided: str = None,
one_sided_coefficients: Tuple[float, float] = (1, 1),
optimal_func: Callable[[List[Weight]], Weight] = min,
):
"""Returns the optimal (min or max) RF distance to a given history.
See :meth:`utils.make_rfdistance_countfuncs` for detailed documentation of
arguments.
The given history must be on the same taxa as all trees in the DAG.
Since computing reference splits is expensive, it is better to use
:meth:`optimal_weight_annotate` and :meth:`utils.make_rfdistance_countfuncs`
instead of making multiple calls to this method with the same reference
history.
"""
kwargs = utils.make_rfdistance_countfuncs(history, rooted=rooted)
kwargs = utils.make_rfdistance_countfuncs(
history,
rooted=rooted,
one_sided=one_sided,
one_sided_coefficients=one_sided_coefficients,
)
return self.optimal_weight_annotate(**kwargs, optimal_func=optimal_func)

def count_rf_distances(self, history: "HistoryDag", rooted: bool = False):
def count_rf_distances(
self,
history: "HistoryDag",
rooted: bool = False,
one_sided: str = None,
one_sided_coefficients: Tuple[float, float] = (1, 1),
):
"""Returns a Counter containing all RF distances to a given history.
The given history must be on the same taxa as all trees in the DAG.
See :meth:`utils.make_rfdistance_countfuncs` for detailed documentation of
arguments.
Since computing reference splits is expensive, it is better to use
:meth:`weight_count` and :meth:`utils.make_rfdistance_countfuncs`
instead of making multiple calls to this method with the same reference
history.
"""
kwargs = utils.make_rfdistance_countfuncs(history, rooted=rooted)
kwargs = utils.make_rfdistance_countfuncs(
history,
rooted=rooted,
one_sided=one_sided,
one_sided_coefficients=one_sided_coefficients,
)
return self.weight_count(**kwargs)

def count_sum_rf_distances(self, reference_dag: "HistoryDag", rooted: bool = False):
def count_sum_rf_distances(
self,
reference_dag: "HistoryDag",
rooted: bool = True,
one_sided: str = None,
one_sided_coefficients: Tuple[float, float] = (1, 1),
):
"""Returns a Counter containing all sum RF distances to a given
reference DAG.
See :meth:`utils.sum_rfdistance_funcs` for detailed documentation of
arguments.
The given history DAG must be on the same taxa as all trees in the DAG.
Since computing reference splits is expensive, it is better to use
:meth:`weight_count` and :meth:`utils.sum_rfdistance_funcs`
instead of making multiple calls to this method with the same reference
history DAG.
"""
kwargs = utils.sum_rfdistance_funcs(reference_dag)
kwargs = utils.sum_rfdistance_funcs(
reference_dag,
rooted=rooted,
one_sided=one_sided,
one_sided_coefficients=one_sided_coefficients,
)
return self.weight_count(**kwargs)

def sum_rf_distances(self, reference_dag: "HistoryDag" = None):
r"""Computes the sum of all Robinson-Foulds distances between a history
in this DAG and a history in the reference DAG.
def sum_rf_distances(
self,
reference_dag: "HistoryDag" = None,
rooted: bool = True,
one_sided: str = None,
one_sided_coefficients: Tuple[float, float] = (1, 1),
):
r"""Computes the sum of Robinson-Foulds distances over all pairs of
histories in this DAG and the provided reference DAG.
Args:
reference_dag: If None, the sum of pairwise distances between histories in this DAG
is computed. If provided, the sum is over pairs containing one history in this DAG and
one from ``reference_dag``.
rooted: If False, use edges' splits for RF distance computation. Otherwise, use
the clade below each edge.
one_sided: May be 'left', 'right', or None. 'left' means that we count
splits (or clades, in the rooted case) which are in the reference trees but not
in the DAG tree, especially useful if trees in the DAG might be resolutions of
multifurcating trees in the reference DAG. 'right' means that we count splits or clades in
the DAG tree which are not in the reference trees, useful if the reference trees
are possibly resolutions of multifurcating trees in the DAG. If not None,
one_sided_coefficients are ignored.
one_sided_coefficients: coefficients for non-standard symmetric difference calculations.
See :meth:`utils.make_rfdistance_countfuncs` for more details.
This is rooted RF distance.
Returns:
An integer sum of RF distances.
If T is the set of histories in the reference DAG, and T' is the set of histories in
this DAG, then the returned sum is:
Expand All @@ -2372,22 +2497,16 @@ def sum_rf_distances(self, reference_dag: "HistoryDag" = None):
That is, since RF distance is symmetric, when T = T' (such as when ``reference_dag=None``),
or when the intersection of T and T' is nonempty, some distances are counted twice.
Args:
reference_dag: If None, the sum of pairwise distances between histories in this DAG
is computed. If provided, the sum is over pairs containing one history in this DAG and
one from ``reference_dag``.
Returns:
An integer sum of RF distances.
Note that when computing one-sided distances, or when the one_sided_coefficients values are not
equal, this 'distance' is no longer symmetric.
"""
s, t, _ = utils._process_rf_one_sided_coefficients(
one_sided, one_sided_coefficients
)

def get_data(dag):
n_histories = dag.count_histories()
N = dag.count_nodes(collapse=True)
try:
N.pop(frozenset())
except KeyError:
pass
N = dag.count_nodes(collapse=True, rooted=rooted)

clade_count_sum = sum(N.values())
return (n_histories, N, clade_count_sum)
Expand All @@ -2410,13 +2529,13 @@ def get_data(dag):
)

return (
n_histories * clade_count_sum_prime
+ n_histories_prime * clade_count_sum
- 2 * intersection_term
t * n_histories * clade_count_sum_prime
+ s * n_histories_prime * clade_count_sum
- (s + t) * intersection_term
)

def average_pairwise_rf_distance(
self, reference_dag: "HistoryDag" = None, non_identical=True
self, reference_dag: "HistoryDag" = None, non_identical=True, **kwargs
):
"""Return the average Robinson-Foulds distance between pairs of
histories.
Expand All @@ -2425,6 +2544,7 @@ def average_pairwise_rf_distance(
reference_dag: A history DAG from which to take the second history in
each pair. If None, ``self`` will be used as the reference.
non_identical: If True, mean divisor will be the number of non-identical pairs.
kwargs: See :meth:`historydag.sum_rf_distances` for additional keyword arguments
Returns:
The average rf-distance between pairs of histories, where the first history
Expand All @@ -2433,7 +2553,9 @@ def average_pairwise_rf_distance(
``non_identical`` is True, in which case the number of histories which appear
in both DAGs is subtracted from this constant.
"""
sum_pairwise_distance = self.sum_rf_distances(reference_dag=reference_dag)
sum_pairwise_distance = self.sum_rf_distances(
reference_dag=reference_dag, **kwargs
)
if reference_dag is None:
# ignore the diagonal in the distance matrix, since it contains
# zeros:
Expand Down
Loading

0 comments on commit 262eae4

Please sign in to comment.