From 33bf028f6235e6542d86df85669478c6d906b7fa Mon Sep 17 00:00:00 2001 From: Gal Topper Date: Wed, 21 Feb 2024 15:36:06 +0800 Subject: [PATCH] Fix `AsyncEmitSource` (#499) * Fix `AsyncEmitSource` By replacing the queue timeout implementation. [ML-4421](https://jira.iguazeng.com/browse/ML-4421) * Delete print --- storey/queue.py | 45 +++++++++++++++++++++++++++++++++++++++++++++ storey/sources.py | 7 ++++--- tests/test_flow.py | 18 ++++++++++++------ tests/test_queue.py | 28 ++++++++++++++++++++++++++++ 4 files changed, 89 insertions(+), 9 deletions(-) create mode 100644 tests/test_queue.py diff --git a/storey/queue.py b/storey/queue.py index 89937db6..0793277f 100644 --- a/storey/queue.py +++ b/storey/queue.py @@ -13,6 +13,7 @@ # limitations under the License. # import asyncio +import collections class AsyncQueue(asyncio.Queue): @@ -51,3 +52,47 @@ def peek_nowait(self): def _peek(self): return self._queue[0] + + +def _release_waiter(waiter): + if not waiter.done(): + waiter.set_result(False) + + +class SimpleAsyncQueue: + """ + A simple async queue with built-in timeout. + """ + + def __init__(self, capacity): + self._capacity = capacity + self._deque = collections.deque() + self._not_empty_futures = collections.deque() + self._loop = asyncio.get_running_loop() + + async def get(self, timeout=None): + if not self._deque: + not_empty_future = asyncio.get_running_loop().create_future() + self._not_empty_futures.append(not_empty_future) + if timeout is None: + await not_empty_future + else: + self._loop.call_later(timeout, _release_waiter, not_empty_future) + got_result = await not_empty_future + if not got_result: + raise TimeoutError(f"Queue get() timed out after {timeout} seconds") + + result = self._deque.popleft() + return result + + async def put(self, item): + while self._not_empty_futures: + not_empty_future = self._not_empty_futures.popleft() + if not not_empty_future.done(): + not_empty_future.set_result(True) + break + + return self._deque.append(item) + + def empty(self): + return len(self._deque) == 0 diff --git a/storey/sources.py b/storey/sources.py index 22f09b49..795ee5b4 100644 --- a/storey/sources.py +++ b/storey/sources.py @@ -35,6 +35,7 @@ from .dtypes import Event, _termination_obj from .flow import Complete, Flow +from .queue import SimpleAsyncQueue from .utils import find_filters, find_partitions, url_to_file_system @@ -601,7 +602,7 @@ def _init(self): super()._init() self._is_terminated = False self._outstanding_offsets = defaultdict(list) - self._q = asyncio.Queue(self._buffer_size) + self._q = SimpleAsyncQueue(self._buffer_size) async def _run_loop(self): committer = None @@ -624,9 +625,9 @@ async def _run_loop(self): # In case we can't block because there are outstanding events while num_offsets_not_handled > 0: try: - event = await asyncio.wait_for(self._q.get(), self._max_wait_before_commit) + event = await self._q.get(self._max_wait_before_commit) break - except asyncio.TimeoutError: + except TimeoutError: pass num_offsets_not_handled = await _commit_handled_events(self._outstanding_offsets, committer) events_handled_since_commit = 0 diff --git a/tests/test_flow.py b/tests/test_flow.py index c611fec9..390d73f0 100644 --- a/tests/test_flow.py +++ b/tests/test_flow.py @@ -191,7 +191,7 @@ async def async_offset_commit(): controller = build_flow( [ - AsyncEmitSource(context=context, explicit_ack=True), + AsyncEmitSource(context=context, explicit_ack=True, max_wait_before_commit=1), Map(lambda x: x + 1), Filter(lambda x: x < 3), FlatMap(lambda x: [x, x * 10]), @@ -208,13 +208,19 @@ async def async_offset_commit(): event.shard_id = shard event.offset = offset await controller.emit(event) - print() - await controller.terminate() - termination_result = await controller.await_termination() - assert termination_result == 330 + del event + # Make sure that offsets are committed even before termination + await asyncio.sleep(2) offsets = copy.copy(platform.offsets) - assert offsets == {("/", i): num_records_per_shard for i in range(num_shards)} + + try: + assert offsets == {("/", i): num_records_per_shard for i in range(num_shards)} + finally: + await controller.terminate() + + termination_result = await controller.await_termination() + assert termination_result == 330 def test_async_offset_commit(): diff --git a/tests/test_queue.py b/tests/test_queue.py new file mode 100644 index 00000000..e94a7c43 --- /dev/null +++ b/tests/test_queue.py @@ -0,0 +1,28 @@ +import asyncio + +import pytest + +from storey.queue import SimpleAsyncQueue + + +async def async_test_simple_async_queue(): + q = SimpleAsyncQueue(2) + + with pytest.raises(TimeoutError): + await q.get(0) + + get_task = asyncio.create_task(q.get(1)) + await q.put("x") + assert await get_task == "x" + + await q.put("x") + await q.put("y") + put_task = asyncio.create_task(q.put("z")) + assert await q.get() == "x" + await put_task + assert await q.get() == "y" + assert await q.get() == "z" + + +def test_simple_async_queue(): + asyncio.run(async_test_simple_async_queue())