diff --git a/.vscode/settings.json b/.vscode/settings.json index b2a34a5..183e9b2 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -6,5 +6,6 @@ }, "editor.formatOnSave": true, "isort.args": ["--profile", "black"], - "isort.check": true + "isort.check": true, + "mypy-type-checker.preferDaemon": true } diff --git a/examples/flask/app.py b/examples/flask/app.py index 28e5ed4..7317d0e 100644 --- a/examples/flask/app.py +++ b/examples/flask/app.py @@ -1,21 +1,16 @@ -import logging - import flask -import src.inngest +from src.inngest import inngest_client -import examples.functions import inngest.flask +from examples import functions app = flask.Flask(__name__) - - -log = logging.getLogger("werkzeug") -log.setLevel(logging.ERROR) +inngest_client.set_logger(app.logger) inngest.flask.serve( app, - src.inngest.inngest_client, - examples.functions.functions_sync, + inngest_client, + functions.functions_sync, ) app.run(port=8000) diff --git a/inngest/__init__.py b/inngest/__init__.py index 7a4d9b0..ed34f48 100644 --- a/inngest/__init__.py +++ b/inngest/__init__.py @@ -11,15 +11,15 @@ TriggerCron, TriggerEvent, ) -from ._internal.middleware_lib import CallInputTransform # TODO: Uncomment when middleware is ready for external use. # from ._internal.middleware_lib import ( # Middleware, # MiddlewareSync, # ) +# from ._internal.execution import TransformableCallInput from ._internal.step_lib import Step, StepSync -from ._internal.types import Serializable +from ._internal.types import Logger, Serializable __all__ = [ "Batch", @@ -28,10 +28,11 @@ "Event", "Function", "Inngest", - "CallInputTransform", + "Logger", # TODO: Uncomment when middleware is ready for external use. # "Middleware", # "MiddlewareSync", + # "TransformableCallInput", "NonRetriableError", "RateLimit", "Serializable", diff --git a/inngest/_internal/client_lib.py b/inngest/_internal/client_lib.py index b9c961b..bb598ac 100644 --- a/inngest/_internal/client_lib.py +++ b/inngest/_internal/client_lib.py @@ -3,14 +3,19 @@ import logging import os import time +import typing import urllib.parse import httpx -from . import const, env, errors, event_lib, middleware_lib, net, result +from . import const, env, errors, event_lib, middleware_lib, net, result, types class Inngest: + middleware: list[ + typing.Type[middleware_lib.Middleware | middleware_lib.MiddlewareSync] + ] + def __init__( self, *, @@ -18,9 +23,11 @@ def __init__( base_url: str | None = None, event_key: str | None = None, is_production: bool | None = None, - logger: logging.Logger | None = None, + logger: types.Logger | None = None, middleware: list[ - middleware_lib.Middleware | middleware_lib.MiddlewareSync + typing.Type[ + middleware_lib.Middleware | middleware_lib.MiddlewareSync + ] ] | None = None, ) -> None: @@ -28,7 +35,7 @@ def __init__( self.base_url = base_url self.is_production = is_production or env.is_prod() self.logger = logger or logging.getLogger(__name__) - self.middleware = middleware_lib.MiddlewareManager(middleware or []) + self.middleware = middleware or [] if event_key is None: if not self.is_production: @@ -80,6 +87,14 @@ def _build_send_request( ) ) + def add_middleware( + self, + middleware: typing.Type[ + middleware_lib.Middleware | middleware_lib.MiddlewareSync + ], + ) -> None: + self.middleware = [*self.middleware, middleware] + async def send( self, events: event_lib.Event | list[event_lib.Event], @@ -110,6 +125,9 @@ def send_sync( raise err return ids + def set_logger(self, logger: types.Logger) -> None: + self.logger = logger + def _extract_ids(body: object) -> list[str]: if not isinstance(body, dict) or "ids" not in body: diff --git a/inngest/_internal/comm.py b/inngest/_internal/comm.py index fd216cf..3fb72f7 100644 --- a/inngest/_internal/comm.py +++ b/inngest/_internal/comm.py @@ -2,19 +2,19 @@ import http import json -import logging import os +import typing import urllib.parse import httpx from inngest._internal import ( - client_lib, const, errors, execution, function, function_config, + middleware_lib, net, registration, result, @@ -22,6 +22,10 @@ types, ) +# Prevent circular import +if typing.TYPE_CHECKING: + from inngest._internal import client_lib + class CommResponse: def __init__( @@ -42,7 +46,7 @@ def is_success(self) -> bool: @classmethod def from_call_result( cls, - logger: logging.Logger, + logger: types.Logger, framework: const.Framework, call_res: execution.CallResult, ) -> CommResponse: @@ -107,7 +111,7 @@ def from_call_result( @classmethod def from_error( cls, - logger: logging.Logger, + logger: types.Logger, framework: const.Framework, err: Exception, ) -> CommResponse: @@ -138,7 +142,6 @@ class CommHandler: _fns: dict[str, function.Function] _framework: const.Framework _is_production: bool - _logger: logging.Logger _signing_key: str | None def __init__( @@ -148,19 +151,18 @@ def __init__( client: client_lib.Inngest, framework: const.Framework, functions: list[function.Function], - logger: logging.Logger, signing_key: str | None = None, ) -> None: + self._client = client self._is_production = client.is_production - self._logger = logger if not self._is_production: - self._logger.info("Dev Server mode enabled") + self._client.logger.info("Dev Server mode enabled") base_url = base_url or os.getenv(const.EnvKey.BASE_URL.value) if base_url is None: if not self._is_production: - self._logger.info("Defaulting API origin to Dev Server") + self._client.logger.info("Defaulting API origin to Dev Server") base_url = const.DEV_SERVER_ORIGIN else: base_url = const.DEFAULT_API_ORIGIN @@ -170,7 +172,6 @@ def __init__( except Exception as err: raise errors.InvalidBaseURL() from err - self._client = client self._fns = {fn.get_id(): fn for fn in functions} self._framework = framework @@ -178,7 +179,7 @@ def __init__( if self._client.is_production: signing_key = os.getenv(const.EnvKey.SIGNING_KEY.value) if signing_key is None: - self._logger.error("missing signing key") + self._client.logger.error("missing signing key") raise errors.MissingSigningKey() self._signing_key = signing_key @@ -238,23 +239,20 @@ async def call_function( Handles a function call from the Executor. """ - # No memoized data means we're calling the function for the first time. - is_first_call = len(call.steps.keys()) == 0 - if is_first_call: - await self._client.middleware.before_function_execution() + middleware = middleware_lib.MiddlewareManager(self._client) # Give middleware the opportunity to change some of params passed to the # user's handler. - call_input = await self._client.middleware.transform_input( - logger=self._logger, + call_input = await middleware.transform_input( + execution.TransformableCallInput(logger=self._client.logger), ) # Validate the request signature. validation_res = req_sig.validate(self._signing_key) if result.is_err(validation_res): - await self._client.middleware.before_response() + await middleware.before_response() return CommResponse.from_error( - self._logger, + self._client.logger, self._framework, validation_res.err_value, ) @@ -266,26 +264,27 @@ async def call_function( self._client, fn_id, call_input, + middleware, ) if isinstance(call_res, execution.FunctionCallResponse): # Only call this hook if we get a return at the function # level. - await self._client.middleware.after_function_execution() + await middleware.after_execution() comm_res = CommResponse.from_call_result( - self._logger, + self._client.logger, self._framework, call_res, ) case result.Err(err): comm_res = CommResponse.from_error( - self._logger, + self._client.logger, self._framework, err, ) - await self._client.middleware.before_response() + await middleware.before_response() return comm_res def call_function_sync( @@ -299,22 +298,19 @@ def call_function_sync( Handles a function call from the Executor. """ - # No memoized data means we're calling the function for the first time. - is_first_call = len(call.steps.keys()) == 0 - if is_first_call: - self._client.middleware.before_function_execution_sync() + middleware = middleware_lib.MiddlewareManager(self._client) # Give middleware the opportunity to change some of params passed to the # user's handler. - match self._client.middleware.transform_input_sync( - logger=self._logger, + match middleware.transform_input_sync( + execution.TransformableCallInput(logger=self._client.logger), ): case result.Ok(call_input): pass case result.Err(err): - self._client.middleware.before_response_sync() + middleware.before_response_sync() return CommResponse.from_error( - self._logger, + self._client.logger, self._framework, err, ) @@ -322,9 +318,9 @@ def call_function_sync( # Validate the request signature. validation_res = req_sig.validate(self._signing_key) if result.is_err(validation_res): - self._client.middleware.before_response_sync() + middleware.before_response_sync() return CommResponse.from_error( - self._logger, + self._client.logger, self._framework, validation_res.err_value, ) @@ -336,26 +332,27 @@ def call_function_sync( self._client, fn_id, call_input, + middleware, ) if isinstance(call_res, execution.FunctionCallResponse): # Only call this hook if we get a return at the function # level. - self._client.middleware.after_function_execution_sync() + middleware.after_execution_sync() comm_res = CommResponse.from_call_result( - self._logger, + self._client.logger, self._framework, call_res, ) case result.Err(err): comm_res = CommResponse.from_error( - self._logger, + self._client.logger, self._framework, err, ) - self._client.middleware.before_response_sync() + middleware.before_response_sync() return comm_res def _get_function( @@ -415,14 +412,14 @@ def _parse_registration_response( server_res_body = server_res.json() except Exception: return CommResponse.from_error( - self._logger, + self._client.logger, self._framework, errors.RegistrationError("response is not valid JSON"), ) if not isinstance(server_res_body, dict): return CommResponse.from_error( - self._logger, + self._client.logger, self._framework, errors.RegistrationError("response is not an object"), ) @@ -438,7 +435,7 @@ def _parse_registration_response( if not isinstance(msg, str): msg = "registration failed" comm_res = CommResponse.from_error( - self._logger, + self._client.logger, self._framework, errors.RegistrationError(msg.strip()), ) @@ -459,9 +456,9 @@ async def register( case result.Ok(_): pass case result.Err(err): - self._logger.error(err) + self._client.logger.error(err) return CommResponse.from_error( - self._logger, + self._client.logger, self._framework, err, ) @@ -473,9 +470,9 @@ async def register( await client.send(req) ) case result.Err(err): - self._logger.error(err) + self._client.logger.error(err) return CommResponse.from_error( - self._logger, + self._client.logger, self._framework, err, ) @@ -496,9 +493,9 @@ def register_sync( case result.Ok(_): pass case result.Err(err): - self._logger.error(err) + self._client.logger.error(err) return CommResponse.from_error( - self._logger, + self._client.logger, self._framework, err, ) @@ -508,9 +505,9 @@ def register_sync( case result.Ok(req): res = self._parse_registration_response(client.send(req)) case result.Err(err): - self._logger.error(err) + self._client.logger.error(err) return CommResponse.from_error( - self._logger, + self._client.logger, self._framework, err, ) diff --git a/inngest/_internal/comm_test.py b/inngest/_internal/comm_test.py index af7df3c..77094ac 100644 --- a/inngest/_internal/comm_test.py +++ b/inngest/_internal/comm_test.py @@ -52,7 +52,6 @@ def fn(**_kwargs: object) -> int: client=self.client, framework=const.Framework.FLASK, functions=[fn], - logger=self.client.logger, ) assert result.is_ok(handler.get_function_configs("http://foo.bar")) @@ -70,7 +69,6 @@ def test_no_functions(self) -> None: client=self.client, framework=const.Framework.FLASK, functions=functions, - logger=self.client.logger, ) match handler.get_function_configs("http://foo.bar"): diff --git a/inngest/_internal/execution.py b/inngest/_internal/execution.py index 8a2b2f7..b8b1f12 100644 --- a/inngest/_internal/execution.py +++ b/inngest/_internal/execution.py @@ -2,7 +2,6 @@ import dataclasses import enum -import logging import typing from . import errors, event_lib, transforms, types @@ -26,8 +25,8 @@ class CallStack(types.BaseModel): @dataclasses.dataclass -class CallInput: - logger: logging.Logger +class TransformableCallInput: + logger: types.Logger class CallError(types.BaseModel): diff --git a/inngest/_internal/function.py b/inngest/_internal/function.py index 59356c9..42f2afd 100644 --- a/inngest/_internal/function.py +++ b/inngest/_internal/function.py @@ -3,7 +3,6 @@ import dataclasses import hashlib import inspect -import logging import typing import pydantic @@ -15,6 +14,7 @@ event_lib, execution, function_config, + middleware_lib, result, step_lib, transforms, @@ -39,7 +39,7 @@ def __call__( attempt: int, event: event_lib.Event, events: list[event_lib.Event], - logger: logging.Logger, + logger: types.Logger, run_id: str, step: step_lib.Step, ) -> typing.Awaitable[types.Serializable]: @@ -54,7 +54,7 @@ def __call__( attempt: int, event: event_lib.Event, events: list[event_lib.Event], - logger: logging.Logger, + logger: types.Logger, run_id: str, step: step_lib.StepSync, ) -> types.Serializable: @@ -180,8 +180,15 @@ async def call( call: execution.Call, client: client_lib.Inngest, fn_id: str, - call_input: execution.CallInput, + call_input: execution.TransformableCallInput, + middleware: middleware_lib.MiddlewareManager, ) -> execution.CallResult: + memos = step_lib.StepMemos(call.steps) + + # No memoized data means we're calling the function for the first time. + if memos.size == 0: + await middleware.before_execution() + try: handler: FunctionHandlerAsync | FunctionHandlerSync if self.id == fn_id: @@ -209,7 +216,8 @@ async def call( run_id=call.ctx.run_id, step=step_lib.Step( client, - call.steps, + memos, + middleware, step_lib.StepIDCounter(), ), ) @@ -222,7 +230,8 @@ async def call( run_id=call.ctx.run_id, step=step_lib.StepSync( client, - call.steps, + memos, + middleware, step_lib.StepIDCounter(), ), ) @@ -234,7 +243,7 @@ async def call( ) ) - output = await client.middleware.transform_output(output) + output = await middleware.transform_output(output) # Ensure the output is JSON-serializable. match transforms.dump_json(output): @@ -245,7 +254,7 @@ async def call( return execution.FunctionCallResponse(data=output) except step_lib.Interrupt as interrupt: - output = await client.middleware.transform_output(interrupt.data) + output = await middleware.transform_output(interrupt.data) return [ execution.StepCallResponse( @@ -265,8 +274,15 @@ def call_sync( call: execution.Call, client: client_lib.Inngest, fn_id: str, - call_input: execution.CallInput, + call_input: execution.TransformableCallInput, + middleware: middleware_lib.MiddlewareManager, ) -> execution.CallResult: + memos = step_lib.StepMemos(call.steps) + + # No memoized data means we're calling the function for the first time. + if memos.size == 0: + middleware.before_execution_sync() + try: handler: FunctionHandlerAsync | FunctionHandlerSync if self.id == fn_id: @@ -291,12 +307,13 @@ def call_sync( run_id=call.ctx.run_id, step=step_lib.StepSync( client, - call.steps, + memos, + middleware, step_lib.StepIDCounter(), ), ) - match client.middleware.transform_output_sync(output): + match middleware.transform_output_sync(output): case result.Ok(output): pass case result.Err(err): @@ -315,7 +332,7 @@ def call_sync( ) ) except step_lib.Interrupt as interrupt: - match client.middleware.transform_output_sync(interrupt.data): + match middleware.transform_output_sync(interrupt.data): case result.Ok(output): pass case result.Err(err): diff --git a/inngest/_internal/log.py b/inngest/_internal/log.py new file mode 100644 index 0000000..26caa5f --- /dev/null +++ b/inngest/_internal/log.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +import logging +import typing + +from . import execution, middleware_lib, types + +# Prevent circular import +if typing.TYPE_CHECKING: + from . import client_lib + + +# https://github.com/python/typeshed/issues/7855#issuecomment-1128857842 +if typing.TYPE_CHECKING: + _LoggerAdapter = logging.LoggerAdapter[types.Logger] +else: + _LoggerAdapter = logging.LoggerAdapter + + +class LoggerProxy: + """ + Wraps a logger, allowing us to disable logging when we want to. This is + important because we may call a function multiple times and we don't want + duplicate logs. + """ + + _proxied_methods = ( + "critical", + "debug", + "error", + "exception", + "fatal", + "info", + "log", + "warn", + "warning", + ) + + def __init__(self, logger: types.Logger) -> None: + self._is_enabled = False + self.logger = logger + + def __getattr__(self, name: str) -> object: + if name in self._proxied_methods and not self._is_enabled: + # Return noop + return lambda *args, **kwargs: None + + return getattr(self.logger, name) + + def enable(self) -> None: + self._is_enabled = True + + +class LoggerMiddleware(middleware_lib.MiddlewareSync): + def __init__(self, client: client_lib.Inngest) -> None: + super().__init__(client) + self.logger = LoggerProxy(client.logger) + + def before_execution(self) -> None: + # Enable logging because we've encountered new code. + self.logger.enable() + + def transform_input( + self, + call_input: execution.TransformableCallInput, + ) -> execution.TransformableCallInput: + self.logger.logger = call_input.logger + call_input.logger = self.logger # type: ignore + return call_input diff --git a/inngest/_internal/middleware_lib.py b/inngest/_internal/middleware_lib.py deleted file mode 100644 index d8fd3a5..0000000 --- a/inngest/_internal/middleware_lib.py +++ /dev/null @@ -1,200 +0,0 @@ -import dataclasses -import inspect -import logging -import typing - -from . import errors, execution, result, transforms, types - -BlankHook = typing.Callable[[], typing.Awaitable[None] | None] - - -@dataclasses.dataclass -class CallInputTransform: - logger: logging.Logger | None = None - - -class Middleware: - async def after_function_execution(self) -> None: - """ - After a function is done executing. Called once per run regardless of - the number of steps. Will still be called if the run failed. - """ - - return None - - async def before_response(self) -> None: - """ - After the output has been set and before the response is sent - back to Inngest. This is where you can perform any final actions before - the response is sent back to Inngest. Called multiple times per run when - using steps. - """ - - return None - - async def before_function_execution(self) -> None: - """ - Before a function starts executing. Called once per run regardless of - the number of steps. - """ - - return None - - async def transform_input(self) -> CallInputTransform: - """ - Before calling a function or step. Used to replace certain arguments in - the function. Called multiple times per run when using steps. - """ - - return CallInputTransform() - - async def transform_output( - self, - output: types.Serializable, - ) -> types.Serializable: - """ - After a function or step returns. Used to modify the returned data. - Called multiple times per run when using steps. Not called when an error - is thrown. - """ - - return output - - -class MiddlewareSync: - def after_function_execution(self) -> None: - """ - After a function is done executing. Called once per run regardless of - the number of steps. Will still be called if the run failed. - """ - - return None - - def before_response(self) -> None: - """ - After the output has been set and before the response is sent - back to Inngest. This is where you can perform any final actions before - the response is sent back to Inngest. Called multiple times per run when - using steps. - """ - - return None - - def before_function_execution(self) -> None: - """ - Before a function starts executing. Called once per run regardless of - the number of steps. - """ - - return None - - def transform_input(self) -> CallInputTransform: - """ - Before calling a function or step. Used to replace certain arguments in - the function. Called multiple times per run when using steps. - """ - - return CallInputTransform() - - def transform_output( - self, - output: types.Serializable, - ) -> types.Serializable: - """ - After a function or step returns. Used to modify the returned data. - Called multiple times per run when using steps. Not called when an error - is thrown. - """ - - return output - - -_mismatched_sync = errors.MismatchedSync( - "encountered async middleware in non-async context" -) - - -class MiddlewareManager: - def __init__( - self, - middleware: list[Middleware | MiddlewareSync], - ): - self._middleware = middleware - - def add(self, middleware: Middleware | MiddlewareSync) -> None: - self._middleware = [*self._middleware, middleware] - - async def after_function_execution(self) -> None: - for m in self._middleware: - await transforms.maybe_await(m.after_function_execution()) - - def after_function_execution_sync(self) -> result.MaybeError[None]: - for m in self._middleware: - if inspect.iscoroutinefunction(m.after_function_execution): - return result.Err(_mismatched_sync) - m.after_function_execution() - return result.Ok(None) - - async def before_response(self) -> None: - for m in self._middleware: - await transforms.maybe_await(m.before_response()) - - def before_response_sync(self) -> result.MaybeError[None]: - for m in self._middleware: - if inspect.iscoroutinefunction(m.before_response): - return result.Err(_mismatched_sync) - m.before_response() - return result.Ok(None) - - async def before_function_execution(self) -> None: - for m in self._middleware: - await transforms.maybe_await(m.before_function_execution()) - - def before_function_execution_sync(self) -> result.MaybeError[None]: - for m in self._middleware: - if inspect.iscoroutinefunction(m.before_function_execution): - return result.Err(_mismatched_sync) - m.before_function_execution() - return result.Ok(None) - - async def transform_input( - self, - logger: logging.Logger, - ) -> execution.CallInput: - for m in self._middleware: - t = await transforms.maybe_await(m.transform_input()) - if t.logger is not None: - logger = t.logger - return execution.CallInput(logger=logger) - - def transform_input_sync( - self, - logger: logging.Logger, - ) -> result.MaybeError[execution.CallInput]: - for m in self._middleware: - if isinstance(m, Middleware): - return result.Err(_mismatched_sync) - - t = m.transform_input() - if t.logger is not None: - logger = t.logger - return result.Ok(execution.CallInput(logger=logger)) - - async def transform_output( - self, - output: types.Serializable, - ) -> types.Serializable: - for m in self._middleware: - output = await transforms.maybe_await(m.transform_output(output)) - return output - - def transform_output_sync( - self, - output: types.Serializable, - ) -> result.MaybeError[types.Serializable]: - for m in self._middleware: - if isinstance(m, Middleware): - return result.Err(_mismatched_sync) - - output = m.transform_output(output) - return result.Ok(output) diff --git a/inngest/_internal/middleware_lib/__init__.py b/inngest/_internal/middleware_lib/__init__.py new file mode 100644 index 0000000..26abed0 --- /dev/null +++ b/inngest/_internal/middleware_lib/__init__.py @@ -0,0 +1,8 @@ +from .manager import MiddlewareManager +from .middleware import Middleware, MiddlewareSync + +__all__ = [ + "Middleware", + "MiddlewareManager", + "MiddlewareSync", +] diff --git a/inngest/_internal/middleware_lib/log.py b/inngest/_internal/middleware_lib/log.py new file mode 100644 index 0000000..da9c64c --- /dev/null +++ b/inngest/_internal/middleware_lib/log.py @@ -0,0 +1,64 @@ +from __future__ import annotations + +import logging +import typing + +from inngest._internal import execution, types + +from .middleware import MiddlewareSync + +# Prevent circular import +if typing.TYPE_CHECKING: + from inngest._internal import client_lib + + +# https://github.com/python/typeshed/issues/7855#issuecomment-1128857842 +if typing.TYPE_CHECKING: + _LoggerAdapter = logging.LoggerAdapter[types.Logger] +else: + _LoggerAdapter = logging.LoggerAdapter + + +class LoggerProxy: + _proxied_methods = ( + "critical", + "debug", + "error", + "exception", + "fatal", + "info", + "log", + "warn", + "warning", + ) + + def __init__(self, logger: types.Logger) -> None: + self._is_enabled = False + self.logger = logger + + def __getattr__(self, name: str) -> object: + if name in self._proxied_methods and not self._is_enabled: + # Return noop + return lambda *args, **kwargs: None + + return getattr(self.logger, name) + + def enable(self) -> None: + self._is_enabled = True + + +class LoggerMiddleware(MiddlewareSync): + def __init__(self, client: client_lib.Inngest) -> None: + super().__init__(client) + self.logger = LoggerProxy(client.logger) + + def before_execution(self) -> None: + self.logger.enable() + + def transform_input( + self, + call_input: execution.TransformableCallInput, + ) -> execution.TransformableCallInput: + self.logger.logger = call_input.logger + call_input.logger = self.logger # type: ignore + return call_input diff --git a/inngest/_internal/middleware_lib/manager.py b/inngest/_internal/middleware_lib/manager.py new file mode 100644 index 0000000..31d9431 --- /dev/null +++ b/inngest/_internal/middleware_lib/manager.py @@ -0,0 +1,123 @@ +from __future__ import annotations + +import inspect +import typing + +from inngest._internal import errors, execution, result, transforms, types + +from .log import LoggerMiddleware +from .middleware import Middleware, MiddlewareSync + +# Prevent circular import +if typing.TYPE_CHECKING: + from inngest._internal import client_lib + + +MiddlewareT = typing.TypeVar("MiddlewareT", bound=Middleware) +MiddlewareSyncT = typing.TypeVar("MiddlewareSyncT", bound=MiddlewareSync) + + +_mismatched_sync = errors.MismatchedSync( + "encountered async middleware in non-async context" +) + +DEFAULT_MIDDLEWARE: list[typing.Type[Middleware | MiddlewareSync]] = [ + LoggerMiddleware +] + + +class MiddlewareManager: + def __init__(self, client: client_lib.Inngest) -> None: + middleware = [ + *client.middleware, + *DEFAULT_MIDDLEWARE, + ] + self._middleware = [m(client) for m in middleware] + + self._disabled_methods = set[str]() + + def add(self, middleware: Middleware | MiddlewareSync) -> None: + self._middleware = [*self._middleware, middleware] + + async def after_execution(self) -> None: + for m in self._middleware: + await transforms.maybe_await(m.after_execution()) + + def after_execution_sync(self) -> result.MaybeError[None]: + for m in self._middleware: + if inspect.iscoroutinefunction(m.after_execution): + return result.Err(_mismatched_sync) + m.after_execution() + return result.Ok(None) + + async def before_execution(self) -> None: + method_name = inspect.currentframe().f_code.co_name # type: ignore + if method_name in self._disabled_methods: + return None + + for m in self._middleware: + await transforms.maybe_await(m.before_execution()) + self._disabled_methods.add(method_name) + + def before_execution_sync(self) -> result.MaybeError[None]: + method_name = inspect.currentframe().f_code.co_name # type: ignore + if method_name in self._disabled_methods: + return result.Ok(None) + + for m in self._middleware: + if inspect.iscoroutinefunction(m.before_execution): + return result.Err(_mismatched_sync) + m.before_execution() + self._disabled_methods.add(method_name) + return result.Ok(None) + + async def before_response(self) -> None: + for m in self._middleware: + await transforms.maybe_await(m.before_response()) + + def before_response_sync(self) -> result.MaybeError[None]: + for m in self._middleware: + if inspect.iscoroutinefunction(m.before_response): + return result.Err(_mismatched_sync) + m.before_response() + return result.Ok(None) + + async def transform_input( + self, + call_input: execution.TransformableCallInput, + ) -> execution.TransformableCallInput: + for m in self._middleware: + call_input = await transforms.maybe_await( + m.transform_input(call_input), + ) + return call_input + + def transform_input_sync( + self, + call_input: execution.TransformableCallInput, + ) -> result.MaybeError[execution.TransformableCallInput]: + for m in self._middleware: + if isinstance(m, Middleware): + return result.Err(_mismatched_sync) + + call_input = m.transform_input(call_input) + return result.Ok(call_input) + + async def transform_output( + self, + output: types.Serializable, + ) -> types.Serializable: + for m in self._middleware: + output = await transforms.maybe_await(m.transform_output(output)) + return output + + def transform_output_sync( + self, + output: types.Serializable, + ) -> result.MaybeError[types.Serializable]: + for m in self._middleware: + if isinstance(m, Middleware): + return result.Err(_mismatched_sync) + + output = m.transform_output(output) + return result.Ok(output) diff --git a/inngest/_internal/middleware_lib/middleware.py b/inngest/_internal/middleware_lib/middleware.py new file mode 100644 index 0000000..33737dc --- /dev/null +++ b/inngest/_internal/middleware_lib/middleware.py @@ -0,0 +1,119 @@ +from __future__ import annotations + +import typing + +from inngest._internal import execution, types + +# Prevent circular import +if typing.TYPE_CHECKING: + from inngest._internal import client_lib + + +class Middleware: + def __init__(self, client: client_lib.Inngest) -> None: + self.client = client + + async def after_execution(self) -> None: + """ + After executing new code. Called multiple times per run when using + steps. + """ + + return None + + async def before_execution(self) -> None: + """ + Before executing new code. Called multiple times per run when using + steps. + """ + + return None + + async def before_response(self) -> None: + """ + After the output has been set and before the response is sent + back to Inngest. This is where you can perform any final actions before + the response is sent back to Inngest. Called multiple times per run when + using steps. + """ + + return None + + async def transform_input( + self, + call_input: execution.TransformableCallInput, + ) -> execution.TransformableCallInput: + """ + Before calling a function or step. Used to replace certain arguments in + the function. Called multiple times per run when using steps. + """ + + return call_input + + async def transform_output( + self, + output: types.Serializable, + ) -> types.Serializable: + """ + After a function or step returns. Used to modify the returned data. + Called multiple times per run when using steps. Not called when an error + is thrown. + """ + + return output + + +class MiddlewareSync: + client: client_lib.Inngest + + def __init__(self, client: client_lib.Inngest) -> None: + self.client = client + + def after_execution(self) -> None: + """ + After executing new code. Called multiple times per run when using + steps. + """ + + return None + + def before_execution(self) -> None: + """ + Before executing new code. Called multiple times per run when using + steps. + """ + + return None + + def before_response(self) -> None: + """ + After the output has been set and before the response is sent + back to Inngest. This is where you can perform any final actions before + the response is sent back to Inngest. Called multiple times per run when + using steps. + """ + + return None + + def transform_input( + self, + call_input: execution.TransformableCallInput, + ) -> execution.TransformableCallInput: + """ + Before calling a function or step. Used to replace certain arguments in + the function. Called multiple times per run when using steps. + """ + + return call_input + + def transform_output( + self, + output: types.Serializable, + ) -> types.Serializable: + """ + After a function or step returns. Used to modify the returned data. + Called multiple times per run when using steps. Not called when an error + is thrown. + """ + + return output diff --git a/inngest/_internal/step_lib/__init__.py b/inngest/_internal/step_lib/__init__.py index 8ed2c84..b792e46 100644 --- a/inngest/_internal/step_lib/__init__.py +++ b/inngest/_internal/step_lib/__init__.py @@ -1,4 +1,4 @@ -from .base import Interrupt, StepIDCounter +from .base import Interrupt, StepIDCounter, StepMemos from .step_async import Step from .step_sync import StepSync @@ -6,5 +6,6 @@ "Interrupt", "Step", "StepIDCounter", + "StepMemos", "StepSync", ] diff --git a/inngest/_internal/step_lib/base.py b/inngest/_internal/step_lib/base.py index bb9d820..902127a 100644 --- a/inngest/_internal/step_lib/base.py +++ b/inngest/_internal/step_lib/base.py @@ -2,12 +2,53 @@ import threading -from inngest._internal import execution, transforms, types +from inngest._internal import ( + client_lib, + execution, + middleware_lib, + transforms, + types, +) + + +class StepMemos: + """ + Holds memoized step output. + """ + + def __init__(self, memos: dict[str, object]) -> None: + self._memos = memos + + def get(self, hashed_id: str) -> object: + if hashed_id in self._memos: + memo = self._memos[hashed_id] + + # Remove memo + self._memos = { + k: v for k, v in self._memos.items() if k != hashed_id + } + + return memo + + return types.EmptySentinel + + @property + def size(self) -> int: + return len(self._memos) class StepBase: - _memos: dict[str, object] - _step_id_counter: StepIDCounter + def __init__( + self, + client: client_lib.Inngest, + memos: StepMemos, + middleware: middleware_lib.MiddlewareManager, + step_id_counter: StepIDCounter, + ) -> None: + self._client = client + self._memos = memos + self._middleware = middleware + self._step_id_counter = step_id_counter def _get_hashed_id(self, step_id: str) -> str: id_count = self._step_id_counter.increment(step_id) @@ -15,14 +56,29 @@ def _get_hashed_id(self, step_id: str) -> str: step_id = f"{step_id}:{id_count - 1}" return transforms.hash_step_id(step_id) - def _get_memo(self, hashed_id: str) -> object: - if hashed_id in self._memos: - return self._memos[hashed_id] + async def get_memo(self, hashed_id: str) -> object: + memo = self._memos.get(hashed_id) - return types.EmptySentinel + if self._memos.size == 0: + await self._middleware.before_execution() + + return memo + + def get_memo_sync(self, hashed_id: str) -> object: + memo = self._memos.get(hashed_id) + + if self._memos.size == 0: + self._middleware.before_execution_sync() + + return memo class StepIDCounter: + """ + Counts the number of times a step ID has been used. We support reused step + IDs so we need a way to keep track of their counts. + """ + def __init__(self) -> None: self._counts: dict[str, int] = {} self._mutex = threading.Lock() @@ -36,10 +92,13 @@ def increment(self, hashed_id: str) -> int: return self._counts[hashed_id] -# Extend BaseException to avoid being caught by the user's code. Users can still -# catch it if they do a "bare except", but that's a known antipattern in the -# Python world. class Interrupt(BaseException): + """ + Extend BaseException to avoid being caught by the user's code. Users can + still catch it if they do a "bare except", but that's a known antipattern in + the Python world. + """ + def __init__( self, *, diff --git a/inngest/_internal/step_lib/step_async.py b/inngest/_internal/step_lib/step_async.py index 190a2c9..62fd3c8 100644 --- a/inngest/_internal/step_lib/step_async.py +++ b/inngest/_internal/step_lib/step_async.py @@ -1,29 +1,12 @@ import datetime import typing -from inngest._internal import ( - client_lib, - event_lib, - execution, - result, - transforms, - types, -) +from inngest._internal import event_lib, execution, result, transforms, types from . import base class Step(base.StepBase): - def __init__( - self, - client: client_lib.Inngest, - memos: dict[str, object], - step_id_counter: base.StepIDCounter, - ) -> None: - self._client = client - self._memos = memos - self._step_id_counter = step_id_counter - @typing.overload async def run( self, @@ -57,10 +40,12 @@ async def run( hashed_id = self._get_hashed_id(step_id) - memo = self._get_memo(hashed_id) + memo = await self.get_memo(hashed_id) if memo is not types.EmptySentinel: return memo # type: ignore + await self._middleware.before_execution() + # Ensure the output is JSON-serializable. match transforms.dump_json(await transforms.maybe_await(handler())): case result.Ok(output): @@ -122,7 +107,7 @@ async def sleep_until( hashed_id = self._get_hashed_id(step_id) - memo = self._get_memo(hashed_id) + memo = await self.get_memo(hashed_id) if memo is not types.EmptySentinel: return memo # type: ignore @@ -152,7 +137,7 @@ async def wait_for_event( hashed_id = self._get_hashed_id(step_id) - memo = self._get_memo(hashed_id) + memo = await self.get_memo(hashed_id) if memo is not types.EmptySentinel: if memo is None: # Timeout diff --git a/inngest/_internal/step_lib/step_sync.py b/inngest/_internal/step_lib/step_sync.py index 0ee9926..d4ed194 100644 --- a/inngest/_internal/step_lib/step_sync.py +++ b/inngest/_internal/step_lib/step_sync.py @@ -1,29 +1,12 @@ import datetime import typing -from inngest._internal import ( - client_lib, - event_lib, - execution, - result, - transforms, - types, -) +from inngest._internal import event_lib, execution, result, transforms, types from . import base class StepSync(base.StepBase): - def __init__( - self, - client: client_lib.Inngest, - memos: dict[str, object], - step_id_counter: base.StepIDCounter, - ) -> None: - self._client = client - self._memos = memos - self._step_id_counter = step_id_counter - def run( self, step_id: str, @@ -40,10 +23,12 @@ def run( hashed_id = self._get_hashed_id(step_id) - memo = self._get_memo(hashed_id) + memo = self.get_memo_sync(hashed_id) if memo is not types.EmptySentinel: return memo # type: ignore + self._middleware.before_execution_sync() + # Ensure the output is JSON-serializable. match transforms.dump_json(handler()): case result.Ok(output): @@ -105,7 +90,7 @@ def sleep_until( hashed_id = self._get_hashed_id(step_id) - memo = self._get_memo(hashed_id) + memo = self.get_memo_sync(hashed_id) if memo is not types.EmptySentinel: return memo # type: ignore @@ -135,7 +120,7 @@ def wait_for_event( hashed_id = self._get_hashed_id(step_id) - memo = self._get_memo(hashed_id) + memo = self.get_memo_sync(hashed_id) if memo is not types.EmptySentinel: if memo is None: # Timeout diff --git a/inngest/_internal/types.py b/inngest/_internal/types.py index a2dc90b..b8f5f80 100644 --- a/inngest/_internal/types.py +++ b/inngest/_internal/types.py @@ -1,6 +1,7 @@ from __future__ import annotations import json +import logging import typing import pydantic @@ -71,3 +72,5 @@ def to_dict(self) -> result.Result[dict[str, object], Exception]: BaseModelT = typing.TypeVar("BaseModelT", bound=BaseModel) + +Logger: typing.TypeAlias = logging.Logger | logging.LoggerAdapter diff --git a/inngest/fast_api.py b/inngest/fast_api.py index 793313c..8ad2795 100644 --- a/inngest/fast_api.py +++ b/inngest/fast_api.py @@ -27,7 +27,6 @@ def serve( client=client, framework=const.Framework.FAST_API, functions=functions, - logger=client.logger, signing_key=signing_key, ) diff --git a/inngest/flask.py b/inngest/flask.py index a449603..73c23d7 100644 --- a/inngest/flask.py +++ b/inngest/flask.py @@ -28,7 +28,6 @@ def serve( client=client, framework=const.Framework.FLASK, functions=functions, - logger=app.logger, signing_key=signing_key, ) diff --git a/inngest/tornado.py b/inngest/tornado.py index 9a9ae70..98f09fd 100644 --- a/inngest/tornado.py +++ b/inngest/tornado.py @@ -29,7 +29,6 @@ def serve( client=client, framework=const.Framework.TORNADO, functions=functions, - logger=client.logger, signing_key=signing_key, ) diff --git a/tests/cases/__init__.py b/tests/cases/__init__.py index 0faa306..ab6c65c 100644 --- a/tests/cases/__init__.py +++ b/tests/cases/__init__.py @@ -5,6 +5,7 @@ debounce, event_payload, function_args, + logger, middleware, no_steps, on_failure, @@ -24,6 +25,7 @@ def create_cases(framework: str) -> list[base.Case]: debounce, event_payload, function_args, + logger, middleware, no_steps, on_failure, @@ -47,6 +49,7 @@ def create_cases_sync(framework: str) -> list[base.Case]: debounce, event_payload, function_args, + logger, middleware, no_steps, on_failure, diff --git a/tests/cases/function_args.py b/tests/cases/function_args.py index 97ad3aa..9db1a1c 100644 --- a/tests/cases/function_args.py +++ b/tests/cases/function_args.py @@ -1,5 +1,3 @@ -import logging - import inngest import tests.helper @@ -33,7 +31,7 @@ def fn_sync( attempt: int, event: inngest.Event, events: list[inngest.Event], - logger: logging.Logger, + logger: inngest.Logger, run_id: str, step: inngest.StepSync, ) -> None: @@ -53,7 +51,7 @@ async def fn_async( attempt: int, event: inngest.Event, events: list[inngest.Event], - logger: logging.Logger, + logger: inngest.Logger, run_id: str, step: inngest.Step, ) -> None: diff --git a/tests/cases/logger.py b/tests/cases/logger.py new file mode 100644 index 0000000..140c74f --- /dev/null +++ b/tests/cases/logger.py @@ -0,0 +1,119 @@ +import logging + +import inngest +import tests.helper + +from . import base + +_TEST_NAME = "logger" + + +class StatefulLogger(logging.Logger): + """ + Fake logger that stores calls to its methods. We can use this to assert that + logger methods are properly called (e.g. no duplicates). + """ + + def __init__(self) -> None: + super().__init__("test") + self.info_calls: list[object] = [] + + def info(self, msg: object, *args: object, **kwargs: object) -> None: + self.info_calls.append(msg) + + +def create( + framework: str, + is_sync: bool, +) -> base.Case: + test_name = base.create_test_name(_TEST_NAME, is_sync) + event_name = base.create_event_name(framework, test_name, is_sync) + state = base.BaseState() + + _logger = StatefulLogger() + + @inngest.create_function( + fn_id=test_name, + retries=0, + trigger=inngest.TriggerEvent(event=event_name), + ) + def fn_sync( + *, + logger: inngest.Logger, + step: inngest.StepSync, + run_id: str, + **_kwargs: object, + ) -> None: + logger.info("function start") + state.run_id = run_id + + def _first_step() -> None: + logger.info("first_step") + + step.run("first_step", _first_step) + + logger.info("between steps") + + def _second_step() -> None: + logger.info("second_step") + + step.run("second_step", _second_step) + logger.info("function end") + + @inngest.create_function( + fn_id=test_name, + retries=0, + trigger=inngest.TriggerEvent(event=event_name), + ) + async def fn_async( + *, + logger: inngest.Logger, + step: inngest.Step, + run_id: str, + **_kwargs: object, + ) -> None: + logger.info("function start") + state.run_id = run_id + + def _first_step() -> None: + logger.info("first_step") + + await step.run("first_step", _first_step) + + logger.info("between steps") + + def _second_step() -> None: + logger.info("second_step") + + await step.run("second_step", _second_step) + logger.info("function end") + + def run_test(self: base.TestClass) -> None: + self.client.set_logger(_logger) + self.client.send_sync(inngest.Event(name=event_name)) + run_id = state.wait_for_run_id() + tests.helper.client.wait_for_run_status( + run_id, + tests.helper.RunStatus.COMPLETED, + ) + + assert _logger.info_calls == [ + "function start", + "first_step", + "between steps", + "second_step", + "function end", + ], _logger.info_calls + + if is_sync: + fn = fn_sync + else: + fn = fn_async + + return base.Case( + event_name=event_name, + fn=fn, + run_test=run_test, + state=state, + name=test_name, + ) diff --git a/tests/cases/middleware.py b/tests/cases/middleware.py index f4c83fa..e54c055 100644 --- a/tests/cases/middleware.py +++ b/tests/cases/middleware.py @@ -1,11 +1,10 @@ -import logging -import unittest.mock +import typing import inngest import tests.helper # TODO: Remove when middleware is ready for external use. -from inngest._internal import middleware_lib +from inngest._internal import execution, middleware_lib, types from . import base @@ -25,24 +24,27 @@ def create( event_name = base.create_event_name(framework, test_name, is_sync) state = _State() - _logger = unittest.mock.Mock() - - middleware: middleware_lib.Middleware | middleware_lib.MiddlewareSync + middleware: typing.Type[ + middleware_lib.Middleware | middleware_lib.MiddlewareSync + ] if is_sync: class _MiddlewareSync(middleware_lib.MiddlewareSync): - def after_function_execution(self) -> None: - state.hook_list.append("after_function_execution") + def after_execution(self) -> None: + state.hook_list.append("after_execution") def before_response(self) -> None: state.hook_list.append("before_response") - def before_function_execution(self) -> None: - state.hook_list.append("before_function_execution") + def before_execution(self) -> None: + state.hook_list.append("before_execution") - def transform_input(self) -> inngest.CallInputTransform: + def transform_input( + self, + call_input: execution.TransformableCallInput, + ) -> execution.TransformableCallInput: state.hook_list.append("transform_input") - return inngest.CallInputTransform(logger=_logger) + return call_input def transform_output( self, @@ -51,23 +53,26 @@ def transform_output( state.hook_list.append("transform_output") return output - middleware = _MiddlewareSync() + middleware = _MiddlewareSync else: class _MiddlewareAsync(middleware_lib.Middleware): - async def after_function_execution(self) -> None: - state.hook_list.append("after_function_execution") + async def after_execution(self) -> None: + state.hook_list.append("after_execution") async def before_response(self) -> None: state.hook_list.append("before_response") - async def before_function_execution(self) -> None: - state.hook_list.append("before_function_execution") + async def before_execution(self) -> None: + state.hook_list.append("before_execution") - async def transform_input(self) -> inngest.CallInputTransform: + async def transform_input( + self, + call_input: execution.TransformableCallInput, + ) -> execution.TransformableCallInput: state.hook_list.append("transform_input") - return inngest.CallInputTransform(logger=_logger) + return call_input async def transform_output( self, @@ -76,7 +81,7 @@ async def transform_output( state.hook_list.append("transform_output") return output - middleware = _MiddlewareAsync() + middleware = _MiddlewareAsync @inngest.create_function( fn_id=test_name, @@ -85,11 +90,12 @@ async def transform_output( ) def fn_sync( *, - logger: logging.Logger, + logger: types.Logger, step: inngest.StepSync, run_id: str, **_kwargs: object, ) -> None: + logger.info("function start") state.run_id = run_id def _first_step() -> None: @@ -97,10 +103,13 @@ def _first_step() -> None: step.run("first_step", _first_step) + logger.info("between steps") + def _second_step() -> None: logger.info("second_step") step.run("second_step", _second_step) + logger.info("function end") @inngest.create_function( fn_id=test_name, @@ -109,11 +118,12 @@ def _second_step() -> None: ) async def fn_async( *, - logger: logging.Logger, + logger: types.Logger, step: inngest.Step, run_id: str, **_kwargs: object, ) -> None: + logger.info("function start") state.run_id = run_id def _first_step() -> None: @@ -121,13 +131,16 @@ def _first_step() -> None: await step.run("first_step", _first_step) + logger.info("between steps") + def _second_step() -> None: logger.info("second_step") await step.run("second_step", _second_step) + logger.info("function end") def run_test(self: base.TestClass) -> None: - self.client.middleware.add(middleware) + self.client.add_middleware(middleware) self.client.send_sync(inngest.Event(name=event_name)) run_id = state.wait_for_run_id() tests.helper.client.wait_for_run_status( @@ -137,23 +150,24 @@ def run_test(self: base.TestClass) -> None: # Assert that the middleware hooks were called in the correct order assert state.hook_list == [ - "before_function_execution", + # Entry 1 "transform_input", + "before_execution", "transform_output", - "before_response", # first_step done + "before_response", + # Entry 2 "transform_input", + "before_execution", "transform_output", - "before_response", # second_step done + "before_response", + # Entry 3 "transform_input", + "before_execution", "transform_output", - "after_function_execution", - "before_response", # Function done + "after_execution", + "before_response", ], state.hook_list - # Assert that the middleware was able to transform the input - _logger.info.assert_any_call("first_step") - _logger.info.assert_any_call("second_step") - if is_sync: fn = fn_sync else: diff --git a/tests/cases/on_failure.py b/tests/cases/on_failure.py index 5aefbc6..14f4883 100644 --- a/tests/cases/on_failure.py +++ b/tests/cases/on_failure.py @@ -1,5 +1,3 @@ -import logging - import inngest import tests.helper @@ -37,7 +35,7 @@ def on_failure_sync( attempt: int, event: inngest.Event, events: list[inngest.Event], - logger: logging.Logger, + logger: inngest.Logger, run_id: str, step: inngest.StepSync, ) -> None: @@ -52,7 +50,7 @@ async def on_failure_async( attempt: int, event: inngest.Event, events: list[inngest.Event], - logger: logging.Logger, + logger: inngest.Logger, run_id: str, step: inngest.Step, ) -> None: