-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
19cce93
commit 81bfb23
Showing
5 changed files
with
323 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
from typing import Tuple, Iterable, List | ||
from tqdm import tqdm | ||
from kipoiseq.dataclasses import Variant, Interval | ||
|
||
|
||
class VariantQuery: | ||
|
||
def __init__(self, func): | ||
self.func = func | ||
|
||
def __call__(self, variant: Variant): | ||
return self.func(variant) | ||
|
||
def __or__(self, other): | ||
return VariantQuery(lambda variant: self(variant) or other(variant)) | ||
|
||
def __and__(self, other): | ||
return VariantQuery(lambda variant: self(variant) and other(variant)) | ||
|
||
|
||
class FilterVariantQuery(VariantQuery): | ||
|
||
def __init__(self, filter='PASS'): | ||
self.filter = filter | ||
|
||
def __call__(self, variant): | ||
return variant.filter == self.filter | ||
|
||
|
||
class VariantIntervalQuery: | ||
|
||
def __init__(self, func): | ||
self.func = func | ||
|
||
def __call__(self, variants: List[Variant], interval: Interval): | ||
return self.func(variants, interval) | ||
|
||
def __or__(self, other): | ||
return VariantIntervalQuery( | ||
lambda variants, interval: ( | ||
i or j for i, j in zip(self(variants, interval), | ||
other(variants, interval)))) | ||
|
||
def __and__(self, other): | ||
return VariantIntervalQuery( | ||
lambda variants, interval: ( | ||
i and j for i, j in zip(self(variants, interval), | ||
other(variants, interval)))) | ||
|
||
|
||
class NumberVariantQuery(VariantIntervalQuery): | ||
""" | ||
Closure for variant query. Filter variants for interval | ||
if number of variants in given limits. | ||
""" | ||
|
||
def __init__(self, max_num=float('inf'), min_num=0): | ||
# TODO: sample speficity | ||
self.max_num = max_num | ||
self.min_num = min_num | ||
|
||
def __call__(self, variants, interval): | ||
if self.max_num >= len(variants) >= self.min_num: | ||
return [True] * len(variants) | ||
else: | ||
return [False] * len(variants) | ||
|
||
|
||
_VariantIntervalType = List[Tuple[Iterable[Variant], Interval]] | ||
|
||
|
||
class VariantIntervalQueryable: | ||
|
||
def __init__(self, vcf, variant_intervals: _VariantIntervalType, | ||
progress=False): | ||
""" | ||
Query object of variants. | ||
Args: | ||
vcf: cyvcf2.VCF objects. | ||
variants: iter of (variant, interval) tuples. | ||
""" | ||
self.vcf = vcf | ||
|
||
if progress: | ||
self.variant_intervals = tqdm(variant_intervals) | ||
else: | ||
self.variant_intervals = variant_intervals | ||
|
||
def __iter__(self): | ||
for variants, interval in self.variant_intervals: | ||
yield from variants | ||
|
||
def filter(self, query: VariantQuery): | ||
""" | ||
Filters variant given conduction. | ||
Args: | ||
query: function which get a variant as input and filtered iter of | ||
variants. | ||
""" | ||
self.variant_intervals = [ | ||
(filter(query, variants), Interval) | ||
for variants, interval in self.variant_intervals | ||
] | ||
return self | ||
|
||
def filter_range(self, query: VariantIntervalQuery): | ||
""" | ||
Filters variant given conduction. | ||
Args: | ||
query: function which get variants and an interval as input | ||
and filtered iter of variants. | ||
""" | ||
self.variant_intervals = list(self._filter_range(query)) | ||
return self | ||
|
||
def _filter_range(self, query: VariantIntervalQuery): | ||
for variants, interval in self.variant_intervals: | ||
variants = list(variants) | ||
yield ( | ||
v | ||
for v, cond in zip(variants, query(variants, interval)) | ||
if cond | ||
), interval | ||
|
||
def to_vcf(self, path): | ||
""" | ||
Parse query result as vcf file. | ||
Args: | ||
path: path of the file. | ||
""" | ||
from cyvcf2 import Writer | ||
writer = Writer(path, self.vcf) | ||
for v in self: | ||
writer.write_record(v.source) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
import pytest | ||
|
||
vcf_file = 'tests/data/test.vcf.gz' | ||
sample_5kb_fasta_file = 'tests/data/sample.5kb.fa' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
import pytest | ||
from conftest import vcf_file | ||
from kipoiseq.dataclasses import Variant, Interval | ||
from kipoiseq.extractors.vcf_seq import MultiSampleVCF | ||
from kipoiseq.extractors.vcf_query import * | ||
|
||
|
||
@pytest.fixture | ||
def query_true(): | ||
return VariantQuery(lambda v: True) | ||
|
||
|
||
@pytest.fixture | ||
def query_false(): | ||
return VariantQuery(lambda v: False) | ||
|
||
|
||
def test_base_query__and__(query_false, query_true): | ||
assert not (query_false & query_true)(None) | ||
|
||
|
||
def test_base_query__or__(query_false, query_true): | ||
assert (query_false | query_true)(None) | ||
|
||
|
||
@pytest.fixture | ||
def variant_queryable(): | ||
vcf = MultiSampleVCF(vcf_file) | ||
return VariantIntervalQueryable(vcf, [ | ||
( | ||
[ | ||
Variant('chr1', 12, 'A', 'T'), | ||
Variant('chr1', 18, 'A', 'C', filter='q10'), | ||
], | ||
Interval('chr1', 10, 20) | ||
), | ||
( | ||
[ | ||
Variant('chr2', 120, 'AT', 'AAAT'), | ||
], | ||
Interval('chr2', 110, 200) | ||
) | ||
]) | ||
|
||
|
||
def test_variant_queryable__iter__(variant_queryable): | ||
variants = list(variant_queryable) | ||
assert len(variants) == 3 | ||
assert variants[0].ref == 'A' | ||
assert variants[0].alt == 'T' | ||
|
||
|
||
def test_variant_queryable_filter_1(variant_queryable): | ||
assert len(list(variant_queryable.filter(lambda v: v.alt == 'T'))) == 1 | ||
|
||
|
||
def test_variant_queryable_filter_2(variant_queryable): | ||
assert len(list(variant_queryable.filter(lambda v: v.ref == 'A'))) == 2 | ||
|
||
|
||
def test_variant_filter_range(variant_queryable): | ||
assert 2 == len(list(variant_queryable.filter_range( | ||
lambda variants, interval: (v.ref == 'A' for v in variants)))) | ||
|
||
|
||
def test_VariantQueryable_filter_by_num_max(variant_queryable): | ||
assert 1 == len(list(variant_queryable.filter_range( | ||
NumberVariantQuery(max_num=1)))) | ||
|
||
|
||
def test_VariantQueryable_filter_by_num_min(variant_queryable): | ||
assert 2 == len(list(variant_queryable.filter_range( | ||
NumberVariantQuery(min_num=2)))) | ||
|
||
|
||
def test_VariantQueryable_filter_variant_query_2(variant_queryable): | ||
assert 2 == len(list(variant_queryable.filter(FilterVariantQuery()))) | ||
|
||
|
||
def test_VariantQueryable_filter_variant_query_3(variant_queryable): | ||
assert 3 == len(list(variant_queryable.filter( | ||
FilterVariantQuery() | FilterVariantQuery(filter='q10')))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters