Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add transform_input middleware hook #8

Merged
merged 7 commits into from
Nov 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions examples/flask/app.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import logging

import flask
import src.inngest

Expand All @@ -7,6 +9,10 @@
app = flask.Flask(__name__)


log = logging.getLogger("werkzeug")
log.setLevel(logging.ERROR)


inngest.flask.serve(
app,
src.inngest.inngest_client,
Expand Down
14 changes: 11 additions & 3 deletions inngest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,13 @@
TriggerCron,
TriggerEvent,
)
from ._internal.middleware_lib import Middleware, MiddlewareSync
from ._internal.middleware_lib import CallInputTransform

# TODO: Uncomment when middleware is ready for external use.
# from ._internal.middleware_lib import (
# Middleware,
# MiddlewareSync,
# )
from ._internal.step_lib import Step, StepSync

__all__ = [
Expand All @@ -21,8 +27,10 @@
"Event",
"Function",
"Inngest",
"Middleware",
"MiddlewareSync",
"CallInputTransform",
# TODO: Uncomment when middleware is ready for external use.
# "Middleware",
# "MiddlewareSync",
"NonRetriableError",
"RateLimit",
"Step",
Expand Down
86 changes: 53 additions & 33 deletions inngest/_internal/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,49 +240,53 @@ async def call_function(

# No memoized data means we're calling the function for the first time.
if len(call.steps.keys()) == 0:
self._client.middleware.before_run_execution_sync()
await self._client.middleware.before_run_execution()

comm_res: CommResponse
call_input = await self._client.middleware.transform_input(
logger=self._logger,
)

validation_res = req_sig.validate(self._signing_key)
if result.is_err(validation_res):
err = validation_res.err_value
extra = {}
if isinstance(err, errors.InternalError):
extra["code"] = err.code
self._logger.error(err, extra=extra)
comm_res = CommResponse.from_error(
await self._client.middleware.before_response()
return CommResponse.from_error(
self._logger,
self._framework,
err,
)
else:
match self._get_function(fn_id):
case result.Ok(fn):
call_res = await fn.call(call, self._client, fn_id)

if isinstance(call_res, execution.FunctionCallResponse):
# Only call this hook if we get a return at the function
# level.
self._client.middleware.after_run_execution_sync()
match self._get_function(fn_id):
case result.Ok(fn):
call_res = await fn.call(
call,
self._client,
fn_id,
call_input,
)

comm_res = CommResponse.from_call_result(
self._logger,
self._framework,
call_res,
)
case result.Err(err):
extra = {}
if isinstance(err, errors.InternalError):
extra["code"] = err.code
self._logger.error(err, extra=extra)
comm_res = CommResponse.from_error(
self._logger,
self._framework,
err,
)
if isinstance(call_res, execution.FunctionCallResponse):
# Only call this hook if we get a return at the function
# level.
await self._client.middleware.after_run_execution()

self._client.middleware.before_response_sync()
comm_res = CommResponse.from_call_result(
self._logger,
self._framework,
call_res,
)
case result.Err(err):
extra = {}
if isinstance(err, errors.InternalError):
extra["code"] = err.code
self._logger.error(err, extra=extra)
comm_res = CommResponse.from_error(
self._logger,
self._framework,
err,
)

await self._client.middleware.before_response()
return comm_res

def call_function_sync(
Expand All @@ -300,7 +304,18 @@ def call_function_sync(
if len(call.steps.keys()) == 0:
self._client.middleware.before_run_execution_sync()

comm_res: CommResponse
match self._client.middleware.transform_input_sync(
logger=self._logger,
):
case result.Ok(call_input):
pass
case result.Err(err):
self._client.middleware.before_response_sync()
return CommResponse.from_error(
self._logger,
self._framework,
err,
)

validation_res = req_sig.validate(self._signing_key)
if result.is_err(validation_res):
Expand All @@ -317,7 +332,12 @@ def call_function_sync(
else:
match self._get_function(fn_id):
case result.Ok(fn):
call_res = fn.call_sync(call, self._client, fn_id)
call_res = fn.call_sync(
call,
self._client,
fn_id,
call_input,
)

if isinstance(call_res, execution.FunctionCallResponse):
# Only call this hook if we get a return at the function
Expand Down
7 changes: 7 additions & 0 deletions inngest/_internal/execution.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import dataclasses
import enum
import logging
import typing

from . import errors, event_lib, transforms, types
Expand All @@ -23,6 +25,11 @@ class CallStack(types.BaseModel):
stack: list[str]


@dataclasses.dataclass
class CallInput:
logger: logging.Logger


class CallError(types.BaseModel):
"""
When an error that occurred during a call. Used for both function- and step-level
Expand Down
8 changes: 8 additions & 0 deletions inngest/_internal/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import dataclasses
import hashlib
import inspect
import logging
import typing

import pydantic
Expand Down Expand Up @@ -38,6 +39,7 @@ def __call__(
attempt: int,
event: event_lib.Event,
events: list[event_lib.Event],
logger: logging.Logger,
run_id: str,
step: step_lib.Step,
) -> typing.Awaitable[types.Serializable]:
Expand All @@ -52,6 +54,7 @@ def __call__(
attempt: int,
event: event_lib.Event,
events: list[event_lib.Event],
logger: logging.Logger,
run_id: str,
step: step_lib.StepSync,
) -> types.Serializable:
Expand Down Expand Up @@ -177,6 +180,7 @@ async def call(
call: execution.Call,
client: client_lib.Inngest,
fn_id: str,
call_input: execution.CallInput,
) -> execution.CallResult:
try:
handler: FunctionHandlerAsync | FunctionHandlerSync
Expand All @@ -201,6 +205,7 @@ async def call(
attempt=call.ctx.attempt,
event=call.event,
events=call.events,
logger=call_input.logger,
run_id=call.ctx.run_id,
step=step_lib.Step(
client,
Expand All @@ -213,6 +218,7 @@ async def call(
attempt=call.ctx.attempt,
event=call.event,
events=call.events,
logger=call_input.logger,
run_id=call.ctx.run_id,
step=step_lib.StepSync(
client,
Expand Down Expand Up @@ -255,6 +261,7 @@ def call_sync(
call: execution.Call,
client: client_lib.Inngest,
fn_id: str,
call_input: execution.CallInput,
) -> execution.CallResult:
try:
handler: FunctionHandlerAsync | FunctionHandlerSync
Expand All @@ -276,6 +283,7 @@ def call_sync(
attempt=call.ctx.attempt,
event=call.event,
events=call.events,
logger=call_input.logger,
run_id=call.ctx.run_id,
step=step_lib.StepSync(
client,
Expand Down
Loading
Loading