Skip to content

Commit

Permalink
Fix AsyncEmitSource (#499)
Browse files Browse the repository at this point in the history
* Fix `AsyncEmitSource`

By replacing the queue timeout implementation.

[ML-4421](https://jira.iguazeng.com/browse/ML-4421)

* Delete print
  • Loading branch information
gtopper authored Feb 21, 2024
1 parent d063afd commit 33bf028
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 9 deletions.
45 changes: 45 additions & 0 deletions storey/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
#
import asyncio
import collections


class AsyncQueue(asyncio.Queue):
Expand Down Expand Up @@ -51,3 +52,47 @@ def peek_nowait(self):

def _peek(self):
return self._queue[0]


def _release_waiter(waiter):
if not waiter.done():
waiter.set_result(False)


class SimpleAsyncQueue:
"""
A simple async queue with built-in timeout.
"""

def __init__(self, capacity):
self._capacity = capacity
self._deque = collections.deque()
self._not_empty_futures = collections.deque()
self._loop = asyncio.get_running_loop()

async def get(self, timeout=None):
if not self._deque:
not_empty_future = asyncio.get_running_loop().create_future()
self._not_empty_futures.append(not_empty_future)
if timeout is None:
await not_empty_future
else:
self._loop.call_later(timeout, _release_waiter, not_empty_future)
got_result = await not_empty_future
if not got_result:
raise TimeoutError(f"Queue get() timed out after {timeout} seconds")

result = self._deque.popleft()
return result

async def put(self, item):
while self._not_empty_futures:
not_empty_future = self._not_empty_futures.popleft()
if not not_empty_future.done():
not_empty_future.set_result(True)
break

return self._deque.append(item)

def empty(self):
return len(self._deque) == 0
7 changes: 4 additions & 3 deletions storey/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@

from .dtypes import Event, _termination_obj
from .flow import Complete, Flow
from .queue import SimpleAsyncQueue
from .utils import find_filters, find_partitions, url_to_file_system


Expand Down Expand Up @@ -601,7 +602,7 @@ def _init(self):
super()._init()
self._is_terminated = False
self._outstanding_offsets = defaultdict(list)
self._q = asyncio.Queue(self._buffer_size)
self._q = SimpleAsyncQueue(self._buffer_size)

async def _run_loop(self):
committer = None
Expand All @@ -624,9 +625,9 @@ async def _run_loop(self):
# In case we can't block because there are outstanding events
while num_offsets_not_handled > 0:
try:
event = await asyncio.wait_for(self._q.get(), self._max_wait_before_commit)
event = await self._q.get(self._max_wait_before_commit)
break
except asyncio.TimeoutError:
except TimeoutError:
pass
num_offsets_not_handled = await _commit_handled_events(self._outstanding_offsets, committer)
events_handled_since_commit = 0
Expand Down
18 changes: 12 additions & 6 deletions tests/test_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ async def async_offset_commit():

controller = build_flow(
[
AsyncEmitSource(context=context, explicit_ack=True),
AsyncEmitSource(context=context, explicit_ack=True, max_wait_before_commit=1),
Map(lambda x: x + 1),
Filter(lambda x: x < 3),
FlatMap(lambda x: [x, x * 10]),
Expand All @@ -208,13 +208,19 @@ async def async_offset_commit():
event.shard_id = shard
event.offset = offset
await controller.emit(event)
print()
await controller.terminate()
termination_result = await controller.await_termination()
assert termination_result == 330
del event

# Make sure that offsets are committed even before termination
await asyncio.sleep(2)
offsets = copy.copy(platform.offsets)
assert offsets == {("/", i): num_records_per_shard for i in range(num_shards)}

try:
assert offsets == {("/", i): num_records_per_shard for i in range(num_shards)}
finally:
await controller.terminate()

termination_result = await controller.await_termination()
assert termination_result == 330


def test_async_offset_commit():
Expand Down
28 changes: 28 additions & 0 deletions tests/test_queue.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import asyncio

import pytest

from storey.queue import SimpleAsyncQueue


async def async_test_simple_async_queue():
q = SimpleAsyncQueue(2)

with pytest.raises(TimeoutError):
await q.get(0)

get_task = asyncio.create_task(q.get(1))
await q.put("x")
assert await get_task == "x"

await q.put("x")
await q.put("y")
put_task = asyncio.create_task(q.put("z"))
assert await q.get() == "x"
await put_task
assert await q.get() == "y"
assert await q.get() == "z"


def test_simple_async_queue():
asyncio.run(async_test_simple_async_queue())

0 comments on commit 33bf028

Please sign in to comment.