Skip to content

Commit

Permalink
format and lint
Browse files Browse the repository at this point in the history
  • Loading branch information
willdumm committed Dec 11, 2023
1 parent 378f45c commit ec0ad60
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 23 deletions.
1 change: 0 additions & 1 deletion historydag/beast_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ def dag_from_beast_trees(
include_sequence_names_in_labels: If True, augment leaf node labels with a ``name`` attribute
containing the name of the corresponding sequence. Useful for distinguishing leaves when
observed sequences are not unique.
"""
dp_trees = load_beast_trees(
beast_xml_file,
Expand Down
6 changes: 3 additions & 3 deletions historydag/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -2315,7 +2315,7 @@ def trim_optimal_rf_distance(
history,
rooted=rooted,
one_sided=one_sided,
one_sided_coefficients=one_sided_coefficients
one_sided_coefficients=one_sided_coefficients,
)
return self.trim_optimal_weight(**kwargs, optimal_func=optimal_func)

Expand All @@ -2339,7 +2339,7 @@ def optimal_rf_distance(
history,
rooted=rooted,
one_sided=one_sided,
one_sided_coefficients=one_sided_coefficients
one_sided_coefficients=one_sided_coefficients,
)
return self.optimal_weight_annotate(**kwargs, optimal_func=optimal_func)

Expand All @@ -2363,7 +2363,7 @@ def count_rf_distances(
history,
rooted=rooted,
one_sided=one_sided,
one_sided_coefficients=one_sided_coefficients
one_sided_coefficients=one_sided_coefficients,
)
return self.weight_count(**kwargs)

Expand Down
22 changes: 11 additions & 11 deletions historydag/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,27 +665,28 @@ def make_rfdistance_countfuncs(
of the IntState is computed as `a + sign(b) + |B|`, which on the UA node of the hDAG gives RF distance.
"""

rf_type_suffix = 'distance'
if one_sided_coefficients != (1,1):
rf_type_suffix = 'nonstandard'
rf_type_suffix = "distance"
if one_sided_coefficients != (1, 1):
rf_type_suffix = "nonstandard"

if one_sided is None:
pass
elif one_sided.lower() == 'left':
elif one_sided.lower() == "left":
one_sided_coefficients = (1, 0)
one_sided_suffix = 'left_difference'
elif one_sided.lower() == 'right':
rf_type_suffix = "left_difference"
elif one_sided.lower() == "right":
one_sided_coefficients = (0, 1)
one_sided_suffix = 'right_difference'
rf_type_suffix = "right_difference"
else:
raise ValueError(f"Argument `one_sided` must have value 'left', 'right', or None, not {one_sided}")
raise ValueError(
f"Argument `one_sided` must have value 'left', 'right', or None, not {one_sided}"
)

s, t = one_sided_coefficients

taxa = frozenset(n.label for n in ref_tree.get_leaves())

if not rooted:
# TODO sidedness not tested for rooted

def split(node):
cu = node.clade_union()
Expand Down Expand Up @@ -742,10 +743,9 @@ def edge_func(n1, n2):
summer(w.state for w in wlist)
),
},
name="RF_unrooted_distance",
name="RF_unrooted_distance_" + rf_type_suffix,
)
else:
# TODO sidedness not tested for unrooted
ref_cus = frozenset(
node.clade_union() for node in ref_tree.preorder(skip_ua_node=True)
)
Expand Down
37 changes: 29 additions & 8 deletions tests/test_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,18 +493,30 @@ def rooted_rf_distance(history1, history2):
cladeset2 = {n.clade_union() for n in history2.preorder(skip_ua_node=True)}
return len(cladeset1 ^ cladeset2)


def test_right_left_rf_add_correctly():
# In both the rooted and unrooted cases, left and right RF distances should
# sum to the normal RF distance.
for rooted in (True, False):
for dag in dags:
ref_tree = dag.sample()
left_kwargs = dagutils.make_rfdistance_countfuncs(ref_tree, rooted=rooted, one_sided='left')
right_kwargs = dagutils.make_rfdistance_countfuncs(ref_tree, rooted=rooted, one_sided='right')
left_kwargs = dagutils.make_rfdistance_countfuncs(
ref_tree, rooted=rooted, one_sided="left"
)
right_kwargs = dagutils.make_rfdistance_countfuncs(
ref_tree, rooted=rooted, one_sided="right"
)
kwargs = dagutils.make_rfdistance_countfuncs(ref_tree, rooted=rooted)

for tree in dag:
assert tree.optimal_weight_annotate(**left_kwargs) + tree.optimal_weight_annotate(**right_kwargs) == tree.optimal_weight_annotate(**kwargs)
assert tree.optimal_weight_annotate(
**left_kwargs
) + tree.optimal_weight_annotate(
**right_kwargs
) == tree.optimal_weight_annotate(
**kwargs
)


def test_right_left_rf_collapse():
"""
Expand Down Expand Up @@ -533,18 +545,25 @@ def test_right_left_rf_collapse():
continue
else:
count += 1
left_kwargs = dagutils.make_rfdistance_countfuncs(ctree, rooted=rooted, one_sided='left')
left_kwargs = dagutils.make_rfdistance_countfuncs(
ctree, rooted=rooted, one_sided="left"
)
assert tree.optimal_weight_annotate(**left_kwargs) == 0
oleft_kwargs = dagutils.make_rfdistance_countfuncs(tree, rooted=rooted, one_sided='left')
oleft_kwargs = dagutils.make_rfdistance_countfuncs(
tree, rooted=rooted, one_sided="left"
)
assert ctree.optimal_weight_annotate(**oleft_kwargs) > 0
right_kwargs = dagutils.make_rfdistance_countfuncs(ctree, rooted=rooted, one_sided='right')
right_kwargs = dagutils.make_rfdistance_countfuncs(
ctree, rooted=rooted, one_sided="right"
)
assert tree.optimal_weight_annotate(**right_kwargs) > 0
oright_kwargs = dagutils.make_rfdistance_countfuncs(tree, rooted=rooted, one_sided='right')
oright_kwargs = dagutils.make_rfdistance_countfuncs(
tree, rooted=rooted, one_sided="right"
)
assert ctree.optimal_weight_annotate(**oright_kwargs) == 0
assert count > 0



def test_rf_rooted_distances():
for dag in dags:
ref_tree = dag.sample()
Expand Down Expand Up @@ -642,8 +661,10 @@ def test_optimal_sum_rf_distance():
calculated_sum = tree.optimal_sum_rf_distance(dag)
assert calculated_sum == expected_sum


# ############# END RF Distance Tests: ###############


def test_trim_range():
for curr_dag in [dags[-1], cdags[-1]]:
history_dag = curr_dag.copy()
Expand Down

0 comments on commit ec0ad60

Please sign in to comment.