Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rearrange and enforce adj_group and group_membership being on the sam… #666

Merged
merged 1 commit into from
Jan 16, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 39 additions & 41 deletions gnomad/utils/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1957,87 +1957,85 @@ def agg_by_strata(
'group_membership' annotation that is a list of bools to aggregate the columns by.

:param mt: Input MatrixTable.
:param entry_agg_funcs: Optional dict of entry aggregation functions. When
specified, additional annotations are added to the output Table/MatrixTable.
The keys of the dict are the names of the annotations and the values are tuples
:param entry_agg_funcs: Dict of entry aggregation functions where the
keys of the dict are the names of the annotations and the values are tuples
of functions. The first function is used to transform the `mt` entries in some
way, and the second function is used to aggregate the output from the first
function.
:param select_fields: Optional list of row fields from `mt` to keep on the output
Table.
:param group_membership_ht: Optional Table containing group membership annotations
to stratify the coverage stats by. If not provided, the 'group_membership'
to stratify the aggregations by. If not provided, the 'group_membership'
annotation is expected to be present on `mt`.
:return: Table or MatrixTable with allele frequencies by strata.
:return: Table with annotations of stratified aggregations.
"""
if entry_agg_funcs is None:
entry_agg_funcs = {}
if select_fields is None:
select_fields = []

n_samples = mt.count_cols()
global_expr = {}
if "adj_group" in mt.index_globals():
global_expr["adj_group"] = mt.index_globals().adj_group
logger.info("Using the 'adj_group' global annotation found on the input MT.")
raise TypeError(
"'agg_by_strata' expects a 'entry_agg_funcs' dictionary but it was not"
" supplied. Without the dictionary, no aggregations will occur."
)

if group_membership_ht is None and "group_membership" not in mt.col:
raise ValueError(
"The 'group_membership' annotation is not found in the input MatrixTable "
"and 'group_membership_ht' is not specified."
)
elif group_membership_ht is None:

if select_fields is None:
select_fields = []

if group_membership_ht is None:
logger.info(
"'group_membership_ht' is not specified, using sample stratification "
"indicated by the 'group_membership' annotation on mt."
)
n_groups = len(mt.group_membership.take(1)[0])
group_globals = mt.index_globals()
else:
logger.info(
"'group_membership_ht' is specified, using sample stratification indicated "
"by its 'group_membership' annotation."
)
group_globals = group_membership_ht.index_globals()
n_groups = len(group_membership_ht.group_membership.take(1)[0])
mt = mt.annotate_cols(
group_membership=group_membership_ht[mt.col_key].group_membership
)
if "adj_group" not in global_expr:
if "adj_group" in group_globals:
global_expr["adj_group"] = group_globals.adj_group
logger.info(
"Using the 'adj_group' global annotation on 'group_membership_ht'."
)
elif "freq_meta" in group_globals:
logger.info(
"The 'freq_meta' global annotation is found in "
"'group_membership_ht', using it to determine the adj filtered "
"stratification groups."
)
freq_meta = group_globals.freq_meta
global_expr["adj_group"] = freq_meta.map(
lambda x: x.get("group", "NA") == "adj"
)

if "adj_group" not in global_expr:
global_expr = {}
n_groups = len(mt.group_membership.take(1)[0])
if "adj_group" in group_globals:
global_expr["adj_group"] = group_globals.adj_group
logger.info("Using the 'adj_group' global annotation on 'group_membership_ht'.")
elif "freq_meta" in group_globals:
logger.info(
"The 'freq_meta' global annotation is found in "
"'group_membership_ht', using it to determine the adj filtered "
"stratification groups."
)
freq_meta = group_globals.freq_meta
global_expr["adj_group"] = freq_meta.map(
lambda x: x.get("group", "NA") == "adj"
)
else:
global_expr["adj_group"] = hl.range(n_groups).map(lambda x: False)

# NOTE: Unsure if we still want this check here since the adj_group and n_groups
# always be from the same table or built within this function? Its a cheap operation
# so I'm leaning towards keeping it even though I'm not sure this is the right place
# for this check.
n_adj_group = hl.eval(hl.len(global_expr["adj_group"]))
if hl.eval(hl.len(global_expr["adj_group"])) != n_groups:
if n_adj_group != n_groups:
raise ValueError(
f"The number of elements in the 'adj_group' ({n_adj_group}) global "
"annotation does not match the number of elements in the "
f"'group_membership' annotation ({n_groups})!",
)

# Keep only the entries needed for the aggregation functions.
select_expr = {}
has_adj = False
if hl.eval(hl.any(global_expr["adj_group"])):
select_expr = {**{ann: f[0](mt) for ann, f in entry_agg_funcs.items()}}
has_adj = hl.eval(hl.any(global_expr["adj_group"]))
if has_adj:
select_expr["adj"] = mt.adj
has_adj = True

select_expr.update(**{ann: f[0](mt) for ann, f in entry_agg_funcs.items()})
mt = mt.select_entries(**select_expr)

# Convert MT to HT with a row annotation that is an array of all samples entries
Expand All @@ -2047,7 +2045,7 @@ def agg_by_strata(
# For each stratification group in group_membership, determine the indices of the
# samples that belong to that group.
global_expr["indices_by_group"] = hl.range(n_groups).map(
lambda g_i: hl.range(n_samples).filter(
lambda g_i: hl.range(mt.count_cols()).filter(
lambda s_i: ht.cols[s_i].group_membership[g_i]
)
)
Expand Down