diff --git a/gnomad/utils/filtering.py b/gnomad/utils/filtering.py index 75aecaa60..8fcfc693d 100644 --- a/gnomad/utils/filtering.py +++ b/gnomad/utils/filtering.py @@ -620,13 +620,115 @@ def split_vds_by_strata( } +def filter_meta_array( + meta_expr: hl.expr.ArrayExpression, + keys_to_keep: List[str] = None, + keys_to_exclude: List[str] = None, + key_value_pairs_to_keep: Dict[str, List[str]] = None, + key_value_pairs_to_exclude: Dict[str, List[str]] = None, + keep_combine_operator: str = "and", + exclude_combine_operator: str = "and", + combine_operator: str = "and", + exact_match: bool = False, +) -> hl.expr.ArrayExpression: + """ + Filter a metadata array expression based on keys and key-value pairs to keep/exclude. + + If `exact_match` is True, the filtering will only be applied to items with exactly + the specified keys in `keys_to_keep` (and the keys in `key_value_pairs_to_keep` + if provided). When `key_value_pairs_to_keep` is also provided, the keys in + `key_value_pairs_to_keep` must also be present in the metadata item. This + parameter is only relevant when `keys_to_keep` is provided, `combine_operator` + is "and", and `exact_match` is True. + + :param meta_expr: Metadata array expression to filter. + :param keys_to_keep: List of keys to keep. + :param keys_to_exclude: List of keys to exclude. + :param key_value_pairs_to_keep: Dictionary of key-value pairs to keep. + :param key_value_pairs_to_exclude: Dictionary of key-value pairs to exclude. + :param keep_combine_operator: Whether to use "and" or "or" to combine the filtering + criteria for keys/key-value pairs to keep. + :param exclude_combine_operator: Whether to use "and" or "or" to combine the + filtering criteria for keys/key-value pairs to exclude. + :param combine_operator: Whether to use "and" or "or" to combine the keep and + exclude filtering criteria. + :param exact_match: Whether to apply the filtering only to items with exactly the + specified keys. + :return: The filtered metadata array expression. + """ + keys_to_keep = keys_to_keep or {} + key_value_pairs_to_keep = key_value_pairs_to_keep or {} + keys_to_exclude = keys_to_exclude or {} + key_value_pairs_to_exclude = key_value_pairs_to_exclude or {} + + combine_operator_map = {"and": hl.all, "or": hl.any} + for o in [keep_combine_operator, exclude_combine_operator, combine_operator]: + if o not in combine_operator_map: + raise ValueError( + "The combine operators must be one of 'and' or 'or', but found" f" {o}!" + ) + + # Assign operators to their respective values in the combine_operator_map dict. + keep_combine_operator = combine_operator_map[keep_combine_operator] + exclude_combine_operator = combine_operator_map[exclude_combine_operator] + combine_operator = combine_operator_map[combine_operator] + + def _get_filter(m: hl.DictExpression) -> hl.expr.BooleanExpression: + """ + Get the filter to apply to the metadata item. + + :param m: Metadata item. + :return: Filter to apply to the metadata item. + """ + # If keys_to_keep is provided, filter to only metadata items with the specified + # keys. If exact_match is True, filter to only metadata items with the exact + # keys specified in keys_to_keep, where any keys in key_value_pairs_to_keep + # are also present. + if exact_match: + keep_filter = [ + hl.set(set(keys_to_keep) | set(key_value_pairs_to_keep.keys())) + == hl.set(m.keys()) + ] + else: + keep_filter = [m.contains(k) for k in keys_to_keep] + + # If key_value_pairs_to_keep is provided, filter to only metadata items with the + # specified key-value pairs. + keep_filter += [ + hl.literal(v if isinstance(v, list) else [v]).contains(m.get(k, "")) + for k, v in key_value_pairs_to_keep.items() + ] + + # If keys_to_exclude is provided, filter to only metadata items without the + # specified keys and if key_value_pairs_to_exclude is provided, filter to only + # metadata items without the specified key-value pairs. + exclude_filter = [~m.contains(k) for k in keys_to_exclude] + [ + ~hl.literal(v if isinstance(v, list) else [v]).contains(m.get(k, "")) + for k, v in key_value_pairs_to_exclude.items() + ] + + filters = [] + if keep_filter: + filters.append(keep_combine_operator(keep_filter)) + if exclude_filter: + filters.append(exclude_combine_operator(exclude_filter)) + + return combine_operator(filters) + + return meta_expr.filter(lambda m: _get_filter(m)) + + def filter_arrays_by_meta( meta_expr: hl.expr.ArrayExpression, meta_indexed_exprs: Union[ Dict[str, hl.expr.ArrayExpression], hl.expr.ArrayExpression ], - items_to_filter: Union[Dict[str, List[str]], List[str]], + items_to_filter: Union[ + List[str], Dict[str, Union[List[str], Dict[str, Union[List[str], bool]]]] + ], keep: bool = True, + keep_combine_operator: str = "and", + exclude_combine_operator: str = "and", combine_operator: str = "and", exact_match: bool = False, ) -> Tuple[ @@ -634,14 +736,24 @@ def filter_arrays_by_meta( Union[Dict[str, hl.expr.ArrayExpression], hl.expr.ArrayExpression], ]: """ - Filter both metadata array expression and meta data indexed expression by `items_to_filter`. + Filter both metadata array expression and metadata indexed expression by `items_to_filter`. The `items_to_filter` can be used to filter in the following ways based on `meta_expr` items: - - By a list of keys, e.g. ["sex", "downsampling"]. - - By specific key: value pairs, e.g. to filter where 'pop' is 'han' or 'papuan' - {"pop": ["han", "papuan"]}, or where 'pop' is 'afr' and/or 'sex' is 'XX' - {"pop": ["afr"], "sex": ["XX"]}. + + - By a list of keys, e.g. ``["sex", "downsampling"]``. + - By specific key: value pairs, e.g. to filter where 'pop' is 'han' or 'papuan' + ``{"pop": ["han", "papuan"]}``, or where 'pop' is 'afr' and/or 'sex' is 'XX' + ``{"pop": ["afr"], "sex": ["XX"]}``. + - By specific key: value pairs with differing keep values, e.g.: + + .. code-block:: python + + { + "gen_anc": {"values": ["global", "afr"], "keep": True}, + "downsampling": {"keep": True}, + "subset": {"keep": False}, + } The items can be kept or removed from `meta_indexed_expr` and `meta_expr` based on the value of `keep`. For example if `meta_indexed_exprs` is {'freq': ht.freq, @@ -660,9 +772,8 @@ def filter_arrays_by_meta( specified in the `items_to_filter` parameter. For example, by default, if `keep` is True, `combine_operator` is "and", and `items_to_filter` is ["sex", "downsampling"], then all items in `meta_expr` with both "sex" and "downsampling" as keys will be - kept. However, if `exact_match` is True, then the items - in `meta_expr` will only be kept if "sex" and "downsampling" are the only keys in - the meta dict. + kept. However, if `exact_match` is True, then the items in `meta_expr` will only be + kept if "sex" and "downsampling" are the only keys in the meta dict. :param meta_expr: Metadata expression that contains the values of the elements in `meta_indexed_expr`. The most often used expression is `freq_meta` to index into @@ -672,8 +783,12 @@ def filter_arrays_by_meta( array or just a single expression indexed by the `meta_expr`. :param items_to_filter: Items to filter by, either a list or a dictionary. :param keep: Whether to keep or remove the items specified by `items_to_filter`. - :param combine_operator: Whether to use "and" or "or" to combine the items - specified by `items_to_filter`. + :param keep_combine_operator: Whether to use "and" or "or" to combine the filtering + criteria for keys/key-value pairs to keep. + :param exclude_combine_operator: Whether to use "and" or "or" to combine the + filtering criteria for keys/key-value pairs to exclude. + :param combine_operator: Whether to use "and" or "or" to combine the keep and + exclude filtering criteria. :param exact_match: Whether to apply the `keep` parameter to only the items specified in the `items_to_filter` parameter or to all items in `meta_expr`. See the example above for more details. Default is False. @@ -683,58 +798,84 @@ def filter_arrays_by_meta( """ meta_expr = meta_expr.collect(_localize=False)[0] + # If only a single array expression needs to be filtered, make meta_indexed_exprs + # a dictionary with a single key "_tmp" so it can be filtered in the same way as + # a dictionary of array expressions. if isinstance(meta_indexed_exprs, hl.expr.ArrayExpression): meta_indexed_exprs = {"_tmp": meta_indexed_exprs} - if combine_operator == "and": - operator_func = hl.all - elif combine_operator == "or": - operator_func = hl.any - else: - raise ValueError( - "combine_operator must be one of 'and' or 'or', but found" - f" {combine_operator}!" - ) - + # If items_to_filter is a list, convert it to a dictionary with the key being the + # item to filter and the value being None, so it can be filtered in the same way as + # a dictionary of items to filter. if isinstance(items_to_filter, list): - items_to_filter_set = hl.set(items_to_filter) - items_to_filter = [[k] for k in items_to_filter] - if exact_match: - filter_func = lambda m, k: ( - hl.len(hl.set(m.keys()).difference(items_to_filter_set)) == 0 - ) & m.contains(k) - else: - filter_func = lambda m, k: m.contains(k) + items_to_filter = {k: None for k in items_to_filter} elif isinstance(items_to_filter, dict): - items_to_filter = [ - [(k, v) for v in values] for k, values in items_to_filter.items() - ] - items_to_filter_set = hl.set(hl.flatten(items_to_filter)) - if exact_match: - filter_func = lambda m, k: ( - (hl.len(hl.set(m.items()).difference(items_to_filter_set)) == 0) - & (m.get(k[0], "") == k[1]) + # If items_to_filter is a dictionary with lists as values, convert the lists + # to dictionaries with the key "values" and the value being the list of values + # to filter by. + items_to_filter = { + k: ( + v + if v is None or isinstance(v, dict) + else {"values": v if isinstance(v, list) else [v]} ) - else: - filter_func = lambda m, k: (m.get(k[0], "") == k[1]) + for k, v in items_to_filter.items() + } else: raise TypeError("items_to_filter must be a list or a dictionary!") - meta_expr = hl.enumerate(meta_expr).filter( - lambda m: hl.bind( - lambda x: hl.if_else(keep, x, ~x), - operator_func( - [hl.any([filter_func(m[1], v) for v in k]) for k in items_to_filter] - ), - ), + # Use filter_meta_array to filter the meta_expr to keep only the items specified + # by items_to_filter. + keys_to_keep = [] + keys_to_exclude = [] + key_value_pairs_to_keep = {} + key_value_pairs_to_exclude = {} + + for k, v in items_to_filter.items(): + # Set item_keep to 'keep' parameter if value is None or if 'keep' value is not + # defined in that items' dictionary. Otherwise, (if already defined in the + # item's dictionary), use the 'keep' value defined in the dictionary. + item_keep = keep if v is None or "keep" not in v else v["keep"] + + if item_keep: + if v is not None and "values" in v: + key_value_pairs_to_keep[k] = v["values"] + else: + keys_to_keep.append(k) + else: + if v is not None and "values" in v: + key_value_pairs_to_exclude[k] = v["values"] + else: + keys_to_exclude.append(k) + + filtered_meta_expr = filter_meta_array( + meta_expr, + keys_to_keep=keys_to_keep, + keys_to_exclude=keys_to_exclude, + key_value_pairs_to_keep=key_value_pairs_to_keep, + key_value_pairs_to_exclude=key_value_pairs_to_exclude, + keep_combine_operator=keep_combine_operator, + exclude_combine_operator=exclude_combine_operator, + combine_operator=combine_operator, + exact_match=exact_match, + ) + + # Filter the enumerated meta_exprs to only keep the items that match the metadata + # dictionaries in the filtered meta expression. + filtered_meta_idx_expr = hl.enumerate(meta_expr).filter( + lambda x: filtered_meta_expr.contains(x[1]) ) + # Filter each of the array expressions in meta_indexed_exprs to only keep the items + # that match the metadata dictionaries in the filtered meta expression. meta_indexed_exprs = { - k: meta_expr.map(lambda x: v[x[0]]) for k, v in meta_indexed_exprs.items() + k: filtered_meta_idx_expr.map(lambda x: v[x[0]]) + for k, v in meta_indexed_exprs.items() } - meta_expr = meta_expr.map(lambda x: x[1]) + # If the original meta_indexed_exprs was a single array expression, return the + # filtered meta_indexed_exprs as a single array expression. if "_tmp" in meta_indexed_exprs: meta_indexed_exprs = meta_indexed_exprs["_tmp"] - return meta_expr, meta_indexed_exprs + return filtered_meta_expr, meta_indexed_exprs diff --git a/tests/utils/test_filtering.py b/tests/utils/test_filtering.py new file mode 100644 index 000000000..bebefa831 --- /dev/null +++ b/tests/utils/test_filtering.py @@ -0,0 +1,451 @@ +"""Tests for the filtering module.""" + +from typing import Any, Dict, List, Union + +import hail as hl +import pytest + +from gnomad.utils.filtering import filter_arrays_by_meta, filter_meta_array + + +@pytest.fixture(scope="class") +def metadata_combinations(): + """Top-level fixture to hold all metadata combinations.""" + + class MetaDataCombinations: + only_group = [{"group": "adj"}, {"group": "raw"}] + group_gen_anc_a = [{"group": "adj", "gen_anc": "a"}] + group_gen_anc_a_b = [*group_gen_anc_a, {"group": "adj", "gen_anc": "b"}] + group_gen_anc = [*group_gen_anc_a_b, {"group": "adj", "gen_anc": "c"}] + group_sex = [{"group": "adj", "sex": "XX"}, {"group": "adj", "sex": "XY"}] + group_subset = [ + {"group": "adj", "subset": "s1"}, + {"group": "raw", "subset": "s1"}, + ] + group_gen_anc_a_sex = [ + {"group": "adj", "gen_anc": "a", "sex": "XX"}, + {"group": "adj", "gen_anc": "a", "sex": "XY"}, + ] + group_gen_anc_b_sex = [ + {"group": "adj", "gen_anc": "b", "sex": "XX"}, + {"group": "adj", "gen_anc": "b", "sex": "XY"}, + ] + group_gen_anc_a_b_sex = group_gen_anc_a_sex + group_gen_anc_b_sex + group_gen_anc_sex = [ + *group_gen_anc_a_b_sex, + {"group": "adj", "gen_anc": "c", "sex": "XX"}, + {"group": "adj", "gen_anc": "c", "sex": "XY"}, + ] + group_gen_anc_a_subset = [{"group": "adj", "gen_anc": "a", "subset": "s1"}] + group_gen_anc_a_b_subset = [ + *group_gen_anc_a_subset, + {"group": "adj", "gen_anc": "b", "subset": "s1"}, + ] + group_gen_anc_subset = [ + *group_gen_anc_a_b_subset, + {"group": "adj", "gen_anc": "c", "subset": "s1"}, + ] + group_sex_subset = [ + {"group": "adj", "sex": "XX", "subset": "s1"}, + {"group": "adj", "sex": "XY", "subset": "s1"}, + ] + group_gen_anc_a_sex_subset = [ + {"group": "adj", "gen_anc": "a", "sex": "XX", "subset": "s1"}, + {"group": "adj", "gen_anc": "a", "sex": "XY", "subset": "s1"}, + ] + group_gen_anc_b_sex_subset = [ + {"group": "adj", "gen_anc": "b", "sex": "XX", "subset": "s1"}, + {"group": "adj", "gen_anc": "b", "sex": "XY", "subset": "s1"}, + ] + group_gen_anc_a_b_sex_subset = ( + group_gen_anc_a_sex_subset + group_gen_anc_b_sex_subset + ) + group_gen_anc_sex_subset = [ + *group_gen_anc_a_b_sex_subset, + {"group": "adj", "gen_anc": "c", "sex": "XX", "subset": "s1"}, + {"group": "adj", "gen_anc": "c", "sex": "XY", "subset": "s1"}, + ] + group_gen_anc_a_downsampling = [ + {"group": "adj", "gen_anc": "a", "downsampling": "1"}, + {"group": "adj", "gen_anc": "a", "downsampling": "2"}, + {"group": "adj", "gen_anc": "a", "downsampling": "3"}, + ] + group_gen_anc_a_b_downsampling = [ + *group_gen_anc_a_downsampling, + {"group": "adj", "gen_anc": "b", "downsampling": "1"}, + {"group": "adj", "gen_anc": "b", "downsampling": "2"}, + {"group": "adj", "gen_anc": "b", "downsampling": "3"}, + ] + group_gen_anc_downsampling = [ + *group_gen_anc_a_b_downsampling, + {"group": "adj", "gen_anc": "c", "downsampling": "1"}, + {"group": "adj", "gen_anc": "c", "downsampling": "2"}, + {"group": "adj", "gen_anc": "c", "downsampling": "3"}, + ] + group_gen_anc_a_subset_downsampling = [ + {"group": "adj", "gen_anc": "a", "subset": "s1", "downsampling": "1"}, + {"group": "adj", "gen_anc": "a", "subset": "s1", "downsampling": "2"}, + {"group": "adj", "gen_anc": "a", "subset": "s1", "downsampling": "3"}, + ] + group_gen_anc_a_b_subset_downsampling = [ + *group_gen_anc_a_subset_downsampling, + {"group": "adj", "gen_anc": "b", "subset": "s1", "downsampling": "1"}, + {"group": "adj", "gen_anc": "b", "subset": "s1", "downsampling": "2"}, + {"group": "adj", "gen_anc": "b", "subset": "s1", "downsampling": "3"}, + ] + group_gen_anc_subset_downsampling = [ + *group_gen_anc_a_b_subset_downsampling, + {"group": "adj", "gen_anc": "c", "subset": "s1", "downsampling": "1"}, + {"group": "adj", "gen_anc": "c", "subset": "s1", "downsampling": "2"}, + {"group": "adj", "gen_anc": "c", "subset": "s1", "downsampling": "3"}, + ] + downsampling = group_gen_anc_downsampling + group_gen_anc_subset_downsampling + sex = ( + group_sex + group_gen_anc_sex + group_sex_subset + group_gen_anc_sex_subset + ) + no_sex = ( + only_group + + group_gen_anc + + group_subset + + group_gen_anc_subset + + group_gen_anc_downsampling + + group_gen_anc_subset_downsampling + ) + sex_and_gen_anc = group_gen_anc_sex + group_gen_anc_sex_subset + no_sex_and_no_gen_anc = only_group + group_subset + sex_or_subset = ( + group_sex + + group_subset + + group_gen_anc_sex + + group_gen_anc_subset + + group_sex_subset + + group_gen_anc_sex_subset + + group_gen_anc_subset_downsampling + ) + sex_and_gen_anc_a = group_gen_anc_a_sex + group_gen_anc_a_sex_subset + sex_or_gen_anc_a = ( + group_gen_anc_a + + group_sex + + group_gen_anc_sex + + group_gen_anc_a_subset + + group_sex_subset + + group_gen_anc_sex_subset + + group_gen_anc_a_downsampling + + group_gen_anc_a_subset_downsampling + ) + sex_and_gen_anc_a_or_b = ( + group_gen_anc_a_sex + + group_gen_anc_b_sex + + group_gen_anc_a_sex_subset + + group_gen_anc_b_sex_subset + ) + no_downsampling = ( + only_group + + group_gen_anc + + group_sex + + group_subset + + group_gen_anc_sex + + group_gen_anc_subset + + group_sex_subset + + group_gen_anc_sex_subset + ) + no_subset_and_no_downsampling = ( + only_group + group_gen_anc + group_sex + group_gen_anc_sex + ) + no_subset_or_no_downsampling = no_downsampling + group_gen_anc_downsampling + no_downsampling_and_no_gen_anc_c = ( + only_group + + group_gen_anc_a_b + + group_sex + + group_subset + + group_gen_anc_a_b_sex + + group_gen_anc_a_b_subset + + group_sex_subset + + group_gen_anc_a_b_sex_subset + ) + no_downsampling_or_no_gen_anc_c = ( + no_downsampling + + group_gen_anc_a_b_downsampling + + group_gen_anc_a_b_subset_downsampling + ) + sex_and_no_subset = group_sex + group_gen_anc_sex + sex_or_no_subset = ( + only_group + + group_gen_anc + + group_sex + + group_gen_anc_sex + + group_sex_subset + + group_gen_anc_sex_subset + + group_gen_anc_downsampling + ) + + return MetaDataCombinations + + +@pytest.fixture +def mock_meta_expr(metadata_combinations) -> hl.expr.ArrayExpression: + """Mock meta expression.""" + return hl.literal( + metadata_combinations.only_group + + metadata_combinations.group_gen_anc + + metadata_combinations.group_sex + + metadata_combinations.group_subset + + metadata_combinations.group_gen_anc_sex + + metadata_combinations.group_gen_anc_subset + + metadata_combinations.group_sex_subset + + metadata_combinations.group_gen_anc_sex_subset + + metadata_combinations.group_gen_anc_downsampling + + metadata_combinations.group_gen_anc_subset_downsampling + ) + + +class TestFilterMetaArray: + """Tests for the filter_meta_array function.""" + + # Define some common parameters. + all_and = ["and", "and", "and"] + s_ga_list = ["sex", "gen_anc"] + s_ss_list = ["sex", "subset"] + s_list = ["sex"] + g_s_list = ["group", "sex"] + ss_d_ex = [None, ["subset", "downsampling"], None, None] + ds_ex = [None, ["downsampling"], None] + s_ss = [["sex"], ["subset"]] + + ga_a = {"gen_anc": "a"} + ga_c = {"gen_anc": "c"} + ga_ab = {"gen_anc": ["a", "b"]} + + @pytest.mark.parametrize( + "keys_to_keep, keys_to_exclude, key_value_pairs_to_keep, key_value_pairs_to_exclude, keep_combine_operator, exclude_combine_operator, combine_operator, exact_match, expected", + [ + (s_list, None, None, None, *all_and, False, "sex"), + (s_ga_list, None, None, None, *all_and, False, "sex_and_gen_anc"), + (g_s_list, None, None, None, *all_and, True, "group_sex"), + (s_ss_list, None, None, None, "or", "and", "and", False, "sex_or_subset"), + (s_list, None, ga_a, None, *all_and, False, "sex_and_gen_anc_a"), + (s_list, None, ga_a, None, "or", "and", "and", False, "sex_or_gen_anc_a"), + (g_s_list, None, ga_a, None, *all_and, True, "group_gen_anc_a_sex"), + (s_list, None, ga_ab, None, *all_and, False, "sex_and_gen_anc_a_or_b"), + (*ds_ex, None, *all_and, False, "no_downsampling"), + (*ss_d_ex, *all_and, False, "no_subset_and_no_downsampling"), + (*ss_d_ex, "and", "or", "and", False, "no_subset_or_no_downsampling"), + (*ds_ex, ga_c, *all_and, False, "no_downsampling_and_no_gen_anc_c"), + ( + *ds_ex, + ga_c, + "and", + "or", + "and", + False, + "no_downsampling_or_no_gen_anc_c", + ), + (*s_ss, None, None, *all_and, False, "sex_and_no_subset"), + (*s_ss, None, None, "and", "and", "or", False, "sex_or_no_subset"), + ], + ) + def test_filter_meta_array( + self, + mock_meta_expr: hl.expr.ArrayExpression, + keys_to_keep: List[str], + keys_to_exclude: List[str], + key_value_pairs_to_keep: Dict[str, Any], + key_value_pairs_to_exclude: Dict[str, Any], + keep_combine_operator: str, + exclude_combine_operator: str, + combine_operator: str, + exact_match: bool, + expected: str, + metadata_combinations: Any, + ) -> None: + """Test filter_meta_array function.""" + result = filter_meta_array( + meta_expr=mock_meta_expr, + keys_to_keep=keys_to_keep, + keys_to_exclude=keys_to_exclude, + key_value_pairs_to_keep=key_value_pairs_to_keep, + key_value_pairs_to_exclude=key_value_pairs_to_exclude, + keep_combine_operator=keep_combine_operator, + exclude_combine_operator=exclude_combine_operator, + combine_operator=combine_operator, + exact_match=exact_match, + ) + assert hl.eval(result) == getattr(metadata_combinations, expected) + + +class TestFilterArraysByMeta: + """Tests for the filter_arrays_by_meta function.""" + + @pytest.fixture + def simple_mock_meta_expr(self): + """Get simple mock meta expression for filter_arrays_by_meta.""" + return hl.literal( + [ + {"key1": "value1", "key2": "value2"}, + {"key1": "value3", "key2": "value4"}, + {"key1": "value5", "key2": "value6"}, + ] + ) + + @pytest.fixture + def simple_mock_meta_indexed_exprs(self): + """Get simple mock meta-indexed expressions for filter_arrays_by_meta.""" + return { + "expr1": hl.literal([1, 2, 3]), + "expr2": hl.literal([4, 5, 6]), + } + + params = { + "k1_keep": { + "in": {"key1": {"values": ["value1", "value3"], "keep": True}}, + "out": ( + [ + {"key1": "value1", "key2": "value2"}, + {"key1": "value3", "key2": "value4"}, + ], + [1, 2], + [4, 5], + ), + }, + "k1_ex": { + "in": {"key1": {"values": ["value1", "value3"], "keep": False}}, + "out": ([{"key1": "value5", "key2": "value6"}], [3], [6]), + }, + "k12_keep": { + "in": { + "key1": {"values": ["value1"], "keep": True}, + "key2": {"values": ["value2"], "keep": True}, + }, + "out": ([{"key1": "value1", "key2": "value2"}], [1], [4]), + }, + } + + @pytest.mark.parametrize( + "items_to_filter, keep, combine_operator, exact_match, expected_meta, expected_expr1, expected_expr2", + [ + (params["k1_keep"]["in"], True, "and", False, *params["k1_keep"]["out"]), + (params["k1_ex"]["in"], False, "and", False, *params["k1_ex"]["out"]), + (params["k12_keep"]["in"], True, "and", True, *params["k12_keep"]["out"]), + ], + ) + def test_filter_arrays_by_meta( + self, + simple_mock_meta_expr: hl.expr.ArrayExpression, + simple_mock_meta_indexed_exprs: Dict[str, hl.expr.ArrayExpression], + items_to_filter: Dict[str, Dict[str, Any]], + keep: bool, + combine_operator: str, + exact_match: bool, + expected_meta: List[Dict[str, str]], + expected_expr1: List[int], + expected_expr2: List[int], + ) -> None: + """Test filter_arrays_by_meta function.""" + filtered_meta_expr, filtered_meta_indexed_exprs = filter_arrays_by_meta( + meta_expr=simple_mock_meta_expr, + meta_indexed_exprs=simple_mock_meta_indexed_exprs, + items_to_filter=items_to_filter, + keep=keep, + combine_operator=combine_operator, + exact_match=exact_match, + ) + assert hl.eval(filtered_meta_expr) == expected_meta + assert hl.eval(filtered_meta_indexed_exprs["expr1"]) == expected_expr1 + assert hl.eval(filtered_meta_indexed_exprs["expr2"]) == expected_expr2 + + # Additional cases reusing complex cases from the metadata_combinations fixture. + all_and = ["and", "and", "and"] + s_ga_list = ["sex", "gen_anc"] + ss_d_list = ["subset", "downsampling"] + s_ga_dict = {"sex": None, "gen_anc": None} + s_ga_keep = {"sex": {"keep": True}, "gen_anc": {"keep": True}} + s_ga_ex = {"sex": {"keep": False}, "gen_anc": {"keep": False}} + s_ga_a = {"sex": None, "gen_anc": "a"} + s_ga_a_2 = {"sex": None, "gen_anc": ["a"]} + s_ga_a_3 = {"sex": None, "gen_anc": {"values": "a"}} + s_ga_a_4 = {"sex": None, "gen_anc": {"values": ["a"]}} + s_ga_a_4_keep = {"sex": None, "gen_anc": {"values": ["a"], "keep": True}} + g_s_ga_a = {"group": None, "sex": None, "gen_anc": "a"} + s_ga_a_b = {"sex": None, "gen_anc": ["a", "b"]} + d_ga_c = {"downsampling": None, "gen_anc": "c"} + s_keep_ss_ex = {"sex": {"keep": True}, "subset": {"keep": False}} + + @pytest.mark.parametrize( + "items_to_filter, keep, keep_combine_operator, exclude_combine_operator, combine_operator, exact_match, expected_meta", + [ + (["sex"], True, *all_and, False, "sex"), + ({"sex": None}, True, *all_and, False, "sex"), + ({"sex": {"keep": True}}, True, *all_and, False, "sex"), + ({"sex": {"keep": True}}, False, *all_and, False, "sex"), + (["sex"], True, "or", "or", "or", False, "sex"), + (["sex"], False, *all_and, False, "no_sex"), + ({"sex": None}, False, *all_and, False, "no_sex"), + ({"sex": {"keep": False}}, True, *all_and, False, "no_sex"), + (s_ga_list, True, *all_and, False, "sex_and_gen_anc"), + (s_ga_dict, True, *all_and, False, "sex_and_gen_anc"), + (s_ga_keep, True, *all_and, False, "sex_and_gen_anc"), + (s_ga_list, False, *all_and, False, "no_sex_and_no_gen_anc"), + (s_ga_ex, True, *all_and, False, "no_sex_and_no_gen_anc"), + (["sex", "subset"], True, "or", "and", "and", False, "sex_or_subset"), + (ss_d_list, False, *all_and, False, "no_subset_and_no_downsampling"), + ( + ss_d_list, + False, + "and", + "or", + "and", + False, + "no_subset_or_no_downsampling", + ), + (["group", "sex"], True, *all_and, True, "group_sex"), + (["group"], True, *all_and, True, "only_group"), + (s_ga_a, True, *all_and, False, "sex_and_gen_anc_a"), + (s_ga_a_2, True, *all_and, False, "sex_and_gen_anc_a"), + (s_ga_a_3, True, *all_and, False, "sex_and_gen_anc_a"), + (s_ga_a_4, True, *all_and, False, "sex_and_gen_anc_a"), + (s_ga_a_4_keep, True, *all_and, False, "sex_and_gen_anc_a"), + (s_ga_a, True, "or", "and", "and", False, "sex_or_gen_anc_a"), + (g_s_ga_a, True, *all_and, True, "group_gen_anc_a_sex"), + (s_ga_a_b, True, *all_and, False, "sex_and_gen_anc_a_or_b"), + ( + d_ga_c, + False, + "and", + "or", + "and", + False, + "no_downsampling_or_no_gen_anc_c", + ), + (s_keep_ss_ex, True, *all_and, False, "sex_and_no_subset"), + (s_keep_ss_ex, True, "and", "and", "or", False, "sex_or_no_subset"), + ], + ) + def test_filter_arrays_by_meta_with_reuse( + self, + mock_meta_expr: hl.expr.ArrayExpression, + items_to_filter: Union[List[str], Dict[str, Dict[str, Any]]], + keep: bool, + keep_combine_operator, + exclude_combine_operator, + combine_operator: str, + exact_match: bool, + expected_meta: str, + metadata_combinations: Any, + ) -> None: + """Test filter_arrays_by_meta function with reused cases.""" + filtered_meta_expr, filtered_meta_indexed_exprs = filter_arrays_by_meta( + meta_expr=mock_meta_expr, + meta_indexed_exprs={"meta_array": mock_meta_expr}, + items_to_filter=items_to_filter, + keep=keep, + keep_combine_operator=keep_combine_operator, + exclude_combine_operator=exclude_combine_operator, + combine_operator=combine_operator, + exact_match=exact_match, + ) + assert hl.eval(filtered_meta_expr) == getattr( + metadata_combinations, expected_meta + ) + assert hl.eval(filtered_meta_indexed_exprs["meta_array"]) == getattr( + metadata_combinations, expected_meta + )