Skip to content

Commit

Permalink
queryable-vcf-files
Browse files Browse the repository at this point in the history
  • Loading branch information
MuhammedHasan committed Oct 7, 2019
1 parent 19cce93 commit 81bfb23
Show file tree
Hide file tree
Showing 5 changed files with 323 additions and 3 deletions.
138 changes: 138 additions & 0 deletions kipoiseq/extractors/vcf_query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
from typing import Tuple, Iterable, List
from tqdm import tqdm
from kipoiseq.dataclasses import Variant, Interval


class VariantQuery:

def __init__(self, func):
self.func = func

def __call__(self, variant: Variant):
return self.func(variant)

def __or__(self, other):
return VariantQuery(lambda variant: self(variant) or other(variant))

def __and__(self, other):
return VariantQuery(lambda variant: self(variant) and other(variant))


class FilterVariantQuery(VariantQuery):

def __init__(self, filter='PASS'):
self.filter = filter

def __call__(self, variant):
return variant.filter == self.filter


class VariantIntervalQuery:

def __init__(self, func):
self.func = func

def __call__(self, variants: List[Variant], interval: Interval):
return self.func(variants, interval)

def __or__(self, other):
return VariantIntervalQuery(
lambda variants, interval: (
i or j for i, j in zip(self(variants, interval),
other(variants, interval))))

def __and__(self, other):
return VariantIntervalQuery(
lambda variants, interval: (
i and j for i, j in zip(self(variants, interval),
other(variants, interval))))


class NumberVariantQuery(VariantIntervalQuery):
"""
Closure for variant query. Filter variants for interval
if number of variants in given limits.
"""

def __init__(self, max_num=float('inf'), min_num=0):
# TODO: sample speficity
self.max_num = max_num
self.min_num = min_num

def __call__(self, variants, interval):
if self.max_num >= len(variants) >= self.min_num:
return [True] * len(variants)
else:
return [False] * len(variants)


_VariantIntervalType = List[Tuple[Iterable[Variant], Interval]]


class VariantIntervalQueryable:

def __init__(self, vcf, variant_intervals: _VariantIntervalType,
progress=False):
"""
Query object of variants.
Args:
vcf: cyvcf2.VCF objects.
variants: iter of (variant, interval) tuples.
"""
self.vcf = vcf

if progress:
self.variant_intervals = tqdm(variant_intervals)
else:
self.variant_intervals = variant_intervals

def __iter__(self):
for variants, interval in self.variant_intervals:
yield from variants

def filter(self, query: VariantQuery):
"""
Filters variant given conduction.
Args:
query: function which get a variant as input and filtered iter of
variants.
"""
self.variant_intervals = [
(filter(query, variants), Interval)
for variants, interval in self.variant_intervals
]
return self

def filter_range(self, query: VariantIntervalQuery):
"""
Filters variant given conduction.
Args:
query: function which get variants and an interval as input
and filtered iter of variants.
"""
self.variant_intervals = list(self._filter_range(query))
return self

def _filter_range(self, query: VariantIntervalQuery):
for variants, interval in self.variant_intervals:
variants = list(variants)
yield (
v
for v, cond in zip(variants, query(variants, interval))
if cond
), interval

def to_vcf(self, path):
"""
Parse query result as vcf file.
Args:
path: path of the file.
"""
from cyvcf2 import Writer
writer = Writer(path, self.vcf)
for v in self:
writer.write_record(v.source)
68 changes: 68 additions & 0 deletions kipoiseq/extractors/vcf_seq.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from pyfaidx import Sequence, complement
from kipoiseq.extractors import BaseExtractor, FastaStringExtractor
from kipoiseq.dataclasses import Variant, Interval
from kipoiseq.extractors.vcf_query import VariantIntervalQueryable


try:
from cyvcf2 import VCF
Expand Down Expand Up @@ -33,8 +35,74 @@ def _region(self, interval):

def has_variant(self, variant, sample_id):
gt_type = variant.source.gt_types[self.sample_mapping[sample_id]]
return self._has_variant_gt(gt_type)

def _has_variant_gt(self, gt_type):
return gt_type != 0 and gt_type != 2

def query_variants(self, intervals, sample_id=None, progress=False):
"""
Fetch variants for given multi-intervals from vcf file
for sample if sample id is given.
Args:
intervals (List[pybedtools.Interval]): list of Interval objects
sample_id (str, optional): sample id in vcf file.
Returns:
VCFQueryable: queryable object whihc allow you to query the
fetched variatns.
Examples:
To fetch variants if only single variant present in interval.
>>> MultiSampleVCF(vcf_path) \
.query_variants(intervals) \
.filter(lambda variant: variant.qual > 10) \
.filter_range(NumberVariantQuery(max_num=1))
.to_vcf(output_path)
"""
pairs = ((self.fetch_variants(i, sample_id=sample_id), i)
for i in intervals)
return VariantIntervalQueryable(self, pairs, progress=progress)

def get_variant(self, variant):
"""
Returns variant from vcf file. Let you use vcf file as dict.
Args:
vcf: cyvcf2.VCF file
variant: variant object or variant id as string.
Returns:
Variant object.
Examples:
>>> MultiSampleVCF(vcf_path).get_variant("chr1:4:T:['C']")
"""
if type(variant) == str:
variant = Variant.from_str(variant)

