Skip to content

Commit

Permalink
Merge pull request #13 (Add experimental parallel steps support)
Browse files Browse the repository at this point in the history
Add experimental parallel steps support
  • Loading branch information
amh4r authored Nov 11, 2023
2 parents aa257d2 + 6f8445f commit 813d706
Show file tree
Hide file tree
Showing 12 changed files with 370 additions and 19 deletions.
16 changes: 16 additions & 0 deletions inngest/_internal/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,8 +228,15 @@ async def call_function(
call: execution.Call,
fn_id: str,
req_sig: net.RequestSignature,
target_hashed_id: str,
) -> CommResponse:
"""Handle a function call from the Executor."""

if target_hashed_id == execution.UNSPECIFIED_STEP_ID:
target_step_id = None
else:
target_step_id = target_hashed_id

middleware = middleware_lib.MiddlewareManager.from_client(self._client)

# Validate the request signature.
Expand All @@ -248,6 +255,7 @@ async def call_function(
fn_id,
execution.TransformableInput(logger=self._client.logger),
middleware,
target_step_id,
)

return await self._respond(middleware, call_res)
Expand All @@ -258,8 +266,15 @@ def call_function_sync(
call: execution.Call,
fn_id: str,
req_sig: net.RequestSignature,
target_hashed_id: str,
) -> CommResponse:
"""Handle a function call from the Executor."""

if target_hashed_id == execution.UNSPECIFIED_STEP_ID:
target_step_id = None
else:
target_step_id = target_hashed_id

middleware = middleware_lib.MiddlewareManager.from_client(self._client)

# Validate the request signature.
Expand All @@ -278,6 +293,7 @@ def call_function_sync(
fn_id,
execution.TransformableInput(logger=self._client.logger),
middleware,
target_step_id,
)

return self._respond_sync(middleware, call_res)
Expand Down
5 changes: 5 additions & 0 deletions inngest/_internal/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ class HeaderKey(enum.Enum):
USER_AGENT = "User-Agent"


class QueryParamKey(enum.Enum):
FUNCTION_ID = "fnId"
STEP_ID = "stepId"


class InternalEvents(enum.Enum):
FUNCTION_FAILED = "inngest/function.failed"

