diff --git a/storey/sources.py b/storey/sources.py index 0abca29a..c99c37e4 100644 --- a/storey/sources.py +++ b/storey/sources.py @@ -175,9 +175,17 @@ def emit( self._emit_fn(event) return awaitable_result - def terminate(self): - """Terminates the associated flow.""" + def terminate(self, wait=False): + """ + Terminates the associated flow. + + :param wait: Whether to wait for the flow to terminate before returning. + + :returns: None if wait=False. If wait=True, the termination result will be returned. + """ self._emit_fn(_termination_obj) + if wait: + return self._await_termination_fn() def await_termination(self): """Awaits the termination of the flow. To be called after terminate. Returns the termination result of the @@ -482,9 +490,17 @@ async def emit( raise result return result - async def terminate(self): - """Terminates the associated flow.""" + async def terminate(self, wait=False): + """ + Terminates the associated flow. + + :param wait: Whether to wait for the flow to terminate before returning. + + :returns: None if wait=False. If wait=True, the termination result will be returned. + """ await self._emit_fn(_termination_obj) + if wait: + return await self.await_termination() async def await_termination(self): """ diff --git a/tests/test_flow.py b/tests/test_flow.py index 65f70a1c..233d8726 100644 --- a/tests/test_flow.py +++ b/tests/test_flow.py @@ -185,8 +185,7 @@ def test_offset_commit(): event.shard_id = shard event.offset = offset controller.emit(event) - controller.terminate() - termination_result = controller.await_termination() + termination_result = controller.terminate(wait=True) assert termination_result == 330 offsets = copy.copy(platform.offsets) @@ -225,9 +224,8 @@ async def async_offset_commit(): try: assert offsets == {("/", i): num_records_per_shard for i in range(num_shards)} finally: - await controller.terminate() + termination_result = await controller.terminate(wait=True) - termination_result = await controller.await_termination() assert termination_result == 330