Skip to content

Commit

Permalink
Respect full_event in ConcurrentExecution (#527)
Browse files Browse the repository at this point in the history
  • Loading branch information
gtopper authored Jun 10, 2024
1 parent a387cb2 commit eed7bfa
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 12 deletions.
36 changes: 25 additions & 11 deletions storey/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ class Choice(Flow):
:type default: Flow
:param name: Name of this step, as it should appear in logs. Defaults to class name (Choice).
:type name: string
:param full_event: Whether user functions should receive and/or return Event objects (when True),
:param full_event: Whether user functions should receive and return Event objects (when True),
or only the payload (when False). Defaults to False.
:type full_event: boolean
"""
Expand Down Expand Up @@ -471,7 +471,7 @@ class Map(_UnaryFunctionFlow):
:type long_running: boolean
:param name: Name of this step, as it should appear in logs. Defaults to class name (Map).
:type name: string
:param full_event: Whether user functions should receive and/or return Event objects (when True),
:param full_event: Whether user functions should receive and return Event objects (when True),
or only the payload (when False). Defaults to False.
:type full_event: boolean
"""
Expand All @@ -491,7 +491,7 @@ class Filter(_UnaryFunctionFlow):
:type long_running: boolean
:param name: Name of this step, as it should appear in logs. Defaults to class name (Filter).
:type name: string
:param full_event: Whether user functions should receive and/or return Event objects (when True), or only the
:param full_event: Whether user functions should receive and return Event objects (when True), or only the
payload (when False). Defaults to False.
:type full_event: boolean
"""
Expand All @@ -511,7 +511,7 @@ class FlatMap(_UnaryFunctionFlow):
:type long_running: boolean
:param name: Name of this step, as it should appear in logs. Defaults to class name (FlatMap).
:type name: string
:param full_event: Whether user functions should receive and/or return Event objects (when True), or only
:param full_event: Whether user functions should receive and return Event objects (when True), or only
the payload (when False). Defaults to False.
:type full_event: boolean
"""
Expand All @@ -533,7 +533,7 @@ class Extend(_UnaryFunctionFlow):
:type long_running: boolean
:param name: Name of this step, as it should appear in logs. Defaults to class name (Extend).
:type name: string
:param full_event: Whether user functions should receive and/or return Event objects (when True), or only the
:param full_event: Whether user functions should receive and return Event objects (when True), or only the
payload (when False). Defaults to False.
:type full_event: boolean
"""
Expand Down Expand Up @@ -739,7 +739,7 @@ class Reduce(Flow):
:type fn: Function ((object, Event) => object)
:param name: Name of this step, as it should appear in logs. Defaults to class name (Reduce).
:type name: string
:param full_event: Whether user functions should receive and/or return Event objects (when True),
:param full_event: Whether user functions should receive and return Event objects (when True),
or only the payload (when False). Defaults to False.
:type full_event: boolean
"""
Expand Down Expand Up @@ -950,6 +950,8 @@ class ConcurrentExecution(_ConcurrentJobExecution):
:param backoff_factor: Wait time in seconds between retries (default 1)
:param pass_context: If False, the process_event function will be called with just one parameter (event). If True,
the process_event function will be called with two parameters (event, context). Defaults to False.
:param full_event: Whether event processor should receive and return Event objects (when True),
or only the payload (when False). Defaults to False.
"""

_supported_concurrency_mechanisms = ["asyncio", "threading", "multiprocessing"]
Expand All @@ -963,6 +965,11 @@ def __init__(
):
super().__init__(**kwargs)

if concurrency_mechanism == "multiprocessing" and kwargs.get("full_event"):
raise ValueError(
'concurrency_mechanism="multiprocessing" may not be used in conjunction with full_event=True'
)

self._event_processor = event_processor

if concurrency_mechanism and concurrency_mechanism not in self._supported_concurrency_mechanisms:
Expand All @@ -986,16 +993,23 @@ def __init__(
self._pass_context = pass_context

async def _process_event(self, event):
args = [event]
args = [event if self._full_event else event.body]

if self._pass_context:
args.append(self.context)
if self._executor:
result = await asyncio.get_running_loop().run_in_executor(self._executor, self._event_processor, *args)
else:
result = self._event_processor(*args)

if asyncio.iscoroutine(result):
result = await result
return result

if self._full_event:
return result
else:
event.body = result
return event

async def _handle_completed(self, event, response):
await self._do_downstream(response)
Expand All @@ -1010,7 +1024,7 @@ class SendToHttp(_ConcurrentJobExecution):
:type join_from_response: Function ((Event, HttpResponse)=>Event)
:param name: Name of this step, as it should appear in logs. Defaults to class name (SendToHttp).
:type name: string
:param full_event: Whether user functions should receive and/or return Event objects (when True),
:param full_event: Whether user functions should receive and return Event objects (when True),
or only the payload (when False). Defaults to False.
:type full_event: boolean
"""
Expand Down Expand Up @@ -1213,7 +1227,7 @@ class JoinWithV3IOTable(_ConcurrentJobExecution):
:type attributes: string
:param name: Name of this step, as it should appear in logs. Defaults to class name (JoinWithV3IOTable).
:type name: string
:param full_event: Whether user functions should receive and/or return Event objects (when True), or only
:param full_event: Whether user functions should receive and return Event objects (when True), or only
the payload (when False). Defaults to False.
:type full_event: boolean
"""
Expand Down Expand Up @@ -1270,7 +1284,7 @@ class JoinWithTable(_ConcurrentJobExecution):
:param join_function: Joins the original event with relevant data received from the storage. Event is dropped when
this function returns None. Defaults to assume the event's body is a dict-like object and updating it.
:param name: Name of this step, as it should appear in logs. Defaults to class name (JoinWithTable).
:param full_event: Whether user functions should receive and/or return Event objects (when True), or only the
:param full_event: Whether user functions should receive and return Event objects (when True), or only the
payload (when False). Defaults to False.
:param context: Context object that holds global configurations and secrets.
"""
Expand Down
40 changes: 39 additions & 1 deletion tests/test_concurrent_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest

from storey import AsyncEmitSource
from storey.flow import ConcurrentExecution, Reduce, build_flow
from storey.flow import Complete, ConcurrentExecution, Reduce, build_flow
from tests.test_flow import append_and_return

event_processing_duration = 0.5
Expand Down Expand Up @@ -76,3 +76,41 @@ async def async_test_concurrent_execution(concurrency_mechanism, event_processor
)
def test_concurrent_execution(concurrency_mechanism, event_processor, pass_context):
asyncio.run(async_test_concurrent_execution(concurrency_mechanism, event_processor, pass_context))


async def async_test_concurrent_execution_multiprocessing_and_complete():
controller = build_flow(
[
AsyncEmitSource(),
ConcurrentExecution(
event_processor=process_event_slow_processing,
concurrency_mechanism="multiprocessing",
max_in_flight=2,
),
Complete(),
]
).run()

event_body = "hello"
try:
res = await controller.emit(event_body)
assert res == event_body
finally:
await controller.terminate()
await controller.await_termination()


def test_concurrent_execution_multiprocessing_and_complete():
asyncio.run(async_test_concurrent_execution_multiprocessing_and_complete())


def test_concurrent_execution_multiprocessing_and_full_event():
with pytest.raises(
ValueError,
match='concurrency_mechanism="multiprocessing" may not be used in conjunction with full_event=True',
):
ConcurrentExecution(
event_processor=process_event_slow_processing,
concurrency_mechanism="multiprocessing",
full_event=True,
)

0 comments on commit eed7bfa

Please sign in to comment.