diff --git a/tests/batch_iter/test_pipeline.py b/tests/batch_iter/test_pipeline.py index 4871603..e1b2047 100644 --- a/tests/batch_iter/test_pipeline.py +++ b/tests/batch_iter/test_pipeline.py @@ -1,4 +1,5 @@ import time +from collections import Counter from itertools import repeat import pytest @@ -38,25 +39,32 @@ def test_parallel(): assert abs(faster - delta / 2) < sleep -# TODO: uncomment as soon as #68 is solved -# def test_loky(): -# size = 100 -# for i, item in enumerate(wrap_pipeline(range(size), Loky(lambda x: x ** 2, n_workers=2))): -# assert item == i ** 2 -# assert i == size - 1 -# # at this point the first worker is killed -# # start a new one -# for i, item in enumerate(wrap_pipeline(range(size), Loky(lambda x: x ** 2, n_workers=2))): -# assert item == i ** 2 -# assert i == size - 1 - -# # several workers -# for i, item in enumerate(wrap_pipeline( -# range(size), -# Loky(lambda x: x ** 2, n_workers=2), -# Loky(lambda x: x ** 2, n_workers=2))): -# assert item == i ** 4 -# assert i == size - 1 +# TODO: check order of output itmes as soon as #68 is solved +def test_loky(): + size = 100 + + source_items = list(range(size)) + items = [] + + for i, item in enumerate(wrap_pipeline(source_items, Loky(lambda x: x ** 2, n_workers=2))): + items.append(item) + assert Counter(items) == Counter(map(lambda x: x ** 2, source_items)) + # at this point the first worker is killed + # start a new one + items = [] + for i, item in enumerate(wrap_pipeline(range(size), Loky(lambda x: x ** 2, n_workers=2))): + items.append(item) + assert Counter(items) == Counter(map(lambda x: x ** 2, source_items)) + + # several workers + items = [] + for i, item in enumerate(wrap_pipeline( + range(size), + Loky(lambda x: x ** 2, n_workers=2), + Loky(lambda x: x ** 2, n_workers=2))): + items.append(item) + assert Counter(items) == Counter(map(lambda x: x ** 4, source_items)) + assert i == size - 1 def test_premature_stop():