Skip to content

Commit

Permalink
Add API to terminate and await termination (#549)
Browse files Browse the repository at this point in the history
To make it easier for the caller to do both regardless of whether source is sync or async.

[ML-8738](https://iguazio.atlassian.net/browse/ML-8738)
  • Loading branch information
gtopper authored Dec 16, 2024
1 parent a00e997 commit ee998ab
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 8 deletions.
24 changes: 20 additions & 4 deletions storey/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,9 +175,17 @@ def emit(
self._emit_fn(event)
return awaitable_result

def terminate(self):
"""Terminates the associated flow."""
def terminate(self, wait=False):
"""
Terminates the associated flow.
:param wait: Whether to wait for the flow to terminate before returning.
:returns: None if wait=False. If wait=True, the termination result will be returned.
"""
self._emit_fn(_termination_obj)
if wait:
return self._await_termination_fn()

def await_termination(self):
"""Awaits the termination of the flow. To be called after terminate. Returns the termination result of the
Expand Down Expand Up @@ -482,9 +490,17 @@ async def emit(
raise result
return result

async def terminate(self):
"""Terminates the associated flow."""
async def terminate(self, wait=False):
"""
Terminates the associated flow.
:param wait: Whether to wait for the flow to terminate before returning.
:returns: None if wait=False. If wait=True, the termination result will be returned.
"""
await self._emit_fn(_termination_obj)
if wait:
return await self.await_termination()

async def await_termination(self):
"""
Expand Down
6 changes: 2 additions & 4 deletions tests/test_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,7 @@ def test_offset_commit():
event.shard_id = shard
event.offset = offset
controller.emit(event)
controller.terminate()
termination_result = controller.await_termination()
termination_result = controller.terminate(wait=True)
assert termination_result == 330

offsets = copy.copy(platform.offsets)
Expand Down Expand Up @@ -225,9 +224,8 @@ async def async_offset_commit():
try:
assert offsets == {("/", i): num_records_per_shard for i in range(num_shards)}
finally:
await controller.terminate()
termination_result = await controller.terminate(wait=True)

termination_result = await controller.await_termination()
assert termination_result == 330


Expand Down

0 comments on commit ee998ab

Please sign in to comment.