diff --git a/gnomad/utils/annotations.py b/gnomad/utils/annotations.py index b59fed83a..b99e6135e 100644 --- a/gnomad/utils/annotations.py +++ b/gnomad/utils/annotations.py @@ -1901,80 +1901,188 @@ def compute_freq_by_strata( True. :return: Table or MatrixTable with allele frequencies by strata. """ + if not group_membership_includes_raw_group: + # Add the 'raw' group to the 'group_membership' annotation. + mt = mt.annotate_cols( + group_membership=hl.array([mt.group_membership[0]]).extend( + mt.group_membership + ) + ) + + # Add adj_groups global annotation indicating that the second element in + # group_membership is 'raw' and all others are 'adj'. + mt = mt.annotate_globals( + adj_groups=hl.range(hl.len(mt.group_membership.take(1)[0])).map( + lambda x: x != 1 + ) + ) + if entry_agg_funcs is None: entry_agg_funcs = {} + + def _get_freq_expr(gt_expr: hl.expr.CallExpression) -> hl.expr.StructExpression: + """ + Get struct expression with call statistics. + + :param gt_expr: CallExpression to compute call statistics on. + :return: StructExpression with call statistics. + """ + # Get the source Table for the CallExpression to grab alleles. + ht = gt_expr._indices.source + freq_expr = hl.agg.call_stats(gt_expr, ht.alleles) + # Select non-ref allele (assumes bi-allelic). + freq_expr = freq_expr.annotate( + AC=freq_expr.AC[1], + AF=freq_expr.AF[1], + homozygote_count=freq_expr.homozygote_count[1], + ) + + return freq_expr + + entry_agg_funcs["freq"] = (lambda x: x.GT, _get_freq_expr) + + return agg_by_strata(mt, entry_agg_funcs, select_fields).drop("adj_groups") + + +def agg_by_strata( + mt: hl.MatrixTable, + entry_agg_funcs: Dict[str, Tuple[Callable, Callable]], + select_fields: Optional[List[str]] = None, + group_membership_ht: Optional[hl.Table] = None, +) -> hl.Table: + """ + Get row expression for annotations of each entry aggregation function(s) by strata. + + The entry aggregation functions are applied to the MatrixTable entries and + aggregated. If no `group_membership_ht` (like the one returned by + `generate_freq_group_membership_array`) is supplied, `mt` must contain a + 'group_membership' annotation that is a list of bools to aggregate the columns by. + + :param mt: Input MatrixTable. + :param entry_agg_funcs: Dict of entry aggregation functions where the + keys of the dict are the names of the annotations and the values are tuples + of functions. The first function is used to transform the `mt` entries in some + way, and the second function is used to aggregate the output from the first + function. + :param select_fields: Optional list of row fields from `mt` to keep on the output + Table. + :param group_membership_ht: Optional Table containing group membership annotations + to stratify the aggregations by. If not provided, the 'group_membership' + annotation is expected to be present on `mt`. + :return: Table with annotations of stratified aggregations. + """ + if group_membership_ht is None and "group_membership" not in mt.col: + raise ValueError( + "The 'group_membership' annotation is not found in the input MatrixTable " + "and 'group_membership_ht' is not specified." + ) + if select_fields is None: select_fields = [] - n_samples = mt.count_cols() + if group_membership_ht is None: + logger.info( + "'group_membership_ht' is not specified, using sample stratification " + "indicated by the 'group_membership' annotation on the input MatrixTable." + ) + group_globals = mt.index_globals() + else: + logger.info( + "'group_membership_ht' is specified, using sample stratification indicated " + "by its 'group_membership' annotation." + ) + group_globals = group_membership_ht.index_globals() + mt = mt.annotate_cols( + group_membership=group_membership_ht[mt.col_key].group_membership + ) + + global_expr = {} n_groups = len(mt.group_membership.take(1)[0]) + if "adj_groups" in group_globals: + logger.info( + "Using the 'adj_groups' global annotation to determine adj filtered " + "stratification groups." + ) + global_expr["adj_groups"] = group_globals.adj_groups + elif "freq_meta" in group_globals: + logger.info( + "No 'adj_groups' global annotation found, using the 'freq_meta' global " + "annotation to determine adj filtered stratification groups." + ) + global_expr["adj_groups"] = group_globals.freq_meta.map( + lambda x: x.get("group", "NA") == "adj" + ) + else: + logger.info( + "No 'adj_groups' or 'freq_meta' global annotations found. All groups will " + "be considered non-adj." + ) + global_expr["adj_groups"] = hl.range(n_groups).map(lambda x: False) + + n_adj_groups = hl.eval(hl.len(global_expr["adj_groups"])) + if n_adj_groups != n_groups: + raise ValueError( + f"The number of elements in the 'adj_groups' ({n_adj_groups}) global " + "annotation does not match the number of elements in the " + f"'group_membership' annotation ({n_groups})!", + ) + + # Keep only the entries needed for the aggregation functions. + select_expr = {**{ann: f[0](mt) for ann, f in entry_agg_funcs.items()}} + has_adj = hl.eval(hl.any(global_expr["adj_groups"])) + if has_adj: + select_expr["adj"] = mt.adj + + mt = mt.select_entries(**select_expr) + + # Convert MT to HT with a row annotation that is an array of all samples entries + # for that variant. ht = mt.localize_entries("entries", "cols") - ht = ht.annotate_globals( - indices_by_group=hl.range(n_groups).map( - lambda g_i: hl.range(n_samples).filter( - lambda s_i: ht.cols[s_i].group_membership[g_i] - ) + + # For each stratification group in group_membership, determine the indices of the + # samples that belong to that group. + global_expr["indices_by_group"] = hl.range(n_groups).map( + lambda g_i: hl.range(mt.count_cols()).filter( + lambda s_i: ht.cols[s_i].group_membership[g_i] ) ) + ht = ht.annotate_globals(**global_expr) + # Pull out each annotation that will be used in the array aggregation below as its # own ArrayExpression. This is important to prevent memory issues when performing # the below array aggregations. ht = ht.select( - *select_fields, - adj_array=ht.entries.map(lambda e: e.adj), - gt_array=ht.entries.map(lambda e: e.GT), **{ - ann: hl.map(lambda e, s: f[0](e, s), ht.entries, ht.cols) - for ann, f in entry_agg_funcs.items() - }, + ann: ht.entries.map(lambda e: e[ann]) + for ann in select_fields + list(select_expr.keys()) + } ) def _agg_by_group( - ht: hl.Table, agg_func: Callable, ann_expr: hl.expr.ArrayExpression, *args + ht: hl.Table, agg_func: Callable, ann_expr: hl.expr.ArrayExpression ) -> hl.expr.ArrayExpression: """ Aggregate `agg_expr` by group using the `agg_func` function. :param ht: Input Hail Table. - :param agg_func: Aggregation function to apply to `agg_expr`. - :param agg_expr: Expression to aggregate by group. - :param args: Additional arguments to pass to the `agg_func`. + :param agg_func: Aggregation function to apply to `ann_expr`. + :param ann_expr: Expression to aggregate by group. :return: Aggregated array expression. """ - adj_agg_expr = ht.indices_by_group.map( - lambda s_indices: s_indices.aggregate( - lambda i: hl.agg.filter(ht.adj_array[i], agg_func(ann_expr[i], *args)) - ) + f = lambda i, adj: agg_func(ann_expr[i]) + if has_adj: + f = lambda i, adj: hl.if_else(adj, hl.agg.filter(ht.adj[i], f), f) + + return hl.map( + lambda s_indices, adj: s_indices.aggregate(lambda i: f(i, adj)), + ht.indices_by_group, + ht.adj_groups, ) - # Create final agg list by inserting or changing the "raw" group, - # representing all samples, in the adj_agg_list. - raw_agg_expr = ann_expr.aggregate(lambda x: agg_func(x, *args)) - if group_membership_includes_raw_group: - extend_idx = 2 - else: - extend_idx = 1 - - adj_agg_expr = ( - adj_agg_expr[:1].append(raw_agg_expr).extend(adj_agg_expr[extend_idx:]) - ) - - return adj_agg_expr - freq_expr = _agg_by_group(ht, hl.agg.call_stats, ht.gt_array, ht.alleles) - - # Select non-ref allele (assumes bi-allelic). - freq_expr = freq_expr.map( - lambda cs: cs.annotate( - AC=cs.AC[1], - AF=cs.AF[1], - homozygote_count=cs.homozygote_count[1], - ) - ) # Add annotations for any supplied entry transform and aggregation functions. ht = ht.select( *select_fields, **{ann: _agg_by_group(ht, f[1], ht[ann]) for ann, f in entry_agg_funcs.items()}, - freq=freq_expr, ) return ht.drop("cols")