From 210c8f9cfb599f5687995050a7afd0a7021f08c6 Mon Sep 17 00:00:00 2001 From: Philipenko Vladimir Date: Thu, 28 Dec 2023 15:24:22 +0300 Subject: [PATCH] Fix --- dpipe/itertools.py | 17 +++++++++-------- tests/test_itertools.py | 19 +++++++++++++++---- 2 files changed, 24 insertions(+), 12 deletions(-) diff --git a/dpipe/itertools.py b/dpipe/itertools.py index 776a746..3e56455 100644 --- a/dpipe/itertools.py +++ b/dpipe/itertools.py @@ -142,10 +142,10 @@ def stop(self) -> None: def _prediction_func(self) -> None: try: for value in self.__iterable: - self.__result_queue.put(self.__func(value, *self.__args, **self.__kwargs)) - self.__result_queue.put(FinishToken) + self.__result_queue.put((self.__func(value, *self.__args, **self.__kwargs), True)) + self.__result_queue.put((FinishToken, True)) except BaseException as e: - self.__result_queue.put(e) + self.__result_queue.put((e, False)) def __iter__(self): @@ -155,15 +155,16 @@ def __next__(self) -> Any: if self.__exhausted: raise StopIteration - obj = self.__result_queue.get() - if obj is FinishToken: - self.stop() - raise StopIteration + obj, success = self.__result_queue.get() - elif isinstance(obj, BaseException): + if not success: self.stop() raise obj + if obj is FinishToken: + self.stop() + raise StopIteration + return obj diff --git a/tests/test_itertools.py b/tests/test_itertools.py index 6b153dd..87c85de 100644 --- a/tests/test_itertools.py +++ b/tests/test_itertools.py @@ -77,13 +77,24 @@ def test_async_pmap(self): next(async_results) def test_async_pmap_exception(self): - def exception_func(x): + exc = ValueError("I shouldn't be raised") + def return_exception_func(x): + return exc + def raise_exception_func(x): raise ValueError + iterable = range(1) - async_results = AsyncPmap(exception_func, iterable) - async_results.start() + + raised_asyncpmap = AsyncPmap(raise_exception_func, iterable) + returned_asyncpmap = AsyncPmap(return_exception_func, iterable) + + raised_asyncpmap.start() + returned_asyncpmap.start() + with self.assertRaises(ValueError): - out = next(async_results) + out = next(raised_asyncpmap) + + assert next(returned_asyncpmap) == exc def test_async_pmap_stopiteration(self): iterable = range(1)