Skip to content

Commit

Permalink
Apply review suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
KoalaQin committed Dec 20, 2024
1 parent 1e50c05 commit e966735
Showing 1 changed file with 17 additions and 18 deletions.
35 changes: 17 additions & 18 deletions gnomad/utils/filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,12 @@ def filter_to_gencode_cds(
:param genes: Optional gene(s) to filter to. Can be a single gene string or list of genes. Default is None.
:param by_gene_symbol: Whether to filter by gene symbol. Default is True. If False, will filter by gene ID.
:param padding_bp: Number of bases to pad the CDS intervals by. Default is 0.
:param max_intervals: Maximum number of intervals to collect before filtering. Default is 1000.
:param max_collect_intervals: Maximum number of intervals for the use of
`hl.filter_intervals` for filtering. When the number of intervals to filter is
greater than this number, `filter`/`filter_rows` will be used instead. The
reason for this is that`hl.filter_intervals` is faster, but when the
number of intervals is too large, this can cause memory errors. Default is
3000.
:return: Table/MatrixTable filtered to loci in Gencode CDS intervals.
"""
if gencode_ht is None:
Expand All @@ -470,28 +475,22 @@ def filter_to_gencode_cds(
" documentation for more details to confirm it's being used as intended."
)
if genes:
genes_upper = (
genes = hl.literal(
[genes.upper()] if isinstance(genes, str) else [g.upper() for g in genes]
)
if by_gene_symbol:
gencode_ht = gencode_ht.filter(
hl.literal(genes_upper).contains(gencode_ht.gene_name)
)
else:
gencode_ht = gencode_ht.filter(
hl.literal(genes_upper).contains(gencode_ht.gene_id)
)
gene_field = "gene_name" if by_gene_symbol else "gene_id"
gencode_ht = gencode_ht.filter(genes.contains(gencode_ht[gene_field]))

interval_expr = gencode_ht.interval
if padding_bp:
gencode_ht = gencode_ht.annotate(
padded_interval=hl.locus_interval(
gencode_ht.interval.start.contig,
gencode_ht.interval.start.position - padding_bp,
gencode_ht.interval.end.position + padding_bp,
includes_start=gencode_ht.interval.includes_start,
includes_end=gencode_ht.interval.includes_end,
reference_genome=gencode_ht.interval.start.dtype.reference_genome,
)
interval_expr = hl.locus_interval(
interval_expr.start.contig,
interval_expr.start.position - padding_bp,
interval_expr.end.position + padding_bp,
includes_start=interval_expr.includes_start,
includes_end=interval_expr.includes_end,
reference_genome=interval_expr.start.dtype.reference_genome,
)

# Only collect intervals if there are less than or equal to `max_intervals` to avoid memory issues.
Expand Down

0 comments on commit e966735

Please sign in to comment.