From 00a13e6d67c9aab7d309528beacef4c3807043ac Mon Sep 17 00:00:00 2001 From: jkgoodrich <33063077+jkgoodrich@users.noreply.github.com> Date: Tue, 9 Jan 2024 10:10:29 -0700 Subject: [PATCH 1/8] Extract strata aggregation into it's own function and use in `compute_freq_by_strata` --- gnomad/utils/annotations.py | 245 ++++++++++++++++++++++++++---------- 1 file changed, 179 insertions(+), 66 deletions(-) diff --git a/gnomad/utils/annotations.py b/gnomad/utils/annotations.py index b59fed83a..0720a6315 100644 --- a/gnomad/utils/annotations.py +++ b/gnomad/utils/annotations.py @@ -1861,25 +1861,19 @@ def generate_freq_group_membership_array( return ht -def compute_freq_by_strata( +def agg_by_strata( mt: hl.MatrixTable, entry_agg_funcs: Optional[Dict[str, Tuple[Callable, Callable]]] = None, select_fields: Optional[List[str]] = None, - group_membership_includes_raw_group: bool = True, + group_membership_ht: Optional[hl.Table] = None, ) -> hl.Table: """ - Compute call statistics and, when passed, entry aggregation function(s) by strata. - - The computed call statistics are AC, AF, AN, and homozygote_count. The entry - aggregation functions are applied to the MatrixTable entries and aggregated. The - MatrixTable must contain a 'group_membership' annotation (like the one added by - `generate_freq_group_membership_array`) that is a list of bools to aggregate the - columns by. + Get row expression for annotations of each entry aggregation function(s) by strata. - .. note:: - This function is primarily used through `annotate_freq` but can be used - independently if desired. Please see the `annotate_freq` function for more - complete documentation. + The entry aggregation functions are applied to the MatrixTable entries and + aggregated. If no `group_membership_ht` (like the one returned by + `generate_freq_group_membership_array`) is supplied, `mt` must contain a + 'group_membership' annotation that is a list of bools to aggregate the columns by. :param mt: Input MatrixTable. :param entry_agg_funcs: Optional dict of entry aggregation functions. When @@ -1890,15 +1884,9 @@ def compute_freq_by_strata( function. :param select_fields: Optional list of row fields from `mt` to keep on the output Table. - :param group_membership_includes_raw_group: Whether the 'group_membership' - annotation includes an entry for the 'raw' group, representing all samples. If - False, the 'raw' group is inserted as the second element in all added - annotations using the same 'group_membership', resulting - in array lengths of 'group_membership'+1. If True, the second element of each - added annotation is still the 'raw' group, but the group membership is - determined by the values in the second element of 'group_membership', and the - output annotations will be the same length as 'group_membership'. Default is - True. + :param group_membership_ht: Optional Table containing group membership annotations + to stratify the coverage stats by. If not provided, the 'group_membership' + annotation is expected to be present on `mt`. :return: Table or MatrixTable with allele frequencies by strata. """ if entry_agg_funcs is None: @@ -1907,79 +1895,204 @@ def compute_freq_by_strata( select_fields = [] n_samples = mt.count_cols() - n_groups = len(mt.group_membership.take(1)[0]) - ht = mt.localize_entries("entries", "cols") - ht = ht.annotate_globals( - indices_by_group=hl.range(n_groups).map( - lambda g_i: hl.range(n_samples).filter( - lambda s_i: ht.cols[s_i].group_membership[g_i] + global_expr = {} + if "adj_group" in mt.index_globals(): + global_expr["adj_group"] = mt.index_globals().adj_group + logger.info("Using the 'adj_group' global annotation found on the input MT.") + + if group_membership_ht is None and "group_membership" not in mt.col: + raise ValueError( + "The 'group_membership' annotation is not found in the input MatrixTable " + "and 'group_membership_ht' is not specified." + ) + elif group_membership_ht is None: + logger.info( + "'group_membership_ht' is not specified, using sample stratification " + "indicated by the 'group_membership' annotation on mt." + ) + n_groups = len(mt.group_membership.take(1)[0]) + else: + logger.info( + "'group_membership_ht' is specified, using sample stratification indicated " + "by its 'group_membership' annotation." + ) + n_groups = len(group_membership_ht.group_membership.take(1)[0]) + mt = mt.annotate_cols( + group_membership=group_membership_ht[mt.col_key].group_membership + ) + if "adj_group" not in global_expr: + if "adj_group" in group_membership_ht.index_globals(): + global_expr["adj_group"] = mt.index_globals().adj_group + logger.info( + "Using the 'adj_group' global annotation on 'group_membership_ht'." + ) + elif "freq_meta" in group_membership_ht.index_globals(): + logger.info( + "The 'freq_meta' global annotation is found in " + "'group_membership_ht', using it to determine the adj filtered " + "stratification groups." + ) + freq_meta = group_membership_ht.index_globals().freq_meta + + global_expr["adj_group"] = freq_meta.map( + lambda x: x.get("group", "NA") == "adj" ) + + if "adj_group" not in global_expr: + global_expr["adj_group"] = hl.range(n_groups).map(lambda x: False) + + n_adj_group = hl.eval(hl.len(global_expr["adj_group"])) + if hl.eval(hl.len(global_expr["adj_group"])) != n_groups: + raise ValueError( + f"The number of elements in the 'adj_group' ({n_adj_group}) global " + "annotation does not match the number of elements in the " + f"'group_membership' annotation ({n_groups})!", + ) + + # Keep only the entries needed for the aggregation functions. + select_expr = {} + if hl.eval(hl.any(global_expr["adj_group"])): + select_expr["adj"] = mt.adj + + select_expr.update(**{ann: f[0](mt) for ann, f in entry_agg_funcs.items()}) + mt = mt.select_entries(**select_expr) + + # Convert MT to HT with a row annotation that is an array of all samples entries + # for that variant. + ht = mt.localize_entries("entries", "cols") + + # For each stratification group in group_membership, determine the indices of the + # samples that belong to that group. + global_expr["indices_by_group"] = hl.range(n_groups).map( + lambda g_i: hl.range(n_samples).filter( + lambda s_i: ht.cols[s_i].group_membership[g_i] ) ) + ht = ht.annotate_globals(**global_expr) + # Pull out each annotation that will be used in the array aggregation below as its # own ArrayExpression. This is important to prevent memory issues when performing # the below array aggregations. ht = ht.select( - *select_fields, - adj_array=ht.entries.map(lambda e: e.adj), - gt_array=ht.entries.map(lambda e: e.GT), **{ - ann: hl.map(lambda e, s: f[0](e, s), ht.entries, ht.cols) - for ann, f in entry_agg_funcs.items() - }, + ann: ht.entries.map(lambda e: e[ann]) + for ann in select_fields + list(select_expr.keys()) + } ) def _agg_by_group( - ht: hl.Table, agg_func: Callable, ann_expr: hl.expr.ArrayExpression, *args + ht: hl.Table, agg_func: Callable, ann_expr: hl.expr.ArrayExpression ) -> hl.expr.ArrayExpression: """ Aggregate `agg_expr` by group using the `agg_func` function. :param ht: Input Hail Table. - :param agg_func: Aggregation function to apply to `agg_expr`. - :param agg_expr: Expression to aggregate by group. - :param args: Additional arguments to pass to the `agg_func`. + :param agg_func: Aggregation function to apply to `ann_expr`. + :param ann_expr: Expression to aggregate by group. :return: Aggregated array expression. """ - adj_agg_expr = ht.indices_by_group.map( - lambda s_indices: s_indices.aggregate( - lambda i: hl.agg.filter(ht.adj_array[i], agg_func(ann_expr[i], *args)) - ) - ) - # Create final agg list by inserting or changing the "raw" group, - # representing all samples, in the adj_agg_list. - raw_agg_expr = ann_expr.aggregate(lambda x: agg_func(x, *args)) - if group_membership_includes_raw_group: - extend_idx = 2 - else: - extend_idx = 1 - - adj_agg_expr = ( - adj_agg_expr[:1].append(raw_agg_expr).extend(adj_agg_expr[extend_idx:]) + return hl.map( + lambda s_indices, adj: s_indices.aggregate( + lambda i: hl.if_else( + adj, + hl.agg.filter(ht.adj[i], agg_func(ann_expr[i])), + agg_func(ann_expr[i]), + ) + ), + ht.indices_by_group, + ht.adj_group, ) - return adj_agg_expr - - freq_expr = _agg_by_group(ht, hl.agg.call_stats, ht.gt_array, ht.alleles) - - # Select non-ref allele (assumes bi-allelic). - freq_expr = freq_expr.map( - lambda cs: cs.annotate( - AC=cs.AC[1], - AF=cs.AF[1], - homozygote_count=cs.homozygote_count[1], - ) - ) # Add annotations for any supplied entry transform and aggregation functions. ht = ht.select( *select_fields, **{ann: _agg_by_group(ht, f[1], ht[ann]) for ann, f in entry_agg_funcs.items()}, - freq=freq_expr, ) return ht.drop("cols") +def compute_freq_by_strata( + mt: hl.MatrixTable, + entry_agg_funcs: Optional[Dict[str, Tuple[Callable, Callable]]] = None, + select_fields: Optional[List[str]] = None, + group_membership_includes_raw_group: bool = True, +) -> hl.Table: + """ + Compute call statistics and, when passed, entry aggregation function(s) by strata. + + The computed call statistics are AC, AF, AN, and homozygote_count. The entry + aggregation functions are applied to the MatrixTable entries and aggregated. The + MatrixTable must contain a 'group_membership' annotation (like the one added by + `generate_freq_group_membership_array`) that is a list of bools to aggregate the + columns by. + + .. note:: + This function is primarily used through `annotate_freq` but can be used + independently if desired. Please see the `annotate_freq` function for more + complete documentation. + + :param mt: Input MatrixTable. + :param entry_agg_funcs: Optional dict of entry aggregation functions. When + specified, additional annotations are added to the output Table/MatrixTable. + The keys of the dict are the names of the annotations and the values are tuples + of functions. The first function is used to transform the `mt` entries in some + way, and the second function is used to aggregate the output from the first + function. + :param select_fields: Optional list of row fields from `mt` to keep on the output + Table. + :param group_membership_includes_raw_group: Whether the 'group_membership' + annotation includes an entry for the 'raw' group, representing all samples. If + False, the 'raw' group is inserted as the second element in all added + annotations using the same 'group_membership', resulting + in array lengths of 'group_membership'+1. If True, the second element of each + added annotation is still the 'raw' group, but the group membership is + determined by the values in the second element of 'group_membership', and the + output annotations will be the same length as 'group_membership'. Default is + True. + :return: Table or MatrixTable with allele frequencies by strata. + """ + if not group_membership_includes_raw_group: + # Add the 'raw' group to the 'group_membership' annotation. + mt = mt.annotate_cols( + group_membership=hl.array([mt.group_membership[0]]).extend( + mt.group_membership + ) + ) + + # Add adj_group global annotation indicating that the second element in + # group_membership is 'raw' and all others are 'adj'. + mt = mt.annotate_globals( + adj_group=hl.range(hl.len(mt.group_membership.take(1)[0])).map(lambda x: x != 1) + ) + + if entry_agg_funcs is None: + entry_agg_funcs = {} + + def _get_freq_expr(gt_expr: hl.expr.CallExpression) -> hl.expr.StructExpression: + """ + Get struct expression with call statistics. + + :param gt_expr: CallExpression to compute call statistics on. + :return: StructExpression with call statistics. + """ + # Get the source Table for the CallExpression to grab alleles. + ht = gt_expr._indices.source + freq_expr = hl.agg.call_stats(gt_expr, ht.alleles) + # Select non-ref allele (assumes bi-allelic). + freq_expr = freq_expr.annotate( + AC=freq_expr.AC[1], + AF=freq_expr.AF[1], + homozygote_count=freq_expr.homozygote_count[1], + ) + + return freq_expr + + entry_agg_funcs["freq"] = (lambda x: x.GT, _get_freq_expr) + + return agg_by_strata(mt, entry_agg_funcs, select_fields).drop("adj_group") + + def update_structured_annotations( ht: hl.Table, annotation_update_exprs: Dict[str, hl.Expression], From 48703075c1631c6a717e828f8a8531a514c57afd Mon Sep 17 00:00:00 2001 From: jkgoodrich <33063077+jkgoodrich@users.noreply.github.com> Date: Wed, 10 Jan 2024 09:30:36 -0700 Subject: [PATCH 2/8] Switch order of functions for review --- gnomad/utils/annotations.py | 162 ++++++++++++++++++------------------ 1 file changed, 81 insertions(+), 81 deletions(-) diff --git a/gnomad/utils/annotations.py b/gnomad/utils/annotations.py index 0720a6315..e1f505a9e 100644 --- a/gnomad/utils/annotations.py +++ b/gnomad/utils/annotations.py @@ -1861,6 +1861,87 @@ def generate_freq_group_membership_array( return ht +def compute_freq_by_strata( + mt: hl.MatrixTable, + entry_agg_funcs: Optional[Dict[str, Tuple[Callable, Callable]]] = None, + select_fields: Optional[List[str]] = None, + group_membership_includes_raw_group: bool = True, +) -> hl.Table: + """ + Compute call statistics and, when passed, entry aggregation function(s) by strata. + + The computed call statistics are AC, AF, AN, and homozygote_count. The entry + aggregation functions are applied to the MatrixTable entries and aggregated. The + MatrixTable must contain a 'group_membership' annotation (like the one added by + `generate_freq_group_membership_array`) that is a list of bools to aggregate the + columns by. + + .. note:: + This function is primarily used through `annotate_freq` but can be used + independently if desired. Please see the `annotate_freq` function for more + complete documentation. + + :param mt: Input MatrixTable. + :param entry_agg_funcs: Optional dict of entry aggregation functions. When + specified, additional annotations are added to the output Table/MatrixTable. + The keys of the dict are the names of the annotations and the values are tuples + of functions. The first function is used to transform the `mt` entries in some + way, and the second function is used to aggregate the output from the first + function. + :param select_fields: Optional list of row fields from `mt` to keep on the output + Table. + :param group_membership_includes_raw_group: Whether the 'group_membership' + annotation includes an entry for the 'raw' group, representing all samples. If + False, the 'raw' group is inserted as the second element in all added + annotations using the same 'group_membership', resulting + in array lengths of 'group_membership'+1. If True, the second element of each + added annotation is still the 'raw' group, but the group membership is + determined by the values in the second element of 'group_membership', and the + output annotations will be the same length as 'group_membership'. Default is + True. + :return: Table or MatrixTable with allele frequencies by strata. + """ + if not group_membership_includes_raw_group: + # Add the 'raw' group to the 'group_membership' annotation. + mt = mt.annotate_cols( + group_membership=hl.array([mt.group_membership[0]]).extend( + mt.group_membership + ) + ) + + # Add adj_group global annotation indicating that the second element in + # group_membership is 'raw' and all others are 'adj'. + mt = mt.annotate_globals( + adj_group=hl.range(hl.len(mt.group_membership.take(1)[0])).map(lambda x: x != 1) + ) + + if entry_agg_funcs is None: + entry_agg_funcs = {} + + def _get_freq_expr(gt_expr: hl.expr.CallExpression) -> hl.expr.StructExpression: + """ + Get struct expression with call statistics. + + :param gt_expr: CallExpression to compute call statistics on. + :return: StructExpression with call statistics. + """ + # Get the source Table for the CallExpression to grab alleles. + ht = gt_expr._indices.source + freq_expr = hl.agg.call_stats(gt_expr, ht.alleles) + # Select non-ref allele (assumes bi-allelic). + freq_expr = freq_expr.annotate( + AC=freq_expr.AC[1], + AF=freq_expr.AF[1], + homozygote_count=freq_expr.homozygote_count[1], + ) + + return freq_expr + + entry_agg_funcs["freq"] = (lambda x: x.GT, _get_freq_expr) + + return agg_by_strata(mt, entry_agg_funcs, select_fields).drop("adj_group") + + def agg_by_strata( mt: hl.MatrixTable, entry_agg_funcs: Optional[Dict[str, Tuple[Callable, Callable]]] = None, @@ -2012,87 +2093,6 @@ def _agg_by_group( return ht.drop("cols") -def compute_freq_by_strata( - mt: hl.MatrixTable, - entry_agg_funcs: Optional[Dict[str, Tuple[Callable, Callable]]] = None, - select_fields: Optional[List[str]] = None, - group_membership_includes_raw_group: bool = True, -) -> hl.Table: - """ - Compute call statistics and, when passed, entry aggregation function(s) by strata. - - The computed call statistics are AC, AF, AN, and homozygote_count. The entry - aggregation functions are applied to the MatrixTable entries and aggregated. The - MatrixTable must contain a 'group_membership' annotation (like the one added by - `generate_freq_group_membership_array`) that is a list of bools to aggregate the - columns by. - - .. note:: - This function is primarily used through `annotate_freq` but can be used - independently if desired. Please see the `annotate_freq` function for more - complete documentation. - - :param mt: Input MatrixTable. - :param entry_agg_funcs: Optional dict of entry aggregation functions. When - specified, additional annotations are added to the output Table/MatrixTable. - The keys of the dict are the names of the annotations and the values are tuples - of functions. The first function is used to transform the `mt` entries in some - way, and the second function is used to aggregate the output from the first - function. - :param select_fields: Optional list of row fields from `mt` to keep on the output - Table. - :param group_membership_includes_raw_group: Whether the 'group_membership' - annotation includes an entry for the 'raw' group, representing all samples. If - False, the 'raw' group is inserted as the second element in all added - annotations using the same 'group_membership', resulting - in array lengths of 'group_membership'+1. If True, the second element of each - added annotation is still the 'raw' group, but the group membership is - determined by the values in the second element of 'group_membership', and the - output annotations will be the same length as 'group_membership'. Default is - True. - :return: Table or MatrixTable with allele frequencies by strata. - """ - if not group_membership_includes_raw_group: - # Add the 'raw' group to the 'group_membership' annotation. - mt = mt.annotate_cols( - group_membership=hl.array([mt.group_membership[0]]).extend( - mt.group_membership - ) - ) - - # Add adj_group global annotation indicating that the second element in - # group_membership is 'raw' and all others are 'adj'. - mt = mt.annotate_globals( - adj_group=hl.range(hl.len(mt.group_membership.take(1)[0])).map(lambda x: x != 1) - ) - - if entry_agg_funcs is None: - entry_agg_funcs = {} - - def _get_freq_expr(gt_expr: hl.expr.CallExpression) -> hl.expr.StructExpression: - """ - Get struct expression with call statistics. - - :param gt_expr: CallExpression to compute call statistics on. - :return: StructExpression with call statistics. - """ - # Get the source Table for the CallExpression to grab alleles. - ht = gt_expr._indices.source - freq_expr = hl.agg.call_stats(gt_expr, ht.alleles) - # Select non-ref allele (assumes bi-allelic). - freq_expr = freq_expr.annotate( - AC=freq_expr.AC[1], - AF=freq_expr.AF[1], - homozygote_count=freq_expr.homozygote_count[1], - ) - - return freq_expr - - entry_agg_funcs["freq"] = (lambda x: x.GT, _get_freq_expr) - - return agg_by_strata(mt, entry_agg_funcs, select_fields).drop("adj_group") - - def update_structured_annotations( ht: hl.Table, annotation_update_exprs: Dict[str, hl.Expression], From ec53431667ffd0ca7ec24ff4db07ffc165854788 Mon Sep 17 00:00:00 2001 From: jkgoodrich <33063077+jkgoodrich@users.noreply.github.com> Date: Wed, 10 Jan 2024 13:56:38 -0700 Subject: [PATCH 3/8] Add fixes from testing --- gnomad/utils/annotations.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/gnomad/utils/annotations.py b/gnomad/utils/annotations.py index e1f505a9e..ec40287ea 100644 --- a/gnomad/utils/annotations.py +++ b/gnomad/utils/annotations.py @@ -1997,27 +1997,27 @@ def agg_by_strata( "'group_membership_ht' is specified, using sample stratification indicated " "by its 'group_membership' annotation." ) + group_globals = group_membership_ht.index_globals() n_groups = len(group_membership_ht.group_membership.take(1)[0]) mt = mt.annotate_cols( group_membership=group_membership_ht[mt.col_key].group_membership ) if "adj_group" not in global_expr: - if "adj_group" in group_membership_ht.index_globals(): - global_expr["adj_group"] = mt.index_globals().adj_group + if "adj_group" in group_globals: + global_expr["adj_group"] = group_globals.adj_group logger.info( "Using the 'adj_group' global annotation on 'group_membership_ht'." ) - elif "freq_meta" in group_membership_ht.index_globals(): + elif "freq_meta" in group_globals: logger.info( "The 'freq_meta' global annotation is found in " "'group_membership_ht', using it to determine the adj filtered " "stratification groups." ) - freq_meta = group_membership_ht.index_globals().freq_meta - - global_expr["adj_group"] = freq_meta.map( - lambda x: x.get("group", "NA") == "adj" - ) + freq_meta = group_globals.freq_meta + global_expr["adj_group"] = freq_meta.map( + lambda x: x.get("group", "NA") == "adj" + ) if "adj_group" not in global_expr: global_expr["adj_group"] = hl.range(n_groups).map(lambda x: False) @@ -2032,8 +2032,10 @@ def agg_by_strata( # Keep only the entries needed for the aggregation functions. select_expr = {} + has_adj = False if hl.eval(hl.any(global_expr["adj_group"])): select_expr["adj"] = mt.adj + has_adj = True select_expr.update(**{ann: f[0](mt) for ann, f in entry_agg_funcs.items()}) mt = mt.select_entries(**select_expr) @@ -2072,14 +2074,12 @@ def _agg_by_group( :param ann_expr: Expression to aggregate by group. :return: Aggregated array expression. """ + f = lambda i, adj: agg_func(ann_expr[i]) + if has_adj: + f = lambda i, adj: hl.if_else(adj, hl.agg.filter(ht.adj[i], f), f) + return hl.map( - lambda s_indices, adj: s_indices.aggregate( - lambda i: hl.if_else( - adj, - hl.agg.filter(ht.adj[i], agg_func(ann_expr[i])), - agg_func(ann_expr[i]), - ) - ), + lambda s_indices, adj: s_indices.aggregate(lambda i: f(i, adj)), ht.indices_by_group, ht.adj_group, ) From 5245a614bdefe3e0e5447cd7f1c3a2973fe69c83 Mon Sep 17 00:00:00 2001 From: Mike Wilson Date: Thu, 11 Jan 2024 15:27:01 -0500 Subject: [PATCH 4/8] Rearrange and enforce adj_group and group_membership being on the same HT/MT --- gnomad/utils/annotations.py | 80 ++++++++++++++++++------------------- 1 file changed, 39 insertions(+), 41 deletions(-) diff --git a/gnomad/utils/annotations.py b/gnomad/utils/annotations.py index ec40287ea..a312f6684 100644 --- a/gnomad/utils/annotations.py +++ b/gnomad/utils/annotations.py @@ -1957,73 +1957,73 @@ def agg_by_strata( 'group_membership' annotation that is a list of bools to aggregate the columns by. :param mt: Input MatrixTable. - :param entry_agg_funcs: Optional dict of entry aggregation functions. When - specified, additional annotations are added to the output Table/MatrixTable. - The keys of the dict are the names of the annotations and the values are tuples + :param entry_agg_funcs: Dict of entry aggregation functions where the + keys of the dict are the names of the annotations and the values are tuples of functions. The first function is used to transform the `mt` entries in some way, and the second function is used to aggregate the output from the first function. :param select_fields: Optional list of row fields from `mt` to keep on the output Table. :param group_membership_ht: Optional Table containing group membership annotations - to stratify the coverage stats by. If not provided, the 'group_membership' + to stratify the aggregations by. If not provided, the 'group_membership' annotation is expected to be present on `mt`. - :return: Table or MatrixTable with allele frequencies by strata. + :return: Table with annotations of stratified aggregations. """ if entry_agg_funcs is None: - entry_agg_funcs = {} - if select_fields is None: - select_fields = [] - - n_samples = mt.count_cols() - global_expr = {} - if "adj_group" in mt.index_globals(): - global_expr["adj_group"] = mt.index_globals().adj_group - logger.info("Using the 'adj_group' global annotation found on the input MT.") + raise TypeError( + "'agg_by_strata' expects a 'entry_agg_funcs' dictionary but it was not" + " supplied. Without the dictionary, no aggregations will occur." + ) if group_membership_ht is None and "group_membership" not in mt.col: raise ValueError( "The 'group_membership' annotation is not found in the input MatrixTable " "and 'group_membership_ht' is not specified." ) - elif group_membership_ht is None: + + if select_fields is None: + select_fields = [] + + if group_membership_ht is None: logger.info( "'group_membership_ht' is not specified, using sample stratification " "indicated by the 'group_membership' annotation on mt." ) - n_groups = len(mt.group_membership.take(1)[0]) + group_globals = mt.index_globals() else: logger.info( "'group_membership_ht' is specified, using sample stratification indicated " "by its 'group_membership' annotation." ) group_globals = group_membership_ht.index_globals() - n_groups = len(group_membership_ht.group_membership.take(1)[0]) mt = mt.annotate_cols( group_membership=group_membership_ht[mt.col_key].group_membership ) - if "adj_group" not in global_expr: - if "adj_group" in group_globals: - global_expr["adj_group"] = group_globals.adj_group - logger.info( - "Using the 'adj_group' global annotation on 'group_membership_ht'." - ) - elif "freq_meta" in group_globals: - logger.info( - "The 'freq_meta' global annotation is found in " - "'group_membership_ht', using it to determine the adj filtered " - "stratification groups." - ) - freq_meta = group_globals.freq_meta - global_expr["adj_group"] = freq_meta.map( - lambda x: x.get("group", "NA") == "adj" - ) - if "adj_group" not in global_expr: + global_expr = {} + n_groups = len(mt.group_membership.take(1)[0]) + if "adj_group" in group_globals: + global_expr["adj_group"] = group_globals.adj_group + logger.info("Using the 'adj_group' global annotation on 'group_membership_ht'.") + elif "freq_meta" in group_globals: + logger.info( + "The 'freq_meta' global annotation is found in " + "'group_membership_ht', using it to determine the adj filtered " + "stratification groups." + ) + freq_meta = group_globals.freq_meta + global_expr["adj_group"] = freq_meta.map( + lambda x: x.get("group", "NA") == "adj" + ) + else: global_expr["adj_group"] = hl.range(n_groups).map(lambda x: False) + # NOTE: Unsure if we still want this check here since the adj_group and n_groups + # always be from the same table or built within this function? Its a cheap operation + # so I'm leaning towards keeping it even though I'm not sure this is the right place + # for this check. n_adj_group = hl.eval(hl.len(global_expr["adj_group"])) - if hl.eval(hl.len(global_expr["adj_group"])) != n_groups: + if n_adj_group != n_groups: raise ValueError( f"The number of elements in the 'adj_group' ({n_adj_group}) global " "annotation does not match the number of elements in the " @@ -2031,13 +2031,11 @@ def agg_by_strata( ) # Keep only the entries needed for the aggregation functions. - select_expr = {} - has_adj = False - if hl.eval(hl.any(global_expr["adj_group"])): + select_expr = {**{ann: f[0](mt) for ann, f in entry_agg_funcs.items()}} + has_adj = hl.eval(hl.any(global_expr["adj_group"])) + if has_adj: select_expr["adj"] = mt.adj - has_adj = True - select_expr.update(**{ann: f[0](mt) for ann, f in entry_agg_funcs.items()}) mt = mt.select_entries(**select_expr) # Convert MT to HT with a row annotation that is an array of all samples entries @@ -2047,7 +2045,7 @@ def agg_by_strata( # For each stratification group in group_membership, determine the indices of the # samples that belong to that group. global_expr["indices_by_group"] = hl.range(n_groups).map( - lambda g_i: hl.range(n_samples).filter( + lambda g_i: hl.range(mt.count_cols()).filter( lambda s_i: ht.cols[s_i].group_membership[g_i] ) ) From 5abc5aa6e751d31f74ec5279a37dd5dff58600e8 Mon Sep 17 00:00:00 2001 From: jkgoodrich <33063077+jkgoodrich@users.noreply.github.com> Date: Tue, 16 Jan 2024 12:10:06 -0700 Subject: [PATCH 5/8] Small clean-up after merge of suggestions suggestions --- gnomad/utils/annotations.py | 31 +++++++++++++------------------ 1 file changed, 13 insertions(+), 18 deletions(-) diff --git a/gnomad/utils/annotations.py b/gnomad/utils/annotations.py index a312f6684..4ec00ef35 100644 --- a/gnomad/utils/annotations.py +++ b/gnomad/utils/annotations.py @@ -1944,7 +1944,7 @@ def _get_freq_expr(gt_expr: hl.expr.CallExpression) -> hl.expr.StructExpression: def agg_by_strata( mt: hl.MatrixTable, - entry_agg_funcs: Optional[Dict[str, Tuple[Callable, Callable]]] = None, + entry_agg_funcs: Dict[str, Tuple[Callable, Callable]], select_fields: Optional[List[str]] = None, group_membership_ht: Optional[hl.Table] = None, ) -> hl.Table: @@ -1969,12 +1969,6 @@ def agg_by_strata( annotation is expected to be present on `mt`. :return: Table with annotations of stratified aggregations. """ - if entry_agg_funcs is None: - raise TypeError( - "'agg_by_strata' expects a 'entry_agg_funcs' dictionary but it was not" - " supplied. Without the dictionary, no aggregations will occur." - ) - if group_membership_ht is None and "group_membership" not in mt.col: raise ValueError( "The 'group_membership' annotation is not found in the input MatrixTable " @@ -1987,7 +1981,7 @@ def agg_by_strata( if group_membership_ht is None: logger.info( "'group_membership_ht' is not specified, using sample stratification " - "indicated by the 'group_membership' annotation on mt." + "indicated by the 'group_membership' annotation on the input MatrixTable." ) group_globals = mt.index_globals() else: @@ -2003,25 +1997,26 @@ def agg_by_strata( global_expr = {} n_groups = len(mt.group_membership.take(1)[0]) if "adj_group" in group_globals: + logger.info( + "Using the 'adj_group' global annotation to determine adj filtered " + "stratification groups." + ) global_expr["adj_group"] = group_globals.adj_group - logger.info("Using the 'adj_group' global annotation on 'group_membership_ht'.") elif "freq_meta" in group_globals: logger.info( - "The 'freq_meta' global annotation is found in " - "'group_membership_ht', using it to determine the adj filtered " - "stratification groups." + "No 'adj_group' global annotation found, using the 'freq_meta' global " + "annotation to determine adj filtered stratification groups." ) - freq_meta = group_globals.freq_meta - global_expr["adj_group"] = freq_meta.map( + global_expr["adj_group"] = group_globals.freq_meta.map( lambda x: x.get("group", "NA") == "adj" ) else: + logger.info( + "No 'adj_group' or 'freq_meta' global annotations found. All groups will " + "be considered non-adj." + ) global_expr["adj_group"] = hl.range(n_groups).map(lambda x: False) - # NOTE: Unsure if we still want this check here since the adj_group and n_groups - # always be from the same table or built within this function? Its a cheap operation - # so I'm leaning towards keeping it even though I'm not sure this is the right place - # for this check. n_adj_group = hl.eval(hl.len(global_expr["adj_group"])) if n_adj_group != n_groups: raise ValueError( From fda52c885b8d84f7fce8586f1c617b076c0d6f4b Mon Sep 17 00:00:00 2001 From: jkgoodrich <33063077+jkgoodrich@users.noreply.github.com> Date: Wed, 17 Jan 2024 10:39:21 -0700 Subject: [PATCH 6/8] Apply suggestions from code review Co-authored-by: Mike Wilson --- gnomad/utils/annotations.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/gnomad/utils/annotations.py b/gnomad/utils/annotations.py index 4ec00ef35..2932ff904 100644 --- a/gnomad/utils/annotations.py +++ b/gnomad/utils/annotations.py @@ -1909,10 +1909,10 @@ def compute_freq_by_strata( ) ) - # Add adj_group global annotation indicating that the second element in + # Add adj_groups global annotation indicating that the second element in # group_membership is 'raw' and all others are 'adj'. mt = mt.annotate_globals( - adj_group=hl.range(hl.len(mt.group_membership.take(1)[0])).map(lambda x: x != 1) + adj_groups=hl.range(hl.len(mt.group_membership.take(1)[0])).map(lambda x: x != 1) ) if entry_agg_funcs is None: @@ -1939,7 +1939,7 @@ def _get_freq_expr(gt_expr: hl.expr.CallExpression) -> hl.expr.StructExpression: entry_agg_funcs["freq"] = (lambda x: x.GT, _get_freq_expr) - return agg_by_strata(mt, entry_agg_funcs, select_fields).drop("adj_group") + return agg_by_strata(mt, entry_agg_funcs, select_fields).drop("adj_groups") def agg_by_strata( @@ -1996,38 +1996,38 @@ def agg_by_strata( global_expr = {} n_groups = len(mt.group_membership.take(1)[0]) - if "adj_group" in group_globals: + if "adj_groups" in group_globals: logger.info( "Using the 'adj_group' global annotation to determine adj filtered " "stratification groups." ) - global_expr["adj_group"] = group_globals.adj_group + global_expr["adj_groups"] = group_globals.adj_groups elif "freq_meta" in group_globals: logger.info( - "No 'adj_group' global annotation found, using the 'freq_meta' global " + "No 'adj_groups' global annotation found, using the 'freq_meta' global " "annotation to determine adj filtered stratification groups." ) - global_expr["adj_group"] = group_globals.freq_meta.map( + global_expr["adj_groups"] = group_globals.freq_meta.map( lambda x: x.get("group", "NA") == "adj" ) else: logger.info( - "No 'adj_group' or 'freq_meta' global annotations found. All groups will " + "No 'adj_groups' or 'freq_meta' global annotations found. All groups will " "be considered non-adj." ) - global_expr["adj_group"] = hl.range(n_groups).map(lambda x: False) + global_expr["adj_groups"] = hl.range(n_groups).map(lambda x: False) - n_adj_group = hl.eval(hl.len(global_expr["adj_group"])) - if n_adj_group != n_groups: + n_adj_groups = hl.eval(hl.len(global_expr["adj_groups"])) + if n_adj_groups != n_groups: raise ValueError( - f"The number of elements in the 'adj_group' ({n_adj_group}) global " + f"The number of elements in the 'adj_groups' ({n_adj_groups}) global " "annotation does not match the number of elements in the " f"'group_membership' annotation ({n_groups})!", ) # Keep only the entries needed for the aggregation functions. select_expr = {**{ann: f[0](mt) for ann, f in entry_agg_funcs.items()}} - has_adj = hl.eval(hl.any(global_expr["adj_group"])) + has_adj = hl.eval(hl.any(global_expr["adj_groups"])) if has_adj: select_expr["adj"] = mt.adj @@ -2074,7 +2074,7 @@ def _agg_by_group( return hl.map( lambda s_indices, adj: s_indices.aggregate(lambda i: f(i, adj)), ht.indices_by_group, - ht.adj_group, + ht.adj_groups, ) # Add annotations for any supplied entry transform and aggregation functions. From 6640e58449b4a53883098cedd7cdcc38bae0daba Mon Sep 17 00:00:00 2001 From: jkgoodrich <33063077+jkgoodrich@users.noreply.github.com> Date: Wed, 17 Jan 2024 10:39:45 -0700 Subject: [PATCH 7/8] Update gnomad/utils/annotations.py Co-authored-by: Mike Wilson --- gnomad/utils/annotations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gnomad/utils/annotations.py b/gnomad/utils/annotations.py index 2932ff904..7e721d395 100644 --- a/gnomad/utils/annotations.py +++ b/gnomad/utils/annotations.py @@ -1998,7 +1998,7 @@ def agg_by_strata( n_groups = len(mt.group_membership.take(1)[0]) if "adj_groups" in group_globals: logger.info( - "Using the 'adj_group' global annotation to determine adj filtered " + "Using the 'adj_groups' global annotation to determine adj filtered " "stratification groups." ) global_expr["adj_groups"] = group_globals.adj_groups From 32859fd65d5cc439271a5b5f56f03a5769bc92ab Mon Sep 17 00:00:00 2001 From: jkgoodrich <33063077+jkgoodrich@users.noreply.github.com> Date: Wed, 17 Jan 2024 10:42:41 -0700 Subject: [PATCH 8/8] Format --- gnomad/utils/annotations.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/gnomad/utils/annotations.py b/gnomad/utils/annotations.py index 7e721d395..b99e6135e 100644 --- a/gnomad/utils/annotations.py +++ b/gnomad/utils/annotations.py @@ -1912,7 +1912,9 @@ def compute_freq_by_strata( # Add adj_groups global annotation indicating that the second element in # group_membership is 'raw' and all others are 'adj'. mt = mt.annotate_globals( - adj_groups=hl.range(hl.len(mt.group_membership.take(1)[0])).map(lambda x: x != 1) + adj_groups=hl.range(hl.len(mt.group_membership.take(1)[0])).map( + lambda x: x != 1 + ) ) if entry_agg_funcs is None: