diff --git a/storey/flow.py b/storey/flow.py index 37bef0af..f22fd1ff 100644 --- a/storey/flow.py +++ b/storey/flow.py @@ -358,7 +358,7 @@ class Choice(Flow): :type default: Flow :param name: Name of this step, as it should appear in logs. Defaults to class name (Choice). :type name: string - :param full_event: Whether user functions should receive and/or return Event objects (when True), + :param full_event: Whether user functions should receive and return Event objects (when True), or only the payload (when False). Defaults to False. :type full_event: boolean """ @@ -471,7 +471,7 @@ class Map(_UnaryFunctionFlow): :type long_running: boolean :param name: Name of this step, as it should appear in logs. Defaults to class name (Map). :type name: string - :param full_event: Whether user functions should receive and/or return Event objects (when True), + :param full_event: Whether user functions should receive and return Event objects (when True), or only the payload (when False). Defaults to False. :type full_event: boolean """ @@ -491,7 +491,7 @@ class Filter(_UnaryFunctionFlow): :type long_running: boolean :param name: Name of this step, as it should appear in logs. Defaults to class name (Filter). :type name: string - :param full_event: Whether user functions should receive and/or return Event objects (when True), or only the + :param full_event: Whether user functions should receive and return Event objects (when True), or only the payload (when False). Defaults to False. :type full_event: boolean """ @@ -511,7 +511,7 @@ class FlatMap(_UnaryFunctionFlow): :type long_running: boolean :param name: Name of this step, as it should appear in logs. Defaults to class name (FlatMap). :type name: string - :param full_event: Whether user functions should receive and/or return Event objects (when True), or only + :param full_event: Whether user functions should receive and return Event objects (when True), or only the payload (when False). Defaults to False. :type full_event: boolean """ @@ -533,7 +533,7 @@ class Extend(_UnaryFunctionFlow): :type long_running: boolean :param name: Name of this step, as it should appear in logs. Defaults to class name (Extend). :type name: string - :param full_event: Whether user functions should receive and/or return Event objects (when True), or only the + :param full_event: Whether user functions should receive and return Event objects (when True), or only the payload (when False). Defaults to False. :type full_event: boolean """ @@ -739,7 +739,7 @@ class Reduce(Flow): :type fn: Function ((object, Event) => object) :param name: Name of this step, as it should appear in logs. Defaults to class name (Reduce). :type name: string - :param full_event: Whether user functions should receive and/or return Event objects (when True), + :param full_event: Whether user functions should receive and return Event objects (when True), or only the payload (when False). Defaults to False. :type full_event: boolean """ @@ -950,6 +950,8 @@ class ConcurrentExecution(_ConcurrentJobExecution): :param backoff_factor: Wait time in seconds between retries (default 1) :param pass_context: If False, the process_event function will be called with just one parameter (event). If True, the process_event function will be called with two parameters (event, context). Defaults to False. + :param full_event: Whether event processor should receive and return Event objects (when True), + or only the payload (when False). Defaults to False. """ _supported_concurrency_mechanisms = ["asyncio", "threading", "multiprocessing"] @@ -963,6 +965,11 @@ def __init__( ): super().__init__(**kwargs) + if concurrency_mechanism == "multiprocessing" and kwargs.get("full_event"): + raise ValueError( + 'concurrency_mechanism="multiprocessing" may not be used in conjunction with full_event=True' + ) + self._event_processor = event_processor if concurrency_mechanism and concurrency_mechanism not in self._supported_concurrency_mechanisms: @@ -986,16 +993,23 @@ def __init__( self._pass_context = pass_context async def _process_event(self, event): - args = [event] + args = [event if self._full_event else event.body] + if self._pass_context: args.append(self.context) if self._executor: result = await asyncio.get_running_loop().run_in_executor(self._executor, self._event_processor, *args) else: result = self._event_processor(*args) + if asyncio.iscoroutine(result): result = await result - return result + + if self._full_event: + return result + else: + event.body = result + return event async def _handle_completed(self, event, response): await self._do_downstream(response) @@ -1010,7 +1024,7 @@ class SendToHttp(_ConcurrentJobExecution): :type join_from_response: Function ((Event, HttpResponse)=>Event) :param name: Name of this step, as it should appear in logs. Defaults to class name (SendToHttp). :type name: string - :param full_event: Whether user functions should receive and/or return Event objects (when True), + :param full_event: Whether user functions should receive and return Event objects (when True), or only the payload (when False). Defaults to False. :type full_event: boolean """ @@ -1213,7 +1227,7 @@ class JoinWithV3IOTable(_ConcurrentJobExecution): :type attributes: string :param name: Name of this step, as it should appear in logs. Defaults to class name (JoinWithV3IOTable). :type name: string - :param full_event: Whether user functions should receive and/or return Event objects (when True), or only + :param full_event: Whether user functions should receive and return Event objects (when True), or only the payload (when False). Defaults to False. :type full_event: boolean """ @@ -1270,7 +1284,7 @@ class JoinWithTable(_ConcurrentJobExecution): :param join_function: Joins the original event with relevant data received from the storage. Event is dropped when this function returns None. Defaults to assume the event's body is a dict-like object and updating it. :param name: Name of this step, as it should appear in logs. Defaults to class name (JoinWithTable). - :param full_event: Whether user functions should receive and/or return Event objects (when True), or only the + :param full_event: Whether user functions should receive and return Event objects (when True), or only the payload (when False). Defaults to False. :param context: Context object that holds global configurations and secrets. """ diff --git a/tests/test_concurrent_execution.py b/tests/test_concurrent_execution.py index a1f84bbd..59b957e0 100644 --- a/tests/test_concurrent_execution.py +++ b/tests/test_concurrent_execution.py @@ -4,7 +4,7 @@ import pytest from storey import AsyncEmitSource -from storey.flow import ConcurrentExecution, Reduce, build_flow +from storey.flow import Complete, ConcurrentExecution, Reduce, build_flow from tests.test_flow import append_and_return event_processing_duration = 0.5 @@ -76,3 +76,41 @@ async def async_test_concurrent_execution(concurrency_mechanism, event_processor ) def test_concurrent_execution(concurrency_mechanism, event_processor, pass_context): asyncio.run(async_test_concurrent_execution(concurrency_mechanism, event_processor, pass_context)) + + +async def async_test_concurrent_execution_multiprocessing_and_complete(): + controller = build_flow( + [ + AsyncEmitSource(), + ConcurrentExecution( + event_processor=process_event_slow_processing, + concurrency_mechanism="multiprocessing", + max_in_flight=2, + ), + Complete(), + ] + ).run() + + event_body = "hello" + try: + res = await controller.emit(event_body) + assert res == event_body + finally: + await controller.terminate() + await controller.await_termination() + + +def test_concurrent_execution_multiprocessing_and_complete(): + asyncio.run(async_test_concurrent_execution_multiprocessing_and_complete()) + + +def test_concurrent_execution_multiprocessing_and_full_event(): + with pytest.raises( + ValueError, + match='concurrency_mechanism="multiprocessing" may not be used in conjunction with full_event=True', + ): + ConcurrentExecution( + event_processor=process_event_slow_processing, + concurrency_mechanism="multiprocessing", + full_event=True, + )