Expand Down
5 changes: 5 additions & 0 deletions inngest/_internal/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,11 @@ def is_step_call_responses(


class Opcode(enum.Enum):
PLANNED = "StepPlanned"
SLEEP = "Sleep"
STEP = "Step"
WAIT_FOR_EVENT = "WaitForEvent"


# If the Executor sends this step ID then it isn't targeting a specific step.
UNSPECIFIED_STEP_ID = "step"
35 changes: 24 additions & 11 deletions inngest/_internal/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,13 +201,14 @@ def __init__(

self._on_failure_fn_id = f"{opts.id}-{suffix}"

async def call(
async def call( # noqa: C901
self,
call: execution.Call,
client: client_lib.Inngest,
fn_id: str,
call_input: execution.TransformableInput,
middleware: middleware_lib.MiddlewareManager,
target_hashed_id: str | None,
) -> execution.CallResult:
middleware = middleware_lib.MiddlewareManager.from_manager(middleware)
for m in self._middleware:
Expand Down Expand Up @@ -262,6 +263,7 @@ async def call(
memos,
middleware,
step_lib.StepIDCounter(),
target_hashed_id,
),
)
elif _is_function_handler_sync(handler):
Expand All @@ -276,6 +278,7 @@ async def call(
memos,
middleware,
step_lib.StepIDCounter(),
target_hashed_id,
),
)
else:
Expand All @@ -302,12 +305,16 @@ async def call(
if isinstance(err, Exception):
return execution.CallError.from_error(err)

output = await middleware.transform_output(interrupt.response.data)
if isinstance(output, Exception):
return execution.CallError.from_error(output)
interrupt.response.data = output
# TODO: How should transform_output work with multiple responses?
if len(interrupt.responses) == 1:
output = await middleware.transform_output(
interrupt.responses[0].data
)
if isinstance(output, Exception):
return execution.CallError.from_error(output)
interrupt.responses[0].data = output

return [interrupt.response]
return interrupt.responses
except Exception as err:
return execution.CallError.from_error(err)

Expand All @@ -318,6 +325,7 @@ def call_sync(
fn_id: str,
call_input: execution.TransformableInput,
middleware: middleware_lib.MiddlewareManager,
target_hashed_id: str | None,
) -> execution.CallResult:
middleware = middleware_lib.MiddlewareManager.from_manager(middleware)
for m in self._middleware:
Expand Down Expand Up @@ -363,6 +371,7 @@ def call_sync(
memos,
middleware,
step_lib.StepIDCounter(),
target_hashed_id,
),
)
else:
Expand All @@ -386,12 +395,16 @@ def call_sync(
if isinstance(err, Exception):
return execution.CallError.from_error(err)

output = middleware.transform_output_sync(interrupt.response.data)
if isinstance(output, Exception):
return execution.CallError.from_error(output)
interrupt.response.data = output
# TODO: How should transform_output work with multiple responses?
if len(interrupt.responses) == 1:
output = middleware.transform_output_sync(
interrupt.responses[0].data
)
if isinstance(output, Exception):
return execution.CallError.from_error(output)
interrupt.responses[0].data = output

return [interrupt.response]
return interrupt.responses
except Exception as err:
return execution.CallError.from_error(err)

Expand Down
22 changes: 20 additions & 2 deletions inngest/_internal/step_lib/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,14 @@ def __init__(
memos: StepMemos,
middleware: middleware_lib.MiddlewareManager,
step_id_counter: StepIDCounter,
target_hashed_id: str | None,
) -> None:
self._client = client
self._inside_parallel = False
self._memos = memos
self._middleware = middleware
self._step_id_counter = step_id_counter
self._target_hashed_id = target_hashed_id

def _get_hashed_id(self, step_id: str) -> str:
id_count = self._step_id_counter.increment(step_id)
Expand All @@ -72,6 +75,15 @@ def _get_memo_sync(self, hashed_id: str) -> object:

return memo

def _handle_targeting(self, *, hashed_id: str, step_id: str) -> None:
is_targeting_enabled = self._target_hashed_id is not None
if not is_targeting_enabled:
return

is_targeted = self._target_hashed_id == hashed_id
if not is_targeted:
raise SkipInterrupt()


class StepIDCounter:
"""
Expand Down Expand Up @@ -101,9 +113,15 @@ class ResponseInterrupt(BaseException):

def __init__(
self,
response: execution.StepResponse,
responses: execution.StepResponse | list[execution.StepResponse],
) -> None:
self.response = response
if not isinstance(responses, list):
responses = [responses]
self.responses = responses


class SkipInterrupt(BaseException):
pass


class WaitForEventOpts(types.BaseModel):
Expand Down
62 changes: 62 additions & 0 deletions inngest/_internal/step_lib/step_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,38 @@


class Step(base.StepBase):
async def _experimental_parallel(
self,
callables: tuple[typing.Callable[[], typing.Awaitable[types.T]], ...],
) -> tuple[types.T | None, ...]:
"""
Run multiple steps in parallel.
Args:
----
callables: An arbitrary number of step callbacks to run. These are
callables that contain the step (e.g. `lambda: step.run("my_step", my_step_fn)`.
"""

self._inside_parallel = True

outputs = tuple[types.T]()
responses: list[execution.StepResponse] = []
for cb in callables:
try:
output = await cb()
outputs = (*outputs, output)
except base.ResponseInterrupt as interrupt:
responses = [*responses, *interrupt.responses]
except base.SkipInterrupt:
pass

if len(responses) > 0:
raise base.ResponseInterrupt(responses)

self._inside_parallel = False
return outputs

@typing.overload
async def run(
self,
Expand Down Expand Up @@ -45,6 +77,24 @@ async def run(
if memo is not types.EmptySentinel:
return memo # type: ignore

is_targeting_enabled = self._target_hashed_id is not None
is_targeted = self._target_hashed_id == hashed_id
if is_targeting_enabled and not is_targeted:
# Skip this step because a different step is targeted.
raise base.SkipInterrupt()

if self._inside_parallel and not is_targeting_enabled:
# Plan this step because we're in parallel mode.
raise base.ResponseInterrupt(
execution.StepResponse(
data=None,
display_name=step_id,
id=hashed_id,
name=step_id,
op=execution.Opcode.PLANNED,
)
)

err = await self._middleware.before_execution()
if isinstance(err, Exception):
raise err
Expand Down Expand Up @@ -125,6 +175,12 @@ async def sleep_until(
if memo is not types.EmptySentinel:
return memo # type: ignore

is_targeting_enabled = self._target_hashed_id is not None
is_targeted = self._target_hashed_id == hashed_id
if is_targeting_enabled and not is_targeted:
# Skip this step because a different step is targeted.
raise base.SkipInterrupt()

err = await self._middleware.before_execution()
if isinstance(err, Exception):
raise err
Expand Down Expand Up @@ -170,6 +226,12 @@ async def wait_for_event(
# Fulfilled by an event
return event_lib.Event.model_validate(memo)

is_targeting_enabled = self._target_hashed_id is not None
is_targeted = self._target_hashed_id == hashed_id
if is_targeting_enabled and not is_targeted:
# Skip this step because a different step is targeted.
raise base.SkipInterrupt()

err = await self._middleware.before_execution()
if isinstance(err, Exception):
raise err
Expand Down
62 changes: 62 additions & 0 deletions inngest/_internal/step_lib/step_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,38 @@


class StepSync(base.StepBase):
def _experimental_parallel(
self,
callables: tuple[typing.Callable[[], types.T], ...],
) -> tuple[types.T | None, ...]:
"""
Run multiple steps in parallel.
Args:
----
callables: An arbitrary number of step callbacks to run. These are
callables that contain the step (e.g. `lambda: step.run("my_step", my_step_fn)`.
"""

self._inside_parallel = True

outputs = tuple[types.T]()
responses: list[execution.StepResponse] = []
for cb in callables:
try:
output = cb()
outputs = (*outputs, output)
except base.ResponseInterrupt as interrupt:
responses = [*responses, *interrupt.responses]
except base.SkipInterrupt:
pass

if len(responses) > 0:
raise base.ResponseInterrupt(responses)

self._inside_parallel = False
return outputs

def run(
self,
step_id: str,
Expand All @@ -28,6 +60,24 @@ def run(
if memo is not types.EmptySentinel:
return memo # type: ignore

is_targeting_enabled = self._target_hashed_id is not None
is_targeted = self._target_hashed_id == hashed_id
if is_targeting_enabled and not is_targeted:
# Skip this step because a different step is targeted.
raise base.SkipInterrupt()

if self._inside_parallel and not is_targeting_enabled:
# Plan this step because we're in parallel mode.
raise base.ResponseInterrupt(
execution.StepResponse(
data=None,
display_name=step_id,
id=hashed_id,
name=step_id,
op=execution.Opcode.PLANNED,
)
)

err = self._middleware.before_execution_sync()
if isinstance(err, Exception):
raise err
Expand Down Expand Up @@ -108,6 +158,12 @@ def sleep_until(
if memo is not types.EmptySentinel:
return memo # type: ignore

is_targeting_enabled = self._target_hashed_id is not None
is_targeted = self._target_hashed_id == hashed_id
if is_targeting_enabled and not is_targeted:
# Skip this step because a different step is targeted.
raise base.SkipInterrupt()

err = self._middleware.before_execution_sync()
if isinstance(err, Exception):
raise err
Expand Down Expand Up @@ -153,6 +209,12 @@ def wait_for_event(
# Fulfilled by an event
return event_lib.Event.model_validate(memo)

is_targeting_enabled = self._target_hashed_id is not None
is_targeted = self._target_hashed_id == hashed_id
if is_targeting_enabled and not is_targeted:
# Skip this step because a different step is targeted.
raise base.SkipInterrupt()

err = self._middleware.before_execution_sync()
if isinstance(err, Exception):
raise err
Expand Down
Loading

0 comments on commit 813d706

Please sign in to comment.