Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up compute_coverage_stats, change it to use agg_by_strata and have an optional group_membership_ht parameter #660

Merged
merged 8 commits into from
Jan 17, 2024
248 changes: 129 additions & 119 deletions gnomad/utils/sparse_mt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import hail as hl

from gnomad.utils.annotations import (
agg_by_strata,
fs_from_sb,
generate_freq_group_membership_array,
get_adj_expr,
Expand Down Expand Up @@ -913,6 +914,7 @@ def compute_coverage_stats(
coverage_over_x_bins: List[int] = [1, 5, 10, 15, 20, 25, 30, 50, 100],
row_key_fields: List[str] = ["locus"],
strata_expr: Optional[List[Dict[str, hl.expr.StringExpression]]] = None,
group_membership_ht: Optional[hl.Table] = None,
) -> hl.Table:
"""
Compute coverage statistics for every base of the `reference_ht` provided.
Expand All @@ -934,38 +936,70 @@ def compute_coverage_stats(
:param row_key_fields: List of row key fields to use for joining `mtds` with
`reference_ht`
:param strata_expr: Optional list of dicts containing expressions to stratify the
coverage stats by.
:return: Table with per-base coverage stats
coverage stats by. Only one of `group_membership_ht` or `strata_expr` can be
specified.
:param group_membership_ht: Optional Table containing group membership annotations
to stratify the coverage stats by. Only one of `group_membership_ht` or
`strata_expr` can be specified.
:return: Table with per-base coverage stats.
"""
is_vds = isinstance(mtds, hl.vds.VariantDataset)
if is_vds:
mt = mtds.variant_data
else:
mt = mtds

if strata_expr is None:
# Determine the genotype field.
gt_field = set(mt.entry) & {"GT", "LGT"}
if len(gt_field) == 0:
raise ValueError("No genotype field found in entry fields.")

gt_field = gt_field.pop()

if group_membership_ht is not None and strata_expr is not None:
raise ValueError(
"Only one of 'group_membership_ht' or 'strata_expr' can be specified."
)

# Initialize no_strata and default strata_expr if neither group_membership_ht nor
# strata_expr is provided.
no_strata = group_membership_ht is None and strata_expr is None
if no_strata:
strata_expr = {}
no_strata = True
else:
no_strata = False

# Annotate the MT cols with each of the expressions in strata_expr and redefine
# strata_expr based on the column HT with added annotations.
ht = mt.annotate_cols(**{k: v for d in strata_expr for k, v in d.items()}).cols()
strata_expr = [{k: ht[k] for k in d} for d in strata_expr]

# Use the function for creating the frequency stratified by `freq_meta`,
# `freq_meta_sample_count`, and `group_membership` annotations to give
# stratification group membership info for computing coverage. By default, this
# function returns annotations where the second element is a placeholder for the
# "raw" frequency of all samples, where the first 2 elements are the same sample
# set, but freq_meta startswith [{"group": "adj", "group": "raw", ...]. Use
# `no_raw_group` to exclude the "raw" group so there is a single annotation
# representing the full samples set. `freq_meta` is updated below to remove "group"
# from all dicts.
group_membership_ht = generate_freq_group_membership_array(
ht, strata_expr, no_raw_group=True
)

if group_membership_ht is None:
# Annotate the MT cols with each of the expressions in strata_expr and redefine
# strata_expr based on the column HT with added annotations.
ht = mt.annotate_cols(
**{k: v for d in strata_expr for k, v in d.items()}
).cols()
strata_expr = [{k: ht[k] for k in d} for d in strata_expr]

# Use 'generate_freq_group_membership_array' to create a group_membership Table
# that gives stratification group membership info based on 'strata_expr'. The
# returned Table has the following annotations: 'freq_meta',
# 'freq_meta_sample_count', and 'group_membership'. By default, this
# function returns annotations where the second element is a placeholder for the
# "raw" frequency of all samples, where the first 2 elements are the same sample
# set, but 'freq_meta' starts with [{"group": "adj", "group": "raw", ...]. Use
# `no_raw_group` to exclude the "raw" group so there is a single annotation
# representing the full samples set. Update all 'freq_meta' entries' "group"
# to "raw" because `generate_freq_group_membership_array` will return them all
# as "adj" since it was built for frequency computation, but for the coverage
# computation we don't want to do any filtering.
group_membership_ht = generate_freq_group_membership_array(
ht, strata_expr, no_raw_group=True
)
group_membership_ht = group_membership_ht.annotate_globals(
freq_meta=group_membership_ht.freq_meta.map(
lambda x: hl.dict(
x.items().map(
lambda m: hl.if_else(m[0] == "group", ("group", "raw"), m)
)
)
)
)

n_samples = group_membership_ht.count()
sample_counts = group_membership_ht.index_globals().freq_meta_sample_count

Expand All @@ -991,30 +1025,28 @@ def join_with_ref(mt: hl.MatrixTable) -> hl.MatrixTable:
"""
Outer join MatrixTable with reference Table.

Add 'in_ref' annotation indicating whether a given position is found in the reference Table.
Add 'in_ref' annotation indicating whether a given position is found in the
reference Table.

:param mt: Input MatrixTable.
:return: MatrixTable with 'in_ref' annotation added.
"""
keep_entries = ["DP"]
if "END" in mt.entry:
keep_entries.append("END")
if "LGT" in mt.entry:
keep_entries.append("LGT")
if "GT" in mt.entry:
keep_entries.append("GT")
# Get the total number of samples.
n_samples = mt.count_cols()
entry_keep_fields = set(mt.entry) & {gt_field, "DP", "END"}
mt_col_key_fields = list(mt.col_key)
mt_row_key_fields = list(mt.row_key)
t = mt.select_entries(*keep_entries).select_cols().select_rows()
t = mt.select_entries(*entry_keep_fields).select_cols().select_rows()

# Localize entries and perform an outer join with the reference HT.
t = t._localize_entries("__entries", "__cols")
t = (
t.key_by(*row_key_fields)
.join(
reference_ht.key_by(*row_key_fields).select(_in_ref=True),
how="outer",
)
.key_by(*mt_row_key_fields)
t = t.key_by(*row_key_fields)
t = t.join(
reference_ht.key_by(*row_key_fields).select(_in_ref=True), how="outer"
)
t = t.key_by(*mt_row_key_fields)

# Fill in missing entries with missing values for each entry field.
t = t.annotate(
__entries=hl.or_else(
t.__entries,
Expand All @@ -1024,8 +1056,10 @@ def join_with_ref(mt: hl.MatrixTable) -> hl.MatrixTable:
)
)

# Unlocalize entries to turn the HT back to a MT.
return t._unlocalize_entries("__entries", "__cols", mt_col_key_fields)

# Densify VDS/sparse MT at all sites.
if is_vds:
mtds = hl.vds.VariantDataset(
mtds.reference_data.select_entries("END", "DP").select_cols().select_rows(),
Expand All @@ -1039,112 +1073,88 @@ def join_with_ref(mt: hl.MatrixTable) -> hl.MatrixTable:
# Densify
mt = hl.experimental.densify(mtds)

# Filter rows where the reference is missing
# Filter rows where the reference is missing.
mt = mt.filter_rows(mt._in_ref)

# Unfilter entries so that entries with no ref block overlap aren't null
# Unfilter entries so that entries with no ref block overlap aren't null.
mt = mt.unfilter_entries()

# Annotate with group membership
mt = mt.annotate_cols(
group_membership=group_membership_ht[mt.col_key].group_membership
)

# Compute coverage stats
coverage_over_x_bins = sorted(coverage_over_x_bins)
max_coverage_bin = coverage_over_x_bins[-1]
hl_coverage_over_x_bins = hl.array(coverage_over_x_bins)

# This expression creates a counter DP -> number of samples for DP between
# 0 and max_coverage_bin
coverage_counter_expr = hl.agg.counter(
hl.min(max_coverage_bin, hl.or_else(mt.DP, 0))
)
mean_expr = hl.agg.mean(hl.or_else(mt.DP, 0))

# Annotate all rows with coverage stats for each strata group.
ht = mt.select_rows(
coverage_stats=hl.agg.array_agg(
lambda x: hl.agg.filter(
x,
hl.struct(
coverage_counter=coverage_counter_expr,
mean=hl.if_else(hl.is_nan(mean_expr), 0, mean_expr),
median_approx=hl.or_else(
hl.agg.approx_median(hl.or_else(mt.DP, 0)), 0
),
total_DP=hl.agg.sum(mt.DP),
),
# Compute coverage stats.
cov_bins = sorted(coverage_over_x_bins)
rev_cov_bins = list(reversed(cov_bins))
max_cov_bin = cov_bins[-1]
cov_bins = hl.array(cov_bins)
entry_agg_funcs = {
"coverage_stats": (
lambda t: hl.if_else(hl.is_missing(t.DP) | hl.is_nan(t.DP), 0, t.DP),
lambda dp: hl.struct(
# This expression creates a counter DP -> number of samples for DP
# between 0 and max_cov_bin.
coverage_counter=hl.agg.counter(hl.min(max_cov_bin, dp)),
mean=hl.if_else(hl.is_nan(hl.agg.mean(dp)), 0, hl.agg.mean(dp)),
median_approx=hl.or_else(hl.agg.approx_median(dp), 0),
total_DP=hl.agg.sum(dp),
),
mt.group_membership,
)
).rows()
}
ht = agg_by_strata(mt, entry_agg_funcs, group_membership_ht=group_membership_ht)
ht = ht.checkpoint(hl.utils.new_temp_file("coverage_stats", "ht"))

# This expression aggregates the DP counter in reverse order of the
# coverage_over_x_bins and computes the cumulative sum over them.
# It needs to be in reverse order because we want the sum over samples
# covered by > X.
count_array_expr = ht.coverage_stats.map(
lambda x: hl.cumulative_sum(
hl.array(
# The coverage was already floored to the max_coverage_bin, so no more
# aggregation is needed for the max bin.
[hl.int32(x.coverage_counter.get(max_coverage_bin, 0))]
# For each of the other bins, coverage needs to be summed between the
# boundaries.
).extend(
hl.range(hl.len(hl_coverage_over_x_bins) - 1, 0, step=-1).map(
lambda i: hl.sum(
hl.range(
hl_coverage_over_x_bins[i - 1], hl_coverage_over_x_bins[i]
).map(lambda j: hl.int32(x.coverage_counter.get(j, 0)))
)
# This expression aggregates the DP counter in reverse order of the cov_bins and
# computes the cumulative sum over them. It needs to be in reverse order because we
# want the sum over samples covered by > X.
def _cov_stats(
cov_stat: hl.expr.StructExpression, n: hl.expr.Int32Expression
) -> hl.expr.StructExpression:
# The coverage was already floored to the max_coverage_bin, so no more
# aggregation is needed for the max bin.
count_expr = cov_stat.coverage_counter
max_bin_expr = hl.int32(count_expr.get(max_cov_bin, 0))

# For each of the other bins, coverage is summed between the boundaries.
bin_expr = hl.range(hl.len(cov_bins) - 1, 0, step=-1)
bin_expr = bin_expr.map(
lambda i: hl.sum(
hl.range(cov_bins[i - 1], cov_bins[i]).map(
lambda j: hl.int32(count_expr.get(j, 0))
)
)
)
)
bin_expr = hl.cumulative_sum(hl.array([max_bin_expr]).extend(bin_expr))

bin_expr = {f"over_{x}": bin_expr[i] / n for i, x in enumerate(rev_cov_bins)}

return cov_stat.annotate(**bin_expr).drop("coverage_counter")

ht = ht.annotate(
coverage_stats=hl.map(
lambda c, g, n: c.annotate(
**{
f"over_{x}": g[i] / n
for i, x in zip(
range(len(coverage_over_x_bins) - 1, -1, -1),
# Reverse the bin index as count_array_expr has reverse order.
coverage_over_x_bins,
)
}
).drop("coverage_counter"),
lambda c, n: _cov_stats(c, n),
ht.coverage_stats,
count_array_expr,
sample_counts,
)
)
current_keys = list(ht.key)
ht = (
ht.key_by(*row_key_fields)
.select_globals()
.drop(*[k for k in current_keys if k not in row_key_fields])
)
ht = ht.key_by(*row_key_fields).select_globals()
ht = ht.drop(*[k for k in current_keys if k not in row_key_fields])

group_globals = group_membership_ht.index_globals()
global_expr = {}
if no_strata:
# If there was no stratification, move coverage_stats annotations to the top
# level.
ht = ht.select(**{k: ht.coverage_stats[0][k] for k in ht.coverage_stats[0]})
global_expr["sample_count"] = group_globals.freq_meta_sample_count[0]
else:
# If there was stratification, add the metadata and sample count info for the
# stratification to the globals.
ht = ht.annotate_globals(
coverage_stats_meta=(
group_membership_ht.index_globals().freq_meta.map(
lambda x: hl.dict(x.items().filter(lambda m: m[0] != "group"))
)
),
coverage_stats_meta_sample_count=(
group_membership_ht.index_globals().freq_meta_sample_count
),
global_expr["coverage_stats_meta"] = group_globals.freq_meta.map(
lambda x: hl.dict(x.items().filter(lambda m: m[0] != "group"))
)
global_expr["coverage_stats_meta_sample_count"] = (
group_globals.freq_meta_sample_count
)

ht = ht.annotate_globals(**global_expr)

return ht

Expand Down