Skip to content

Commit

Permalink
Add transform_input middleware hook
Browse files Browse the repository at this point in the history
  • Loading branch information
amh4r committed Nov 2, 2023
1 parent 7583e2a commit cc46670
Show file tree
Hide file tree
Showing 8 changed files with 186 additions and 62 deletions.
7 changes: 6 additions & 1 deletion inngest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@
TriggerCron,
TriggerEvent,
)
from ._internal.middleware_lib import Middleware, MiddlewareSync
from ._internal.middleware_lib import (
CallInputTransform,
Middleware,
MiddlewareSync,
)
from ._internal.step_lib import Step, StepSync

__all__ = [
Expand All @@ -21,6 +25,7 @@
"Event",
"Function",
"Inngest",
"CallInputTransform",
"Middleware",
"MiddlewareSync",
"NonRetriableError",
Expand Down
86 changes: 53 additions & 33 deletions inngest/_internal/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,49 +240,53 @@ async def call_function(

# No memoized data means we're calling the function for the first time.
if len(call.steps.keys()) == 0:
self._client.middleware.before_run_execution_sync()
await self._client.middleware.before_run_execution()

comm_res: CommResponse
call_input = await self._client.middleware.transform_input(
logger=self._logger,
)

validation_res = req_sig.validate(self._signing_key)
if result.is_err(validation_res):
err = validation_res.err_value
extra = {}
if isinstance(err, errors.InternalError):
extra["code"] = err.code
self._logger.error(err, extra=extra)
comm_res = CommResponse.from_error(
await self._client.middleware.before_response()
return CommResponse.from_error(
self._logger,
self._framework,
err,
)
else:
match self._get_function(fn_id):
case result.Ok(fn):
call_res = await fn.call(call, self._client, fn_id)

if isinstance(call_res, execution.FunctionCallResponse):
# Only call this hook if we get a return at the function
# level.
self._client.middleware.after_run_execution_sync()
match self._get_function(fn_id):
case result.Ok(fn):
call_res = await fn.call(
call,
self._client,
fn_id,
call_input,
)

comm_res = CommResponse.from_call_result(
self._logger,
self._framework,
call_res,
)
case result.Err(err):
extra = {}
if isinstance(err, errors.InternalError):
extra["code"] = err.code
self._logger.error(err, extra=extra)
comm_res = CommResponse.from_error(
self._logger,
self._framework,
err,
)
if isinstance(call_res, execution.FunctionCallResponse):
# Only call this hook if we get a return at the function
# level.
await self._client.middleware.after_run_execution()

self._client.middleware.before_response_sync()
comm_res = CommResponse.from_call_result(
self._logger,
self._framework,
call_res,
)
case result.Err(err):
extra = {}
if isinstance(err, errors.InternalError):
extra["code"] = err.code
self._logger.error(err, extra=extra)
comm_res = CommResponse.from_error(
self._logger,
self._framework,
err,
)

await self._client.middleware.before_response()
return comm_res

def call_function_sync(
Expand All @@ -300,7 +304,18 @@ def call_function_sync(
if len(call.steps.keys()) == 0:
self._client.middleware.before_run_execution_sync()

comm_res: CommResponse
match self._client.middleware.transform_input_sync(
logger=self._logger,
):
case result.Ok(call_input):
pass
case result.Err(err):
self._client.middleware.before_response_sync()
return CommResponse.from_error(
self._logger,
self._framework,
err,
)

validation_res = req_sig.validate(self._signing_key)
if result.is_err(validation_res):
Expand All @@ -317,7 +332,12 @@ def call_function_sync(
else:
match self._get_function(fn_id):
case result.Ok(fn):
call_res = fn.call_sync(call, self._client, fn_id)
call_res = fn.call_sync(
call,
self._client,
fn_id,
call_input,
)

if isinstance(call_res, execution.FunctionCallResponse):
# Only call this hook if we get a return at the function
Expand Down
7 changes: 7 additions & 0 deletions inngest/_internal/execution.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import dataclasses
import enum
import logging
import typing

from . import errors, event_lib, transforms, types
Expand All @@ -23,6 +25,11 @@ class CallStack(types.BaseModel):
stack: list[str]


@dataclasses.dataclass
class CallInput:
logger: logging.Logger


class CallError(types.BaseModel):
"""
When an error that occurred during a call. Used for both function- and step-level
Expand Down
8 changes: 8 additions & 0 deletions inngest/_internal/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import dataclasses
import hashlib
import inspect
import logging
import typing

import pydantic
Expand Down Expand Up @@ -38,6 +39,7 @@ def __call__(
attempt: int,
event: event_lib.Event,
events: list[event_lib.Event],
logger: logging.Logger,
run_id: str,
step: step_lib.Step,
) -> typing.Awaitable[types.Serializable]:
Expand All @@ -52,6 +54,7 @@ def __call__(
attempt: int,
event: event_lib.Event,
events: list[event_lib.Event],
logger: logging.Logger,
run_id: str,
step: step_lib.StepSync,
) -> types.Serializable:
Expand Down Expand Up @@ -177,6 +180,7 @@ async def call(
call: execution.Call,
client: client_lib.Inngest,
fn_id: str,
call_input: execution.CallInput,
) -> execution.CallResult:
try:
handler: FunctionHandlerAsync | FunctionHandlerSync
Expand All @@ -201,6 +205,7 @@ async def call(
attempt=call.ctx.attempt,
event=call.event,
events=call.events,
logger=call_input.logger,
run_id=call.ctx.run_id,
step=step_lib.Step(
client,
Expand All @@ -213,6 +218,7 @@ async def call(
attempt=call.ctx.attempt,
event=call.event,
events=call.events,
logger=call_input.logger,
run_id=call.ctx.run_id,
step=step_lib.StepSync(
client,
Expand Down Expand Up @@ -255,6 +261,7 @@ def call_sync(
call: execution.Call,
client: client_lib.Inngest,
fn_id: str,
call_input: execution.CallInput,
) -> execution.CallResult:
try:
handler: FunctionHandlerAsync | FunctionHandlerSync
Expand All @@ -276,6 +283,7 @@ def call_sync(
attempt=call.ctx.attempt,
event=call.event,
events=call.events,
logger=call_input.logger,
run_id=call.ctx.run_id,
step=step_lib.StepSync(
client,
Expand Down
95 changes: 73 additions & 22 deletions inngest/_internal/middleware_lib.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
import dataclasses
import inspect
import logging
import typing

from . import errors, result
from . import errors, execution, result, transforms

BlankHook = typing.Callable[[], typing.Awaitable[None] | None]


@dataclasses.dataclass
class CallInputTransform:
logger: logging.Logger | None = None


class Middleware:
async def after_run_execution(self) -> None:
"""
Called after a function run is done executing. Called once per run
After a function run is done executing. Called once per run
regardless of the number of steps. Will still be called if the run
failed.
"""
Expand All @@ -18,7 +25,7 @@ async def after_run_execution(self) -> None:

async def before_response(self) -> None:
"""
Called after the output has been set and before the response is sent
After the output has been set and before the response is sent
back to Inngest. This is where you can perform any final actions before
the response is sent back to Inngest. Called multiple times per run when
using steps.
Expand All @@ -28,17 +35,25 @@ async def before_response(self) -> None:

async def before_run_execution(self) -> None:
"""
Called when a function run starts executing. Called once per run
Before a function run starts executing. Called once per run
regardless of the number of steps.
"""

return None

async def transform_input(self) -> CallInputTransform:
"""
Before calling a function or step. Used to replace certain arguments in
the function. Called multiple times per run when using steps.
"""

return CallInputTransform()


class MiddlewareSync:
def after_run_execution(self) -> None:
"""
Called after a function run is done executing. Called once per run
After a function run is done executing. Called once per run
regardless of the number of steps. Will still be called if the run
failed.
"""
Expand All @@ -47,7 +62,7 @@ def after_run_execution(self) -> None:

def before_response(self) -> None:
"""
Called after the output has been set and before the response is sent
After the output has been set and before the response is sent
back to Inngest. This is where you can perform any final actions before
the response is sent back to Inngest. Called multiple times per run when
using steps.
Expand All @@ -57,12 +72,25 @@ def before_response(self) -> None:

def before_run_execution(self) -> None:
"""
Called when a function run starts executing. Called once per run
Before a function run starts executing. Called once per run
regardless of the number of steps.
"""

return None

def transform_input(self) -> CallInputTransform:
"""
Before calling a function or step. Used to replace certain arguments in
the function. Called multiple times per run when using steps.
"""

return CallInputTransform()


_mismatched_sync = errors.MismatchedSync(
"encountered async middleware in non-async context"
)


class MiddlewareManager:
def __init__(
Expand All @@ -71,35 +99,58 @@ def __init__(
):
self._middleware = middleware

async def after_run_execution(self) -> None:
for m in self._middleware:
await transforms.maybe_await(m.after_run_execution())

def after_run_execution_sync(self) -> result.MaybeError[None]:
for m in self._middleware:
if inspect.iscoroutinefunction(m.after_run_execution):
return result.Err(
errors.MismatchedSync(
"encountered async middleware in non-async context"
)
)
return result.Err(_mismatched_sync)
m.after_run_execution()
return result.Ok(None)

async def before_response(self) -> None:
for m in self._middleware:
await transforms.maybe_await(m.before_response())

def before_response_sync(self) -> result.MaybeError[None]:
for m in self._middleware:
if inspect.iscoroutinefunction(m.before_response):
return result.Err(
errors.MismatchedSync(
"encountered async middleware in non-async context"
)
)
return result.Err(_mismatched_sync)
m.before_response()
return result.Ok(None)

async def before_run_execution(self) -> None:
for m in self._middleware:
await transforms.maybe_await(m.before_run_execution())

def before_run_execution_sync(self) -> result.MaybeError[None]:
for m in self._middleware:
if inspect.iscoroutinefunction(m.before_run_execution):
return result.Err(
errors.MismatchedSync(
"encountered async middleware in non-async context"
)
)
return result.Err(_mismatched_sync)
m.before_run_execution()
return result.Ok(None)

async def transform_input(
self,
logger: logging.Logger,
) -> execution.CallInput:
for m in self._middleware:
t = await transforms.maybe_await(m.transform_input())
if t.logger is not None:
logger = t.logger
return execution.CallInput(logger=logger)

def transform_input_sync(
self,
logger: logging.Logger,
) -> result.MaybeError[execution.CallInput]:
for m in self._middleware:
if isinstance(m, Middleware):
return result.Err(_mismatched_sync)

t = m.transform_input()
if t.logger is not None:
logger = t.logger
return result.Ok(execution.CallInput(logger=logger))
Loading

0 comments on commit cc46670

Please sign in to comment.