Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add function agg_by_strata, which is a generalized version of the compute_freq_by_strata #659

Merged
merged 9 commits into from
Jan 17, 2024
196 changes: 152 additions & 44 deletions gnomad/utils/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1901,80 +1901,188 @@ def compute_freq_by_strata(
True.
:return: Table or MatrixTable with allele frequencies by strata.
"""
if not group_membership_includes_raw_group:
# Add the 'raw' group to the 'group_membership' annotation.
mt = mt.annotate_cols(
group_membership=hl.array([mt.group_membership[0]]).extend(
mt.group_membership
)
)

# Add adj_groups global annotation indicating that the second element in
# group_membership is 'raw' and all others are 'adj'.
mt = mt.annotate_globals(
adj_groups=hl.range(hl.len(mt.group_membership.take(1)[0])).map(
lambda x: x != 1
)
)

if entry_agg_funcs is None:
entry_agg_funcs = {}

def _get_freq_expr(gt_expr: hl.expr.CallExpression) -> hl.expr.StructExpression:
"""
Get struct expression with call statistics.

:param gt_expr: CallExpression to compute call statistics on.
:return: StructExpression with call statistics.
"""
# Get the source Table for the CallExpression to grab alleles.
ht = gt_expr._indices.source
freq_expr = hl.agg.call_stats(gt_expr, ht.alleles)
# Select non-ref allele (assumes bi-allelic).
freq_expr = freq_expr.annotate(
AC=freq_expr.AC[1],
AF=freq_expr.AF[1],
homozygote_count=freq_expr.homozygote_count[1],
)

return freq_expr

entry_agg_funcs["freq"] = (lambda x: x.GT, _get_freq_expr)

return agg_by_strata(mt, entry_agg_funcs, select_fields).drop("adj_groups")


def agg_by_strata(
mt: hl.MatrixTable,
entry_agg_funcs: Dict[str, Tuple[Callable, Callable]],
select_fields: Optional[List[str]] = None,
group_membership_ht: Optional[hl.Table] = None,
) -> hl.Table:
"""
Get row expression for annotations of each entry aggregation function(s) by strata.

The entry aggregation functions are applied to the MatrixTable entries and
aggregated. If no `group_membership_ht` (like the one returned by
`generate_freq_group_membership_array`) is supplied, `mt` must contain a
'group_membership' annotation that is a list of bools to aggregate the columns by.

:param mt: Input MatrixTable.
:param entry_agg_funcs: Dict of entry aggregation functions where the
keys of the dict are the names of the annotations and the values are tuples
of functions. The first function is used to transform the `mt` entries in some
way, and the second function is used to aggregate the output from the first
function.
:param select_fields: Optional list of row fields from `mt` to keep on the output
Table.
:param group_membership_ht: Optional Table containing group membership annotations
to stratify the aggregations by. If not provided, the 'group_membership'
annotation is expected to be present on `mt`.
:return: Table with annotations of stratified aggregations.
"""
if group_membership_ht is None and "group_membership" not in mt.col:
raise ValueError(
"The 'group_membership' annotation is not found in the input MatrixTable "
"and 'group_membership_ht' is not specified."
)

if select_fields is None:
select_fields = []

n_samples = mt.count_cols()
if group_membership_ht is None:
logger.info(
"'group_membership_ht' is not specified, using sample stratification "
"indicated by the 'group_membership' annotation on the input MatrixTable."
)
group_globals = mt.index_globals()
else:
logger.info(
"'group_membership_ht' is specified, using sample stratification indicated "
"by its 'group_membership' annotation."
)
group_globals = group_membership_ht.index_globals()
mt = mt.annotate_cols(
group_membership=group_membership_ht[mt.col_key].group_membership
)

global_expr = {}
n_groups = len(mt.group_membership.take(1)[0])
if "adj_groups" in group_globals:
logger.info(
"Using the 'adj_groups' global annotation to determine adj filtered "
"stratification groups."
)
global_expr["adj_groups"] = group_globals.adj_groups
elif "freq_meta" in group_globals:
logger.info(
"No 'adj_groups' global annotation found, using the 'freq_meta' global "
"annotation to determine adj filtered stratification groups."
)
global_expr["adj_groups"] = group_globals.freq_meta.map(
lambda x: x.get("group", "NA") == "adj"
)
else:
logger.info(
"No 'adj_groups' or 'freq_meta' global annotations found. All groups will "
"be considered non-adj."
)
global_expr["adj_groups"] = hl.range(n_groups).map(lambda x: False)

n_adj_groups = hl.eval(hl.len(global_expr["adj_groups"]))
if n_adj_groups != n_groups:
raise ValueError(
f"The number of elements in the 'adj_groups' ({n_adj_groups}) global "
"annotation does not match the number of elements in the "
f"'group_membership' annotation ({n_groups})!",
)

# Keep only the entries needed for the aggregation functions.
select_expr = {**{ann: f[0](mt) for ann, f in entry_agg_funcs.items()}}
has_adj = hl.eval(hl.any(global_expr["adj_groups"]))
if has_adj:
select_expr["adj"] = mt.adj

mt = mt.select_entries(**select_expr)

# Convert MT to HT with a row annotation that is an array of all samples entries
# for that variant.
ht = mt.localize_entries("entries", "cols")
ht = ht.annotate_globals(
indices_by_group=hl.range(n_groups).map(
lambda g_i: hl.range(n_samples).filter(
lambda s_i: ht.cols[s_i].group_membership[g_i]
)

# For each stratification group in group_membership, determine the indices of the
# samples that belong to that group.
global_expr["indices_by_group"] = hl.range(n_groups).map(
lambda g_i: hl.range(mt.count_cols()).filter(
lambda s_i: ht.cols[s_i].group_membership[g_i]
)
)
ht = ht.annotate_globals(**global_expr)

# Pull out each annotation that will be used in the array aggregation below as its
# own ArrayExpression. This is important to prevent memory issues when performing
# the below array aggregations.
ht = ht.select(
*select_fields,
adj_array=ht.entries.map(lambda e: e.adj),
gt_array=ht.entries.map(lambda e: e.GT),
**{
ann: hl.map(lambda e, s: f[0](e, s), ht.entries, ht.cols)
for ann, f in entry_agg_funcs.items()
},
ann: ht.entries.map(lambda e: e[ann])
for ann in select_fields + list(select_expr.keys())
}
)

def _agg_by_group(
ht: hl.Table, agg_func: Callable, ann_expr: hl.expr.ArrayExpression, *args
ht: hl.Table, agg_func: Callable, ann_expr: hl.expr.ArrayExpression
) -> hl.expr.ArrayExpression:
"""
Aggregate `agg_expr` by group using the `agg_func` function.

:param ht: Input Hail Table.
:param agg_func: Aggregation function to apply to `agg_expr`.
:param agg_expr: Expression to aggregate by group.
:param args: Additional arguments to pass to the `agg_func`.
:param agg_func: Aggregation function to apply to `ann_expr`.
:param ann_expr: Expression to aggregate by group.
:return: Aggregated array expression.
"""
adj_agg_expr = ht.indices_by_group.map(
lambda s_indices: s_indices.aggregate(
lambda i: hl.agg.filter(ht.adj_array[i], agg_func(ann_expr[i], *args))
)
f = lambda i, adj: agg_func(ann_expr[i])
if has_adj:
f = lambda i, adj: hl.if_else(adj, hl.agg.filter(ht.adj[i], f), f)

return hl.map(
lambda s_indices, adj: s_indices.aggregate(lambda i: f(i, adj)),
ht.indices_by_group,
ht.adj_groups,
)
# Create final agg list by inserting or changing the "raw" group,
# representing all samples, in the adj_agg_list.
raw_agg_expr = ann_expr.aggregate(lambda x: agg_func(x, *args))
if group_membership_includes_raw_group:
extend_idx = 2
else:
extend_idx = 1

adj_agg_expr = (
adj_agg_expr[:1].append(raw_agg_expr).extend(adj_agg_expr[extend_idx:])
)

return adj_agg_expr

freq_expr = _agg_by_group(ht, hl.agg.call_stats, ht.gt_array, ht.alleles)

# Select non-ref allele (assumes bi-allelic).
freq_expr = freq_expr.map(
lambda cs: cs.annotate(
AC=cs.AC[1],
AF=cs.AF[1],
homozygote_count=cs.homozygote_count[1],
)
)
# Add annotations for any supplied entry transform and aggregation functions.
ht = ht.select(
*select_fields,
**{ann: _agg_by_group(ht, f[1], ht[ann]) for ann, f in entry_agg_funcs.items()},
freq=freq_expr,
)

return ht.drop("cols")
Expand Down