variants = self.fetch_variants(
Interval(variant.chrom, variant.pos, variant.pos))
for v in variants:
if v.ref == variant.ref and v.alt == variant.alt:
return v
raise KeyError('Variant %s not found in vcf file.' % str(variant))

def get_samples(self, variant):
"""
Fetchs sample names which have given variants
Args:
variant: variant object.
Returns:
Dict[str, int]: Dict of sample which have variant and gt as value.
"""
return dict(filter(lambda x: self._has_variant_gt(x[1]),
zip(self.samples, variant.gt_types)))


class IntervalSeqBuilder(list):
"""
Expand Down
4 changes: 4 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
import pytest

vcf_file = 'tests/data/test.vcf.gz'
sample_5kb_fasta_file = 'tests/data/sample.5kb.fa'
82 changes: 82 additions & 0 deletions tests/extractors/test_vcf_query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import pytest
from conftest import vcf_file
from kipoiseq.dataclasses import Variant, Interval
from kipoiseq.extractors.vcf_seq import MultiSampleVCF
from kipoiseq.extractors.vcf_query import *


@pytest.fixture
def query_true():
return VariantQuery(lambda v: True)


@pytest.fixture
def query_false():
return VariantQuery(lambda v: False)


def test_base_query__and__(query_false, query_true):
assert not (query_false & query_true)(None)


def test_base_query__or__(query_false, query_true):
assert (query_false | query_true)(None)


@pytest.fixture
def variant_queryable():
vcf = MultiSampleVCF(vcf_file)
return VariantIntervalQueryable(vcf, [
(
[
Variant('chr1', 12, 'A', 'T'),
Variant('chr1', 18, 'A', 'C', filter='q10'),
],
Interval('chr1', 10, 20)
),
(
[
Variant('chr2', 120, 'AT', 'AAAT'),
],
Interval('chr2', 110, 200)
)
])


def test_variant_queryable__iter__(variant_queryable):
variants = list(variant_queryable)
assert len(variants) == 3
assert variants[0].ref == 'A'
assert variants[0].alt == 'T'


def test_variant_queryable_filter_1(variant_queryable):
assert len(list(variant_queryable.filter(lambda v: v.alt == 'T'))) == 1


def test_variant_queryable_filter_2(variant_queryable):
assert len(list(variant_queryable.filter(lambda v: v.ref == 'A'))) == 2


def test_variant_filter_range(variant_queryable):
assert 2 == len(list(variant_queryable.filter_range(
lambda variants, interval: (v.ref == 'A' for v in variants))))


def test_VariantQueryable_filter_by_num_max(variant_queryable):
assert 1 == len(list(variant_queryable.filter_range(
NumberVariantQuery(max_num=1))))


def test_VariantQueryable_filter_by_num_min(variant_queryable):
assert 2 == len(list(variant_queryable.filter_range(
NumberVariantQuery(min_num=2))))


def test_VariantQueryable_filter_variant_query_2(variant_queryable):
assert 2 == len(list(variant_queryable.filter(FilterVariantQuery())))


def test_VariantQueryable_filter_variant_query_3(variant_queryable):
assert 3 == len(list(variant_queryable.filter(
FilterVariantQuery() | FilterVariantQuery(filter='q10'))))
34 changes: 31 additions & 3 deletions tests/extractors/test_vcf_seq_extractor.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,26 @@
import pytest
from conftest import vcf_file, sample_5kb_fasta_file
from cyvcf2 import VCF
from pyfaidx import Sequence
from kipoiseq.extractors.vcf_seq import IntervalSeqBuilder
from kipoiseq.dataclasses import Variant, Interval
from kipoiseq.extractors import *

fasta_file = 'tests/data/sample.5kb.fa'
vcf_file = 'tests/data/test.vcf.gz'
fasta_file = sample_5kb_fasta_file

intervals = [
Interval('chr1', 4, 10),
Interval('chr1', 5, 30),
Interval('chr1', 20, 30)
]


@pytest.fixture
def multi_sample_vcf():
return MultiSampleVCF(vcf_file)


def test_multi_sample_vcf_fetch_variant(multi_sample_vcf):
def test_MultiSampleVCF_fetch_variant(multi_sample_vcf):
interval = Interval('chr1', 3, 5)
assert len(list(multi_sample_vcf.fetch_variants(interval))) == 2
assert len(list(multi_sample_vcf.fetch_variants(interval, 'NA00003'))) == 1
Expand All @@ -25,6 +31,28 @@ def test_multi_sample_vcf_fetch_variant(multi_sample_vcf):
assert len(list(multi_sample_vcf.fetch_variants(interval, 'NA00003'))) == 0


def test_MultiSampleVCF_query_variants(multi_sample_vcf):
vq = multi_sample_vcf.query_variants(intervals)
variants = list(vq)
assert len(variants) == 5
assert variants[0].pos == 4
assert variants[1].pos == 5


def test_MultiSampleVCF_get_samples(multi_sample_vcf):
variants = list(multi_sample_vcf)
samples = multi_sample_vcf.get_samples(variants[0])
assert samples == {'NA00003': 3}


def test_MultiSampleVCF_get_variant(multi_sample_vcf):
variant = multi_sample_vcf.get_variant("chr1:4:T>C")
assert variant.chrom == 'chr1'
assert variant.pos == 4
assert variant.ref == 'T'
assert variant.alt == 'C'


@pytest.fixture
def interval_seq_builder():
return IntervalSeqBuilder([
Expand Down

0 comments on commit 81bfb23

Please sign in to comment.