Skip to content

Commit

Permalink
Merge pull request #86 from neuro-ml/develop
Browse files Browse the repository at this point in the history
Pass exceptions to the main thread in `AsyncPmap`
  • Loading branch information
vovaf709 authored Dec 28, 2023
2 parents 913e671 + 519a166 commit 778d36b
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 12 deletions.
2 changes: 1 addition & 1 deletion dpipe/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.3.1'
__version__ = '0.3.2'
32 changes: 21 additions & 11 deletions dpipe/itertools.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,6 @@ def pmap(func: Callable, iterable: Iterable, *args, **kwargs) -> Iterable:
yield func(value, *args, **kwargs)


class FinishToken:
pass


class AsyncPmap:
def __init__(self, func: Callable, iterable: Iterable, *args, **kwargs) -> None:
self.__func = func
Expand All @@ -129,24 +125,38 @@ def __init__(self, func: Callable, iterable: Iterable, *args, **kwargs) -> None:
self.__working_thread = Thread(
target = self._prediction_func
)
self.__exhausted = False

def start(self) -> None:
self.__working_thread.start()

def stop(self) -> None:
self.__working_thread.join()
assert not self.__working_thread.is_alive()
self.__exhausted = True

def _prediction_func(self) -> None:
for value in self.__iterable:
self.__result_queue.put(self.__func(value, *self.__args, **self.__kwargs))
self.__result_queue.put(FinishToken)
try:
for value in self.__iterable:
self.__result_queue.put((self.__func(value, *self.__args, **self.__kwargs), True))
raise StopIteration
except BaseException as e:
self.__result_queue.put((e, False))


def __iter__(self):
return self

def __next__(self) -> Any:
obj = self.__result_queue.get()
if obj is FinishToken:
self.__working_thread.join()
assert not self.__working_thread.is_alive()
if self.__exhausted:
raise StopIteration

obj, success = self.__result_queue.get()

if not success:
self.stop()
raise obj

return obj


Expand Down
31 changes: 31 additions & 0 deletions tests/test_itertools.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,34 @@ def test_async_pmap(self):
assert foo(i) == next(async_results)
with self.assertRaises(StopIteration):
next(async_results)

def test_async_pmap_exception(self):
exc = ValueError("I shouldn't be raised")
def return_exception_func(x):
return exc
def raise_exception_func(x):
raise ValueError

iterable = range(1)

raised_asyncpmap = AsyncPmap(raise_exception_func, iterable)
returned_asyncpmap = AsyncPmap(return_exception_func, iterable)

raised_asyncpmap.start()
returned_asyncpmap.start()

with self.assertRaises(ValueError):
out = next(raised_asyncpmap)

assert next(returned_asyncpmap) == exc

def test_async_pmap_stopiteration(self):
iterable = range(1)
async_results = AsyncPmap(lambda x: x, iterable)
async_results.start()

next(async_results)
with self.assertRaises(StopIteration):
out = next(async_results)
with self.assertRaises(StopIteration):
out = next(async_results)

0 comments on commit 778d36b

Please sign in to comment.