Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
vovaf709 committed Dec 28, 2023
1 parent c1d2095 commit 210c8f9
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 12 deletions.
17 changes: 9 additions & 8 deletions dpipe/itertools.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,10 @@ def stop(self) -> None:
def _prediction_func(self) -> None:
try:
for value in self.__iterable:
self.__result_queue.put(self.__func(value, *self.__args, **self.__kwargs))
self.__result_queue.put(FinishToken)
self.__result_queue.put((self.__func(value, *self.__args, **self.__kwargs), True))
self.__result_queue.put((FinishToken, True))
except BaseException as e:
self.__result_queue.put(e)
self.__result_queue.put((e, False))


def __iter__(self):
Expand All @@ -155,15 +155,16 @@ def __next__(self) -> Any:
if self.__exhausted:
raise StopIteration

obj = self.__result_queue.get()
if obj is FinishToken:
self.stop()
raise StopIteration
obj, success = self.__result_queue.get()

elif isinstance(obj, BaseException):
if not success:
self.stop()
raise obj

if obj is FinishToken:
self.stop()
raise StopIteration

return obj


Expand Down
19 changes: 15 additions & 4 deletions tests/test_itertools.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,24 @@ def test_async_pmap(self):
next(async_results)

def test_async_pmap_exception(self):
def exception_func(x):
exc = ValueError("I shouldn't be raised")
def return_exception_func(x):
return exc
def raise_exception_func(x):
raise ValueError

iterable = range(1)
async_results = AsyncPmap(exception_func, iterable)
async_results.start()

raised_asyncpmap = AsyncPmap(raise_exception_func, iterable)
returned_asyncpmap = AsyncPmap(return_exception_func, iterable)

raised_asyncpmap.start()
returned_asyncpmap.start()

with self.assertRaises(ValueError):
out = next(async_results)
out = next(raised_asyncpmap)

assert next(returned_asyncpmap) == exc

def test_async_pmap_stopiteration(self):
iterable = range(1)
Expand Down

0 comments on commit 210c8f9

Please sign in to comment.