diff --git a/dpipe/__version__.py b/dpipe/__version__.py index e1424ed..73e3bb4 100644 --- a/dpipe/__version__.py +++ b/dpipe/__version__.py @@ -1 +1 @@ -__version__ = '0.3.1' +__version__ = '0.3.2' diff --git a/dpipe/itertools.py b/dpipe/itertools.py index 0c15759..5edd35d 100644 --- a/dpipe/itertools.py +++ b/dpipe/itertools.py @@ -113,10 +113,6 @@ def pmap(func: Callable, iterable: Iterable, *args, **kwargs) -> Iterable: yield func(value, *args, **kwargs) -class FinishToken: - pass - - class AsyncPmap: def __init__(self, func: Callable, iterable: Iterable, *args, **kwargs) -> None: self.__func = func @@ -129,24 +125,38 @@ def __init__(self, func: Callable, iterable: Iterable, *args, **kwargs) -> None: self.__working_thread = Thread( target = self._prediction_func ) + self.__exhausted = False def start(self) -> None: self.__working_thread.start() + def stop(self) -> None: + self.__working_thread.join() + assert not self.__working_thread.is_alive() + self.__exhausted = True + def _prediction_func(self) -> None: - for value in self.__iterable: - self.__result_queue.put(self.__func(value, *self.__args, **self.__kwargs)) - self.__result_queue.put(FinishToken) + try: + for value in self.__iterable: + self.__result_queue.put((self.__func(value, *self.__args, **self.__kwargs), True)) + raise StopIteration + except BaseException as e: + self.__result_queue.put((e, False)) + def __iter__(self): return self def __next__(self) -> Any: - obj = self.__result_queue.get() - if obj is FinishToken: - self.__working_thread.join() - assert not self.__working_thread.is_alive() + if self.__exhausted: raise StopIteration + + obj, success = self.__result_queue.get() + + if not success: + self.stop() + raise obj + return obj diff --git a/tests/test_itertools.py b/tests/test_itertools.py index 5a9d651..87c85de 100644 --- a/tests/test_itertools.py +++ b/tests/test_itertools.py @@ -75,3 +75,34 @@ def test_async_pmap(self): assert foo(i) == next(async_results) with self.assertRaises(StopIteration): next(async_results) + + def test_async_pmap_exception(self): + exc = ValueError("I shouldn't be raised") + def return_exception_func(x): + return exc + def raise_exception_func(x): + raise ValueError + + iterable = range(1) + + raised_asyncpmap = AsyncPmap(raise_exception_func, iterable) + returned_asyncpmap = AsyncPmap(return_exception_func, iterable) + + raised_asyncpmap.start() + returned_asyncpmap.start() + + with self.assertRaises(ValueError): + out = next(raised_asyncpmap) + + assert next(returned_asyncpmap) == exc + + def test_async_pmap_stopiteration(self): + iterable = range(1) + async_results = AsyncPmap(lambda x: x, iterable) + async_results.start() + + next(async_results) + with self.assertRaises(StopIteration): + out = next(async_results) + with self.assertRaises(StopIteration): + out = next(async_results)