Skip to content

Commit

Permalink
Merge pull request #9 from inngest/transform_output
Browse files Browse the repository at this point in the history
Add transform_output middleware hook
  • Loading branch information
amh4r authored Nov 3, 2023
2 parents 47c04a4 + 80db793 commit 5485e0b
Show file tree
Hide file tree
Showing 7 changed files with 165 additions and 83 deletions.
2 changes: 2 additions & 0 deletions inngest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
# MiddlewareSync,
# )
from ._internal.step_lib import Step, StepSync
from ._internal.types import Serializable

__all__ = [
"Batch",
Expand All @@ -33,6 +34,7 @@
# "MiddlewareSync",
"NonRetriableError",
"RateLimit",
"Serializable",
"Step",
"StepSync",
"Throttle",
Expand Down
87 changes: 41 additions & 46 deletions inngest/_internal/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,21 +239,24 @@ async def call_function(
"""

# No memoized data means we're calling the function for the first time.
if len(call.steps.keys()) == 0:
await self._client.middleware.before_run_execution()
is_first_call = len(call.steps.keys()) == 0
if is_first_call:
await self._client.middleware.before_function_execution()

# Give middleware the opportunity to change some of params passed to the
# user's handler.
call_input = await self._client.middleware.transform_input(
logger=self._logger,
)

# Validate the request signature.
validation_res = req_sig.validate(self._signing_key)
if result.is_err(validation_res):
err = validation_res.err_value
await self._client.middleware.before_response()
return CommResponse.from_error(
self._logger,
self._framework,
err,
validation_res.err_value,
)

match self._get_function(fn_id):
Expand All @@ -268,18 +271,14 @@ async def call_function(
if isinstance(call_res, execution.FunctionCallResponse):
# Only call this hook if we get a return at the function
# level.
await self._client.middleware.after_run_execution()
await self._client.middleware.after_function_execution()

comm_res = CommResponse.from_call_result(
self._logger,
self._framework,
call_res,
)
case result.Err(err):
extra = {}
if isinstance(err, errors.InternalError):
extra["code"] = err.code
self._logger.error(err, extra=extra)
comm_res = CommResponse.from_error(
self._logger,
self._framework,
Expand All @@ -301,9 +300,12 @@ def call_function_sync(
"""

# No memoized data means we're calling the function for the first time.
if len(call.steps.keys()) == 0:
self._client.middleware.before_run_execution_sync()
is_first_call = len(call.steps.keys()) == 0
if is_first_call:
self._client.middleware.before_function_execution_sync()

# Give middleware the opportunity to change some of params passed to the
# user's handler.
match self._client.middleware.transform_input_sync(
logger=self._logger,
):
Expand All @@ -317,48 +319,41 @@ def call_function_sync(
err,
)

# Validate the request signature.
validation_res = req_sig.validate(self._signing_key)
if result.is_err(validation_res):
err = validation_res.err_value
extra = {}
if isinstance(err, errors.InternalError):
extra["code"] = err.code
self._logger.error(err, extra=extra)
comm_res = CommResponse.from_error(
self._client.middleware.before_response_sync()
return CommResponse.from_error(
self._logger,
self._framework,
err,
validation_res.err_value,
)
else:
match self._get_function(fn_id):
case result.Ok(fn):
call_res = fn.call_sync(
call,
self._client,
fn_id,
call_input,
)

if isinstance(call_res, execution.FunctionCallResponse):
# Only call this hook if we get a return at the function
# level.
self._client.middleware.after_run_execution_sync()
match self._get_function(fn_id):
case result.Ok(fn):
call_res = fn.call_sync(
call,
self._client,
fn_id,
call_input,
)

comm_res = CommResponse.from_call_result(
self._logger,
self._framework,
call_res,
)
case result.Err(err):
extra = {}
if isinstance(err, errors.InternalError):
extra["code"] = err.code
self._logger.error(err, extra=extra)
comm_res = CommResponse.from_error(
self._logger,
self._framework,
err,
)
if isinstance(call_res, execution.FunctionCallResponse):
# Only call this hook if we get a return at the function
# level.
self._client.middleware.after_function_execution_sync()

comm_res = CommResponse.from_call_result(
self._logger,
self._framework,
call_res,
)
case result.Err(err):
comm_res = CommResponse.from_error(
self._logger,
self._framework,
err,
)

self._client.middleware.before_response_sync()
return comm_res
Expand Down
1 change: 1 addition & 0 deletions inngest/_internal/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class ErrorCode(enum.Enum):
INVALID_BODY = "invalid_body"
INVALID_FUNCTION_CONFIG = "invalid_function_config"
INVALID_REQUEST_SIGNATURE = "invalid_request_signature"
INVALID_TRANSFORM = "invalid_transform"
MISMATCHED_SYNC = "mismatched_sync"
MISSING_EVENT_KEY = "missing_event_key"
MISSING_FUNCTION = "missing_function"
Expand Down
10 changes: 10 additions & 0 deletions inngest/_internal/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,16 @@ def __init__(self, message: str | None = None) -> None:
)


class InvalidTransform(InternalError):
status_code: int = http.HTTPStatus.INTERNAL_SERVER_ERROR

def __init__(self, message: str | None = None) -> None:
super().__init__(
code=const.ErrorCode.INVALID_TRANSFORM,
message=message,
)


class MissingEventKey(InternalError):
status_code: int = http.HTTPStatus.INTERNAL_SERVER_ERROR

Expand Down
20 changes: 18 additions & 2 deletions inngest/_internal/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,8 @@ async def call(
)
)

output = await client.middleware.transform_output(output)

# Ensure the output is JSON-serializable.
match transforms.dump_json(output):
case result.Ok(_):
Expand All @@ -243,9 +245,11 @@ async def call(

return execution.FunctionCallResponse(data=output)
except step_lib.Interrupt as interrupt:
output = await client.middleware.transform_output(interrupt.data)

return [
execution.StepCallResponse(
data=interrupt.data,
data=output,
display_name=interrupt.display_name,
id=interrupt.hashed_id,
name=interrupt.name,
Expand Down Expand Up @@ -292,6 +296,12 @@ def call_sync(
),
)

match client.middleware.transform_output_sync(output):
case result.Ok(output):
pass
case result.Err(err):
return execution.CallError.from_error(err)

match transforms.dump_json(output):
case result.Ok(output_str):
pass
Expand All @@ -305,9 +315,15 @@ def call_sync(
)
)
except step_lib.Interrupt as interrupt:
match client.middleware.transform_output_sync(interrupt.data):
case result.Ok(output):
pass
case result.Err(err):
return execution.CallError.from_error(err)

return [
execution.StepCallResponse(
data=interrupt.data,
data=output,
display_name=interrupt.display_name,
id=interrupt.hashed_id,
name=interrupt.name,
Expand Down
91 changes: 66 additions & 25 deletions inngest/_internal/middleware_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
import typing

from . import errors, execution, result, transforms
from . import errors, execution, result, transforms, types

BlankHook = typing.Callable[[], typing.Awaitable[None] | None]

Expand All @@ -14,11 +14,10 @@ class CallInputTransform:


class Middleware:
async def after_run_execution(self) -> None:
async def after_function_execution(self) -> None:
"""
After a function run is done executing. Called once per run
regardless of the number of steps. Will still be called if the run
failed.
After a function is done executing. Called once per run regardless of
the number of steps. Will still be called if the run failed.
"""

return None
Expand All @@ -33,10 +32,10 @@ async def before_response(self) -> None:

return None

async def before_run_execution(self) -> None:
async def before_function_execution(self) -> None:
"""
Before a function run starts executing. Called once per run
regardless of the number of steps.
Before a function starts executing. Called once per run regardless of
the number of steps.
"""

return None
Expand All @@ -49,13 +48,24 @@ async def transform_input(self) -> CallInputTransform:

return CallInputTransform()

async def transform_output(
self,
output: types.Serializable,
) -> types.Serializable:
"""
After a function or step returns. Used to modify the returned data.
Called multiple times per run when using steps. Not called when an error
is thrown.
"""

return output


class MiddlewareSync:
def after_run_execution(self) -> None:
def after_function_execution(self) -> None:
"""
After a function run is done executing. Called once per run
regardless of the number of steps. Will still be called if the run
failed.
After a function is done executing. Called once per run regardless of
the number of steps. Will still be called if the run failed.
"""

return None
Expand All @@ -70,10 +80,10 @@ def before_response(self) -> None:

return None

def before_run_execution(self) -> None:
def before_function_execution(self) -> None:
"""
Before a function run starts executing. Called once per run
regardless of the number of steps.
Before a function starts executing. Called once per run regardless of
the number of steps.
"""

return None
Expand All @@ -86,6 +96,18 @@ def transform_input(self) -> CallInputTransform:

return CallInputTransform()

def transform_output(
self,
output: types.Serializable,
) -> types.Serializable:
"""
After a function or step returns. Used to modify the returned data.
Called multiple times per run when using steps. Not called when an error
is thrown.
"""

return output


_mismatched_sync = errors.MismatchedSync(
"encountered async middleware in non-async context"
Expand All @@ -102,15 +124,15 @@ def __init__(
def add(self, middleware: Middleware | MiddlewareSync) -> None:
self._middleware = [*self._middleware, middleware]

async def after_run_execution(self) -> None:
async def after_function_execution(self) -> None:
for m in self._middleware:
await transforms.maybe_await(m.after_run_execution())
await transforms.maybe_await(m.after_function_execution())

def after_run_execution_sync(self) -> result.MaybeError[None]:
def after_function_execution_sync(self) -> result.MaybeError[None]:
for m in self._middleware:
if inspect.iscoroutinefunction(m.after_run_execution):
if inspect.iscoroutinefunction(m.after_function_execution):
return result.Err(_mismatched_sync)
m.after_run_execution()
m.after_function_execution()
return result.Ok(None)

async def before_response(self) -> None:
Expand All @@ -124,15 +146,15 @@ def before_response_sync(self) -> result.MaybeError[None]:
m.before_response()
return result.Ok(None)

async def before_run_execution(self) -> None:
async def before_function_execution(self) -> None:
for m in self._middleware:
await transforms.maybe_await(m.before_run_execution())
await transforms.maybe_await(m.before_function_execution())

def before_run_execution_sync(self) -> result.MaybeError[None]:
def before_function_execution_sync(self) -> result.MaybeError[None]:
for m in self._middleware:
if inspect.iscoroutinefunction(m.before_run_execution):
if inspect.iscoroutinefunction(m.before_function_execution):
return result.Err(_mismatched_sync)
m.before_run_execution()
m.before_function_execution()
return result.Ok(None)

async def transform_input(
Expand All @@ -157,3 +179,22 @@ def transform_input_sync(
if t.logger is not None:
logger = t.logger
return result.Ok(execution.CallInput(logger=logger))

async def transform_output(
self,
output: types.Serializable,
) -> types.Serializable:
for m in self._middleware:
output = await transforms.maybe_await(m.transform_output(output))
return output

def transform_output_sync(
self,
output: types.Serializable,
) -> result.MaybeError[types.Serializable]:
for m in self._middleware:
if isinstance(m, Middleware):
return result.Err(_mismatched_sync)

output = m.transform_output(output)
return result.Ok(output)
Loading

0 comments on commit 5485e0b

Please sign in to comment.