From a1a0994cf8f01007795d22fdeb1f6c846d9951ba Mon Sep 17 00:00:00 2001 From: Anihilatorgunn Date: Wed, 27 Dec 2023 18:43:21 +0300 Subject: [PATCH 1/6] AsyncPmap exception forwarding to main thread --- dpipe/itertools.py | 11 ++++++++--- tests/test_itertools.py | 9 +++++++++ 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/dpipe/itertools.py b/dpipe/itertools.py index 0c15759..0b56a77 100644 --- a/dpipe/itertools.py +++ b/dpipe/itertools.py @@ -134,9 +134,12 @@ def start(self) -> None: self.__working_thread.start() def _prediction_func(self) -> None: - for value in self.__iterable: - self.__result_queue.put(self.__func(value, *self.__args, **self.__kwargs)) - self.__result_queue.put(FinishToken) + try: + for value in self.__iterable: + self.__result_queue.put(self.__func(value, *self.__args, **self.__kwargs)) + self.__result_queue.put(FinishToken) + except BaseException as e: + self.__result_queue.put(e) def __iter__(self): return self @@ -147,6 +150,8 @@ def __next__(self) -> Any: self.__working_thread.join() assert not self.__working_thread.is_alive() raise StopIteration + elif isinstance(obj, BaseException): + raise obj return obj diff --git a/tests/test_itertools.py b/tests/test_itertools.py index 5a9d651..1c2d4ed 100644 --- a/tests/test_itertools.py +++ b/tests/test_itertools.py @@ -75,3 +75,12 @@ def test_async_pmap(self): assert foo(i) == next(async_results) with self.assertRaises(StopIteration): next(async_results) + + def test_async_pmap_exception(self): + def exception_func(x): + raise ValueError + iterable = range(1) + async_results = AsyncPmap(exception_func, iterable) + async_results.start() + with self.assertRaises(ValueError): + out = next(async_results) From 976abe68a9e273a0224291c601fec31988fc4a19 Mon Sep 17 00:00:00 2001 From: Anihilatorgunn Date: Wed, 27 Dec 2023 19:54:45 +0300 Subject: [PATCH 2/6] microfix --- dpipe/itertools.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/dpipe/itertools.py b/dpipe/itertools.py index 0b56a77..0316d4d 100644 --- a/dpipe/itertools.py +++ b/dpipe/itertools.py @@ -151,6 +151,8 @@ def __next__(self) -> Any: assert not self.__working_thread.is_alive() raise StopIteration elif isinstance(obj, BaseException): + self.__working_thread.join() + assert not self.__working_thread.is_alive() raise obj return obj From 50416703d5b6576e9ba1a397d8ad4b95f5760054 Mon Sep 17 00:00:00 2001 From: Philipenko Vladimir Date: Thu, 28 Dec 2023 12:36:01 +0300 Subject: [PATCH 3/6] Proper exhausting --- dpipe/itertools.py | 18 ++++++++++++++---- tests/test_itertools.py | 11 +++++++++++ 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/dpipe/itertools.py b/dpipe/itertools.py index 0316d4d..776a746 100644 --- a/dpipe/itertools.py +++ b/dpipe/itertools.py @@ -129,10 +129,16 @@ def __init__(self, func: Callable, iterable: Iterable, *args, **kwargs) -> None: self.__working_thread = Thread( target = self._prediction_func ) + self.__exhausted = False def start(self) -> None: self.__working_thread.start() + def stop(self) -> None: + self.__working_thread.join() + assert not self.__working_thread.is_alive() + self.__exhausted = True + def _prediction_func(self) -> None: try: for value in self.__iterable: @@ -141,19 +147,23 @@ def _prediction_func(self) -> None: except BaseException as e: self.__result_queue.put(e) + def __iter__(self): return self def __next__(self) -> Any: + if self.__exhausted: + raise StopIteration + obj = self.__result_queue.get() if obj is FinishToken: - self.__working_thread.join() - assert not self.__working_thread.is_alive() + self.stop() raise StopIteration + elif isinstance(obj, BaseException): - self.__working_thread.join() - assert not self.__working_thread.is_alive() + self.stop() raise obj + return obj diff --git a/tests/test_itertools.py b/tests/test_itertools.py index 1c2d4ed..6b153dd 100644 --- a/tests/test_itertools.py +++ b/tests/test_itertools.py @@ -84,3 +84,14 @@ def exception_func(x): async_results.start() with self.assertRaises(ValueError): out = next(async_results) + + def test_async_pmap_stopiteration(self): + iterable = range(1) + async_results = AsyncPmap(lambda x: x, iterable) + async_results.start() + + next(async_results) + with self.assertRaises(StopIteration): + out = next(async_results) + with self.assertRaises(StopIteration): + out = next(async_results) From c1d20951c0ef5465d80a015832c2a9cc8302e99d Mon Sep 17 00:00:00 2001 From: Philipenko Vladimir Date: Thu, 28 Dec 2023 12:36:31 +0300 Subject: [PATCH 4/6] version --- dpipe/__version__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dpipe/__version__.py b/dpipe/__version__.py index e1424ed..73e3bb4 100644 --- a/dpipe/__version__.py +++ b/dpipe/__version__.py @@ -1 +1 @@ -__version__ = '0.3.1' +__version__ = '0.3.2' From 210c8f9cfb599f5687995050a7afd0a7021f08c6 Mon Sep 17 00:00:00 2001 From: Philipenko Vladimir Date: Thu, 28 Dec 2023 15:24:22 +0300 Subject: [PATCH 5/6] Fix --- dpipe/itertools.py | 17 +++++++++-------- tests/test_itertools.py | 19 +++++++++++++++---- 2 files changed, 24 insertions(+), 12 deletions(-) diff --git a/dpipe/itertools.py b/dpipe/itertools.py index 776a746..3e56455 100644 --- a/dpipe/itertools.py +++ b/dpipe/itertools.py @@ -142,10 +142,10 @@ def stop(self) -> None: def _prediction_func(self) -> None: try: for value in self.__iterable: - self.__result_queue.put(self.__func(value, *self.__args, **self.__kwargs)) - self.__result_queue.put(FinishToken) + self.__result_queue.put((self.__func(value, *self.__args, **self.__kwargs), True)) + self.__result_queue.put((FinishToken, True)) except BaseException as e: - self.__result_queue.put(e) + self.__result_queue.put((e, False)) def __iter__(self): @@ -155,15 +155,16 @@ def __next__(self) -> Any: if self.__exhausted: raise StopIteration - obj = self.__result_queue.get() - if obj is FinishToken: - self.stop() - raise StopIteration + obj, success = self.__result_queue.get() - elif isinstance(obj, BaseException): + if not success: self.stop() raise obj + if obj is FinishToken: + self.stop() + raise StopIteration + return obj diff --git a/tests/test_itertools.py b/tests/test_itertools.py index 6b153dd..87c85de 100644 --- a/tests/test_itertools.py +++ b/tests/test_itertools.py @@ -77,13 +77,24 @@ def test_async_pmap(self): next(async_results) def test_async_pmap_exception(self): - def exception_func(x): + exc = ValueError("I shouldn't be raised") + def return_exception_func(x): + return exc + def raise_exception_func(x): raise ValueError + iterable = range(1) - async_results = AsyncPmap(exception_func, iterable) - async_results.start() + + raised_asyncpmap = AsyncPmap(raise_exception_func, iterable) + returned_asyncpmap = AsyncPmap(return_exception_func, iterable) + + raised_asyncpmap.start() + returned_asyncpmap.start() + with self.assertRaises(ValueError): - out = next(async_results) + out = next(raised_asyncpmap) + + assert next(returned_asyncpmap) == exc def test_async_pmap_stopiteration(self): iterable = range(1) From 519a1669302a1ee6a8416cd3811b25028ad1dd15 Mon Sep 17 00:00:00 2001 From: Philipenko Vladimir Date: Thu, 28 Dec 2023 15:39:15 +0300 Subject: [PATCH 6/6] Simplify --- dpipe/itertools.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/dpipe/itertools.py b/dpipe/itertools.py index 3e56455..5edd35d 100644 --- a/dpipe/itertools.py +++ b/dpipe/itertools.py @@ -113,10 +113,6 @@ def pmap(func: Callable, iterable: Iterable, *args, **kwargs) -> Iterable: yield func(value, *args, **kwargs) -class FinishToken: - pass - - class AsyncPmap: def __init__(self, func: Callable, iterable: Iterable, *args, **kwargs) -> None: self.__func = func @@ -143,7 +139,7 @@ def _prediction_func(self) -> None: try: for value in self.__iterable: self.__result_queue.put((self.__func(value, *self.__args, **self.__kwargs), True)) - self.__result_queue.put((FinishToken, True)) + raise StopIteration except BaseException as e: self.__result_queue.put((e, False)) @@ -161,10 +157,6 @@ def __next__(self) -> Any: self.stop() raise obj - if obj is FinishToken: - self.stop() - raise StopIteration - return obj