From 2c5cab801d838e642cdea3f1115772c95aec3d52 Mon Sep 17 00:00:00 2001 From: Aaron Harper Date: Tue, 31 Oct 2023 22:03:33 -0400 Subject: [PATCH] Unify create_function and create_function_sync --- examples/functions/batch.py | 3 +- examples/functions/cancel.py | 2 +- examples/functions/debounce.py | 2 +- examples/functions/duplicate_step_name.py | 7 +- examples/functions/error_step.py | 3 +- examples/functions/no_steps.py | 3 +- examples/functions/on_failure.py | 2 +- examples/functions/print_event.py | 4 +- examples/functions/send_event.py | 3 +- examples/functions/two_steps_and_sleep.py | 3 +- examples/functions/wait_for_event.py | 3 +- inngest/__init__.py | 13 +- inngest/_internal/comm.py | 49 +-- inngest/_internal/comm_test.py | 4 +- inngest/_internal/errors.py | 10 + inngest/_internal/function.py | 347 +++++++++++++++++++ inngest/_internal/function/__init__.py | 13 - inngest/_internal/function/base.py | 117 ------- inngest/_internal/function/function_async.py | 135 -------- inngest/_internal/function/function_sync.py | 143 -------- inngest/_internal/types.py | 4 + inngest/flask.py | 2 +- inngest/tornado.py | 2 +- tests/cases/base.py | 7 +- tests/cases/cancel.py | 4 +- tests/cases/client_send.py | 4 +- tests/cases/debounce.py | 4 +- tests/cases/event_payload.py | 4 +- tests/cases/function_args.py | 4 +- tests/cases/no_steps.py | 4 +- tests/cases/on_failure.py | 4 +- tests/cases/sleep.py | 4 +- tests/cases/sleep_until.py | 4 +- tests/cases/two_steps.py | 4 +- tests/cases/unserializable_step_output.py | 4 +- tests/cases/wait_for_event_fulfill.py | 4 +- tests/cases/wait_for_event_timeout.py | 4 +- tests/test_fast_api.py | 2 +- tests/test_flask.py | 4 +- 39 files changed, 438 insertions(+), 501 deletions(-) create mode 100644 inngest/_internal/function.py delete mode 100644 inngest/_internal/function/__init__.py delete mode 100644 inngest/_internal/function/base.py delete mode 100644 inngest/_internal/function/function_async.py delete mode 100644 inngest/_internal/function/function_sync.py diff --git a/examples/functions/batch.py b/examples/functions/batch.py index e547d4f4..f3cddbb2 100644 --- a/examples/functions/batch.py +++ b/examples/functions/batch.py @@ -3,13 +3,12 @@ import inngest -@inngest.create_function_sync( +@inngest.create_function( batch_events=inngest.Batch( max_size=2, timeout=datetime.timedelta(minutes=1), ), fn_id="batch", - name="Batch", trigger=inngest.TriggerEvent(event="app/batch"), ) def fn_sync( diff --git a/examples/functions/cancel.py b/examples/functions/cancel.py index a0d68cba..f942bfef 100644 --- a/examples/functions/cancel.py +++ b/examples/functions/cancel.py @@ -3,7 +3,7 @@ import inngest -@inngest.create_function_sync( +@inngest.create_function( cancel=[inngest.Cancel(event="app/cancel.cancel")], fn_id="cancel", trigger=inngest.TriggerEvent(event="app/cancel"), diff --git a/examples/functions/debounce.py b/examples/functions/debounce.py index 0a6440a4..7d0db5c9 100644 --- a/examples/functions/debounce.py +++ b/examples/functions/debounce.py @@ -3,7 +3,7 @@ import inngest -@inngest.create_function_sync( +@inngest.create_function( debounce=inngest.Debounce(period=datetime.timedelta(seconds=5)), fn_id="debounce", trigger=inngest.TriggerEvent(event="app/debounce"), diff --git a/examples/functions/duplicate_step_name.py b/examples/functions/duplicate_step_name.py index 37d3e167..1b86d515 100644 --- a/examples/functions/duplicate_step_name.py +++ b/examples/functions/duplicate_step_name.py @@ -1,10 +1,9 @@ import inngest -@inngest.create_function_sync( - fn_id="duplicate_step_name", - name="Duplicate step name", - trigger=inngest.TriggerEvent(event="app/duplicate_step_name"), +@inngest.create_function( + fn_id="duplicate_step_name_sync", + trigger=inngest.TriggerEvent(event="app/duplicate_step_name_sync"), ) def fn_sync(*, step: inngest.StepSync, **_kwargs: object) -> str: for _ in range(3): diff --git a/examples/functions/error_step.py b/examples/functions/error_step.py index 11616931..f4faa6d3 100644 --- a/examples/functions/error_step.py +++ b/examples/functions/error_step.py @@ -5,9 +5,8 @@ class MyError(Exception): pass -@inngest.create_function_sync( +@inngest.create_function( fn_id="error_step", - name="Error step", retries=0, trigger=inngest.TriggerEvent(event="app/error_step"), ) diff --git a/examples/functions/no_steps.py b/examples/functions/no_steps.py index 628d80b3..ea8e61a1 100644 --- a/examples/functions/no_steps.py +++ b/examples/functions/no_steps.py @@ -1,9 +1,8 @@ import inngest -@inngest.create_function_sync( +@inngest.create_function( fn_id="no_steps", - name="No steps", trigger=inngest.TriggerEvent(event="app/no_steps"), ) def fn_sync(**_kwargs: object) -> int: diff --git a/examples/functions/on_failure.py b/examples/functions/on_failure.py index 20f4b13a..327f6bd4 100644 --- a/examples/functions/on_failure.py +++ b/examples/functions/on_failure.py @@ -11,7 +11,7 @@ def _on_failure( print("on_failure called") -@inngest.create_function_sync( +@inngest.create_function( fn_id="on_failure", on_failure=_on_failure, retries=0, diff --git a/examples/functions/print_event.py b/examples/functions/print_event.py index 71c9765a..a2da54e2 100644 --- a/examples/functions/print_event.py +++ b/examples/functions/print_event.py @@ -1,9 +1,8 @@ import inngest -@inngest.create_function_sync( +@inngest.create_function( fn_id="print_event", - name="Print event", trigger=inngest.TriggerEvent(event="app/print_event"), ) def fn_sync( @@ -24,7 +23,6 @@ def _print_user() -> dict[str, object]: @inngest.create_function( fn_id="print_event_async", - name="Print event (async)", trigger=inngest.TriggerEvent(event="app/print_event_async"), ) async def fn( diff --git a/examples/functions/send_event.py b/examples/functions/send_event.py index c94a01ea..d0b6d3e7 100644 --- a/examples/functions/send_event.py +++ b/examples/functions/send_event.py @@ -1,9 +1,8 @@ import inngest -@inngest.create_function_sync( +@inngest.create_function( fn_id="send_event", - name="Send event", trigger=inngest.TriggerEvent(event="app/send_event"), ) def fn_sync(*, step: inngest.StepSync, **_kwargs: object) -> None: diff --git a/examples/functions/two_steps_and_sleep.py b/examples/functions/two_steps_and_sleep.py index 2d060c05..5bf4ae72 100644 --- a/examples/functions/two_steps_and_sleep.py +++ b/examples/functions/two_steps_and_sleep.py @@ -3,9 +3,8 @@ import inngest -@inngest.create_function_sync( +@inngest.create_function( fn_id="two_steps_and_sleep", - name="Two steps and sleep", trigger=inngest.TriggerEvent(event="app/two_steps_and_sleep"), ) def fn_sync(*, step: inngest.StepSync, **_kwargs: object) -> str: diff --git a/examples/functions/wait_for_event.py b/examples/functions/wait_for_event.py index b6873d25..272bb151 100644 --- a/examples/functions/wait_for_event.py +++ b/examples/functions/wait_for_event.py @@ -3,9 +3,8 @@ import inngest -@inngest.create_function_sync( +@inngest.create_function( fn_id="wait_for_event", - name="wait_for_event", trigger=inngest.TriggerEvent(event="app/wait_for_event"), ) def fn_sync(*, step: inngest.StepSync, **_kwargs: object) -> None: diff --git a/inngest/__init__.py b/inngest/__init__.py index 29acc124..ec1adc26 100644 --- a/inngest/__init__.py +++ b/inngest/__init__.py @@ -1,14 +1,7 @@ from ._internal.client_lib import Inngest from ._internal.errors import NonRetriableError from ._internal.event_lib import Event -from ._internal.function import ( - Function, - FunctionOpts, - FunctionOptsSync, - FunctionSync, - create_function, - create_function_sync, -) +from ._internal.function import Function, create_function from ._internal.function_config import ( Batch, Cancel, @@ -26,9 +19,6 @@ "Debounce", "Event", "Function", - "FunctionOpts", - "FunctionOptsSync", - "FunctionSync", "Inngest", "NonRetriableError", "RateLimit", @@ -38,5 +28,4 @@ "TriggerCron", "TriggerEvent", "create_function", - "create_function_sync", ] diff --git a/inngest/_internal/comm.py b/inngest/_internal/comm.py index 6438d88f..3ae691e8 100644 --- a/inngest/_internal/comm.py +++ b/inngest/_internal/comm.py @@ -75,7 +75,7 @@ def from_error( class CommHandler: _base_url: str _client: client_lib.Inngest - _fns: dict[str, function.Function | function.FunctionSync] + _fns: dict[str, function.Function] _framework: const.Framework _is_production: bool _logger: logging.Logger @@ -87,7 +87,7 @@ def __init__( base_url: str | None = None, client: client_lib.Inngest, framework: const.Framework, - functions: list[function.Function] | list[function.FunctionSync], + functions: list[function.Function], logger: logging.Logger, signing_key: str | None = None, ) -> None: @@ -190,12 +190,6 @@ async def call_function( case result.Err(err): return CommResponse.from_error(err, self._framework) - if not isinstance(fn, function.Function): - return CommResponse.from_error( - errors.MismatchedSync(f"function {fn_id} is not asynchronous"), - self._framework, - ) - return self._create_response(await fn.call(call, self._client, fn_id)) def call_function_sync( @@ -211,23 +205,24 @@ def call_function_sync( validation_res = req_sig.validate(self._signing_key) if result.is_err(validation_res): - return CommResponse.from_error( - validation_res.err_value, self._framework - ) + err = validation_res.err_value + extra = {} + if isinstance(err, errors.InternalError): + extra["code"] = err.code + self._logger.error(err, extra=extra) + return CommResponse.from_error(err, self._framework) match self._get_function(fn_id): case result.Ok(fn): pass case result.Err(err): + extra = {} + if isinstance(err, errors.InternalError): + extra["code"] = err.code + self._logger.error(err, extra=extra) return CommResponse.from_error(err, self._framework) - if not isinstance(fn, function.FunctionSync): - return CommResponse.from_error( - errors.MismatchedSync(f"function {fn_id} is not asynchronous"), - self._framework, - ) - - return self._create_response(fn.call(call, self._client, fn_id)) + return self._create_response(fn.call_sync(call, self._client, fn_id)) def _create_response( self, @@ -252,7 +247,17 @@ def _create_response( comm_res.body = transforms.prep_body(out) comm_res.status_code = 206 elif isinstance(call_res, execution.CallError): - comm_res.body = transforms.prep_body(call_res.model_dump()) + match call_res.to_dict(): + case result.Ok(d): + body = transforms.prep_body(d) + case result.Err(err): + return CommResponse.from_error(err, self._framework) + + self._logger.error( + call_res.message, + extra={"is_internal": call_res.is_internal}, + ) + comm_res.body = body comm_res.status_code = 500 if call_res.is_retriable is False: @@ -264,7 +269,7 @@ def _create_response( def _get_function( self, fn_id: str - ) -> result.Result[function.Function | function.FunctionSync, Exception]: + ) -> result.Result[function.Function, Exception]: # Look for the function ID in the list of user functions, but also # look for it in the list of on_failure functions. for _fn in self._fns.values(): @@ -370,6 +375,8 @@ async def register( case result.Ok(_): pass case result.Err(err): + print(err) + self._logger.error(err) return CommResponse.from_error(err, self._framework) async with httpx.AsyncClient() as client: @@ -379,6 +386,8 @@ async def register( await client.send(req) ) case result.Err(err): + print(err) + self._logger.error(err) return CommResponse.from_error(err, self._framework) return res diff --git a/inngest/_internal/comm_test.py b/inngest/_internal/comm_test.py index cbc119f0..af7df3c1 100644 --- a/inngest/_internal/comm_test.py +++ b/inngest/_internal/comm_test.py @@ -25,7 +25,7 @@ def test_full_config(self) -> None: fully-specified config. """ - @inngest.create_function_sync( + @inngest.create_function( batch_events=inngest.Batch( max_size=2, timeout=datetime.timedelta(minutes=1) ), @@ -63,7 +63,7 @@ def fn(**_kwargs: object) -> int: assert False, f"Unexpected error: {err}" def test_no_functions(self) -> None: - functions: list[inngest.FunctionSync] = [] + functions: list[inngest.Function] = [] handler = comm.CommHandler( base_url="http://foo.bar", diff --git a/inngest/_internal/errors.py b/inngest/_internal/errors.py index e285aa9a..1ed99a7e 100644 --- a/inngest/_internal/errors.py +++ b/inngest/_internal/errors.py @@ -161,6 +161,16 @@ def __init__(self, message: str | None = None) -> None: ) +class UnknownError(InternalError): + status_code: int = http.HTTPStatus.INTERNAL_SERVER_ERROR + + def __init__(self, message: str | None = None) -> None: + super().__init__( + code=const.ErrorCode.UNKNOWN, + message=message, + ) + + class UnserializableOutput(InternalError): status_code: int = http.HTTPStatus.INTERNAL_SERVER_ERROR diff --git a/inngest/_internal/function.py b/inngest/_internal/function.py new file mode 100644 index 00000000..13a07c8c --- /dev/null +++ b/inngest/_internal/function.py @@ -0,0 +1,347 @@ +from __future__ import annotations + +import dataclasses +import hashlib +import inspect +import json +import typing + +import pydantic + +from inngest._internal import ( + client_lib, + const, + errors, + event_lib, + execution, + function_config, + step_lib, + types, +) + + +@dataclasses.dataclass +class _Config: + # The user-defined function + main: function_config.FunctionConfig + + # The internal on_failure function + on_failure: function_config.FunctionConfig | None + + +@typing.runtime_checkable +class FunctionHandlerAsync(typing.Protocol): + def __call__( + self, + *, + attempt: int, + event: event_lib.Event, + events: list[event_lib.Event], + run_id: str, + step: step_lib.Step, + ) -> typing.Awaitable[types.JSONSerializableOutput]: + ... + + +@typing.runtime_checkable +class FunctionHandlerSync(typing.Protocol): + def __call__( + self, + *, + attempt: int, + event: event_lib.Event, + events: list[event_lib.Event], + run_id: str, + step: step_lib.StepSync, + ) -> types.JSONSerializableOutput: + ... + + +def _is_function_handler_async( + value: FunctionHandlerAsync | FunctionHandlerSync, +) -> typing.TypeGuard[FunctionHandlerAsync]: + return inspect.iscoroutinefunction(value) + + +def _is_function_handler_sync( + value: FunctionHandlerAsync | FunctionHandlerSync, +) -> typing.TypeGuard[FunctionHandlerSync]: + return not inspect.iscoroutinefunction(value) + + +class FunctionOpts(types.BaseModel): + model_config = pydantic.ConfigDict(arbitrary_types_allowed=True) + + batch_events: function_config.Batch | None = None + cancel: list[function_config.Cancel] | None = None + debounce: function_config.Debounce | None = None + id: str + name: str | None = None + on_failure: FunctionHandlerAsync | FunctionHandlerSync | None = None + rate_limit: function_config.RateLimit | None = None + retries: int | None = None + throttle: function_config.Throttle | None = None + + def convert_validation_error( + self, + err: pydantic.ValidationError, + ) -> BaseException: + return errors.InvalidConfig.from_validation_error(err) + + +def create_function( + *, + batch_events: function_config.Batch | None = None, + cancel: list[function_config.Cancel] | None = None, + debounce: function_config.Debounce | None = None, + fn_id: str, + name: str | None = None, + on_failure: FunctionHandlerAsync | FunctionHandlerSync | None = None, + rate_limit: function_config.RateLimit | None = None, + retries: int | None = None, + throttle: function_config.Throttle | None = None, + trigger: function_config.TriggerCron | function_config.TriggerEvent, +) -> typing.Callable[[FunctionHandlerAsync | FunctionHandlerSync], Function]: + def decorator(func: FunctionHandlerAsync | FunctionHandlerSync) -> Function: + return Function( + FunctionOpts( + batch_events=batch_events, + cancel=cancel, + debounce=debounce, + id=fn_id, + name=name, + on_failure=on_failure, + rate_limit=rate_limit, + retries=retries, + throttle=throttle, + ), + trigger, + func, + ) + + return decorator + + +class Function: + _handler: FunctionHandlerAsync | FunctionHandlerSync + _on_failure_fn_id: str | None = None + _opts: FunctionOpts + _trigger: function_config.TriggerCron | function_config.TriggerEvent + + @property + def id(self) -> str: + return self._opts.id + + @property + def is_handler_async(self) -> bool: + return _is_function_handler_async(self._handler) + + @property + def on_failure_fn_id(self) -> str | None: + return self._on_failure_fn_id + + def __init__( + self, + opts: FunctionOpts, + trigger: function_config.TriggerCron | function_config.TriggerEvent, + handler: FunctionHandlerAsync | FunctionHandlerSync, + ) -> None: + self._handler = handler + self._opts = opts + self._trigger = trigger + + if opts.on_failure is not None: + # Create a random suffix to avoid collisions with the main + # function's ID. + suffix = hashlib.sha1(opts.id.encode("utf-8")).hexdigest()[:8] + + self._on_failure_fn_id = f"{opts.id}-{suffix}" + + async def call( + self, + call: execution.Call, + client: client_lib.Inngest, + fn_id: str, + ) -> list[execution.CallResponse] | str | execution.CallError: + try: + handler: FunctionHandlerAsync | FunctionHandlerSync + if self.id == fn_id: + handler = self._handler + elif self.on_failure_fn_id == fn_id: + if self._opts.on_failure is None: + return execution.CallError.from_error( + errors.MissingFunction("on_failure not defined") + ) + handler = self._opts.on_failure + else: + return execution.CallError.from_error( + errors.MissingFunction("function ID mismatch") + ) + + # Determine whether the handler is async (i.e. if we need to await + # it). Sync functions are OK in async contexts, so it's OK if the + # handler is sync. + if _is_function_handler_async(handler): + res = await handler( + attempt=call.ctx.attempt, + event=call.event, + events=call.events, + run_id=call.ctx.run_id, + step=step_lib.Step( + client, + call.steps, + step_lib.StepIDCounter(), + ), + ) + elif _is_function_handler_sync(handler): + res = handler( + attempt=call.ctx.attempt, + event=call.event, + events=call.events, + run_id=call.ctx.run_id, + step=step_lib.StepSync( + client, + call.steps, + step_lib.StepIDCounter(), + ), + ) + else: + # Should be unreachable. + return execution.CallError.from_error( + errors.UnknownError( + "unable to determine function handler type" + ) + ) + + return json.dumps(res) + except step_lib.Interrupt as out: + return [ + execution.CallResponse( + data=out.data, + display_name=out.display_name, + id=out.hashed_id, + name=out.name, + op=out.op, + opts=out.opts, + ) + ] + except Exception as err: + return execution.CallError.from_error(err) + + def call_sync( + self, + call: execution.Call, + client: client_lib.Inngest, + fn_id: str, + ) -> list[execution.CallResponse] | str | execution.CallError: + try: + handler: FunctionHandlerAsync | FunctionHandlerSync + if self.id == fn_id: + handler = self._handler + elif self.on_failure_fn_id == fn_id: + if self._opts.on_failure is None: + return execution.CallError.from_error( + errors.MissingFunction("on_failure not defined") + ) + handler = self._opts.on_failure + else: + return execution.CallError.from_error( + errors.MissingFunction("function ID mismatch") + ) + + if _is_function_handler_sync(handler): + res = handler( + attempt=call.ctx.attempt, + event=call.event, + events=call.events, + run_id=call.ctx.run_id, + step=step_lib.StepSync( + client, + call.steps, + step_lib.StepIDCounter(), + ), + ) + + return json.dumps(res) + + return execution.CallError.from_error( + errors.MismatchedSync( + "encountered async function in non-async context" + ) + ) + except step_lib.Interrupt as out: + return [ + execution.CallResponse( + data=out.data, + display_name=out.display_name, + id=out.hashed_id, + name=out.name, + op=out.op, + opts=out.opts, + ) + ] + except Exception as err: + return execution.CallError.from_error(err) + + def get_config(self, app_url: str) -> _Config: + fn_id = self._opts.id + + name = fn_id + if self._opts.name is not None: + name = self._opts.name + + if self._opts.retries is not None: + retries = function_config.Retries(attempts=self._opts.retries) + else: + retries = None + + main = function_config.FunctionConfig( + batch_events=self._opts.batch_events, + cancel=self._opts.cancel, + debounce=self._opts.debounce, + id=fn_id, + name=name, + rate_limit=self._opts.rate_limit, + steps={ + const.ROOT_STEP_ID: function_config.Step( + id=const.ROOT_STEP_ID, + name=const.ROOT_STEP_ID, + retries=retries, + runtime=function_config.Runtime( + type="http", + url=f"{app_url}?fnId={fn_id}&stepId={const.ROOT_STEP_ID}", + ), + ), + }, + throttle=self._opts.throttle, + triggers=[self._trigger], + ) + + on_failure = None + if self.on_failure_fn_id is not None: + on_failure = function_config.FunctionConfig( + id=self.on_failure_fn_id, + name=f"{name} (on_failure handler)", + steps={ + const.ROOT_STEP_ID: function_config.Step( + id=const.ROOT_STEP_ID, + name=const.ROOT_STEP_ID, + retries=function_config.Retries(attempts=0), + runtime=function_config.Runtime( + type="http", + url=f"{app_url}?fnId={self.on_failure_fn_id}&stepId={const.ROOT_STEP_ID}", + ), + ) + }, + triggers=[ + function_config.TriggerEvent( + event=const.InternalEvents.FUNCTION_FAILED.value, + expression=f"event.data.function_id == '{self.id}'", + ) + ], + ) + + return _Config(main=main, on_failure=on_failure) + + def get_id(self) -> str: + return self._opts.id diff --git a/inngest/_internal/function/__init__.py b/inngest/_internal/function/__init__.py deleted file mode 100644 index 2edb4142..00000000 --- a/inngest/_internal/function/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -from .base import FunctionBase -from .function_async import Function, FunctionOpts, create_function -from .function_sync import FunctionOptsSync, FunctionSync, create_function_sync - -__all__ = [ - "Function", - "FunctionBase", - "FunctionOpts", - "FunctionOptsSync", - "FunctionSync", - "create_function", - "create_function_sync", -] diff --git a/inngest/_internal/function/base.py b/inngest/_internal/function/base.py deleted file mode 100644 index ba755c77..00000000 --- a/inngest/_internal/function/base.py +++ /dev/null @@ -1,117 +0,0 @@ -from __future__ import annotations - -import dataclasses -import typing - -import pydantic - -from inngest._internal import const, errors, function_config, types - -FunctionHandlerT = typing.TypeVar("FunctionHandlerT") - - -class FunctionBase(typing.Generic[FunctionHandlerT]): - _handler: FunctionHandlerT - _on_failure_fn_id: str | None = None - _opts: FunctionOptsBase[FunctionHandlerT] - _trigger: function_config.TriggerCron | function_config.TriggerEvent - - @property - def id(self) -> str: - return self._opts.id - - @property - def on_failure_fn_id(self) -> str | None: - return self._on_failure_fn_id - - def get_config(self, app_url: str) -> _Config: - fn_id = self._opts.id - - name = fn_id - if self._opts.name is not None: - name = self._opts.name - - if self._opts.retries is not None: - retries = function_config.Retries(attempts=self._opts.retries) - else: - retries = None - - main = function_config.FunctionConfig( - batch_events=self._opts.batch_events, - cancel=self._opts.cancel, - debounce=self._opts.debounce, - id=fn_id, - name=name, - rate_limit=self._opts.rate_limit, - steps={ - const.ROOT_STEP_ID: function_config.Step( - id=const.ROOT_STEP_ID, - name=const.ROOT_STEP_ID, - retries=retries, - runtime=function_config.Runtime( - type="http", - url=f"{app_url}?fnId={fn_id}&stepId={const.ROOT_STEP_ID}", - ), - ), - }, - throttle=self._opts.throttle, - triggers=[self._trigger], - ) - - on_failure = None - if self.on_failure_fn_id is not None: - on_failure = function_config.FunctionConfig( - id=self.on_failure_fn_id, - name=f"{name} (on_failure handler)", - steps={ - const.ROOT_STEP_ID: function_config.Step( - id=const.ROOT_STEP_ID, - name=const.ROOT_STEP_ID, - retries=function_config.Retries(attempts=0), - runtime=function_config.Runtime( - type="http", - url=f"{app_url}?fnId={self.on_failure_fn_id}&stepId={const.ROOT_STEP_ID}", - ), - ) - }, - triggers=[ - function_config.TriggerEvent( - event=const.InternalEvents.FUNCTION_FAILED.value, - expression=f"event.data.function_id == '{self.id}'", - ) - ], - ) - - return _Config(main=main, on_failure=on_failure) - - def get_id(self) -> str: - return self._opts.id - - -class FunctionOptsBase(types.BaseModel, typing.Generic[FunctionHandlerT]): - model_config = pydantic.ConfigDict(arbitrary_types_allowed=True) - - batch_events: function_config.Batch | None = None - cancel: list[function_config.Cancel] | None = None - debounce: function_config.Debounce | None = None - id: str - name: str | None = None - on_failure: FunctionHandlerT | None = None - rate_limit: function_config.RateLimit | None = None - retries: int | None = None - throttle: function_config.Throttle | None = None - - def convert_validation_error( - self, - err: pydantic.ValidationError, - ) -> BaseException: - return errors.InvalidConfig.from_validation_error(err) - - -@dataclasses.dataclass -class _Config: - # The user-defined function - main: function_config.FunctionConfig - - # The internal on_failure function - on_failure: function_config.FunctionConfig | None diff --git a/inngest/_internal/function/function_async.py b/inngest/_internal/function/function_async.py deleted file mode 100644 index 9ab7a88e..00000000 --- a/inngest/_internal/function/function_async.py +++ /dev/null @@ -1,135 +0,0 @@ -from __future__ import annotations - -import hashlib -import json -import typing - -from inngest._internal import ( - client_lib, - errors, - event_lib, - execution, - function_config, - step_lib, -) - -from . import base - - -@typing.runtime_checkable -class _FunctionHandler(typing.Protocol): - def __call__( - self, - *, - attempt: int, - event: event_lib.Event, - events: list[event_lib.Event], - run_id: str, - step: step_lib.Step, - ) -> typing.Awaitable[object]: - ... - - -class FunctionOpts(base.FunctionOptsBase[_FunctionHandler]): - pass - - -def create_function( - *, - batch_events: function_config.Batch | None = None, - cancel: list[function_config.Cancel] | None = None, - debounce: function_config.Debounce | None = None, - fn_id: str, - name: str | None = None, - on_failure: _FunctionHandler | None = None, - rate_limit: function_config.RateLimit | None = None, - retries: int | None = None, - throttle: function_config.Throttle | None = None, - trigger: function_config.TriggerCron | function_config.TriggerEvent, -) -> typing.Callable[[_FunctionHandler], Function]: - def decorator(func: _FunctionHandler) -> Function: - return Function( - FunctionOpts( - batch_events=batch_events, - cancel=cancel, - debounce=debounce, - id=fn_id, - name=name, - on_failure=on_failure, - rate_limit=rate_limit, - retries=retries, - throttle=throttle, - ), - trigger, - func, - ) - - return decorator - - -class Function(base.FunctionBase[_FunctionHandler]): - def __init__( - self, - opts: FunctionOpts, - trigger: function_config.TriggerCron | function_config.TriggerEvent, - handler: _FunctionHandler, - ) -> None: - self._handler = handler - self._opts = opts - self._trigger = trigger - - if opts.on_failure is not None: - # Create a random suffix to avoid collisions with the main - # function's ID. - suffix = hashlib.sha1(opts.id.encode("utf-8")).hexdigest()[:8] - - self._on_failure_fn_id = f"{opts.id}-{suffix}" - - async def call( - self, - call: execution.Call, - client: client_lib.Inngest, - fn_id: str, - ) -> list[execution.CallResponse] | str | execution.CallError: - try: - handler: _FunctionHandler - - if self.id == fn_id: - handler = self._handler - elif self.on_failure_fn_id == fn_id: - if self._opts.on_failure is None: - return execution.CallError.from_error( - errors.MissingFunction("on_failure not defined") - ) - handler = self._opts.on_failure - else: - return execution.CallError.from_error( - errors.MissingFunction("function ID mismatch") - ) - - res = await handler( - attempt=call.ctx.attempt, - event=call.event, - events=call.events, - run_id=call.ctx.run_id, - step=step_lib.Step( - client, - call.steps, - step_lib.StepIDCounter(), - ), - ) - - return json.dumps(res) - except step_lib.Interrupt as out: - return [ - execution.CallResponse( - data=out.data, - display_name=out.display_name, - id=out.hashed_id, - name=out.name, - op=out.op, - opts=out.opts, - ) - ] - except Exception as err: - return execution.CallError.from_error(err) diff --git a/inngest/_internal/function/function_sync.py b/inngest/_internal/function/function_sync.py deleted file mode 100644 index 6363d8d5..00000000 --- a/inngest/_internal/function/function_sync.py +++ /dev/null @@ -1,143 +0,0 @@ -from __future__ import annotations - -import hashlib -import json -import typing - -from inngest._internal import ( - client_lib, - errors, - event_lib, - execution, - function_config, -) - -from .. import step_lib -from . import base - - -@typing.runtime_checkable -class _FunctionHandlerSync(typing.Protocol): - def __call__( - self, - *, - attempt: int, - event: event_lib.Event, - events: list[event_lib.Event], - run_id: str, - step: step_lib.StepSync, - ) -> object: - ... - - -class FunctionOptsSync(base.FunctionOptsBase[_FunctionHandlerSync]): - pass - - -def create_function_sync( - *, - batch_events: function_config.Batch | None = None, - cancel: list[function_config.Cancel] | None = None, - debounce: function_config.Debounce | None = None, - fn_id: str, - name: str | None = None, - on_failure: _FunctionHandlerSync | None = None, - rate_limit: function_config.RateLimit | None = None, - retries: int | None = None, - throttle: function_config.Throttle | None = None, - trigger: function_config.TriggerCron | function_config.TriggerEvent, -) -> typing.Callable[[_FunctionHandlerSync], FunctionSync]: - """ - Synchronous version of create_function. - """ - - def decorator(func: _FunctionHandlerSync) -> FunctionSync: - return FunctionSync( - FunctionOptsSync( - batch_events=batch_events, - cancel=cancel, - debounce=debounce, - id=fn_id, - name=name, - on_failure=on_failure, - rate_limit=rate_limit, - retries=retries, - throttle=throttle, - ), - trigger, - func, - ) - - return decorator - - -class FunctionSync(base.FunctionBase[_FunctionHandlerSync]): - """ - Synchronous version of Function. - """ - - def __init__( - self, - opts: FunctionOptsSync, - trigger: function_config.TriggerCron | function_config.TriggerEvent, - handler: _FunctionHandlerSync, - ) -> None: - self._handler = handler - self._opts = opts - self._trigger = trigger - - if opts.on_failure is not None: - # Create a random suffix to avoid collisions with the main - # function's ID. - suffix = hashlib.sha1(opts.id.encode("utf-8")).hexdigest()[:8] - - self._on_failure_fn_id = f"{opts.id}-{suffix}" - - def call( - self, - call: execution.Call, - client: client_lib.Inngest, - fn_id: str, - ) -> list[execution.CallResponse] | str | execution.CallError: - try: - handler: _FunctionHandlerSync - - if self.id == fn_id: - handler = self._handler - elif self.on_failure_fn_id == fn_id: - if self._opts.on_failure is None: - return execution.CallError.from_error( - errors.MissingFunction("on_failure not defined") - ) - handler = self._opts.on_failure - else: - return execution.CallError.from_error( - errors.MissingFunction("function ID mismatch") - ) - - res = handler( - attempt=call.ctx.attempt, - event=call.event, - events=call.events, - run_id=call.ctx.run_id, - step=step_lib.StepSync( - client, - call.steps, - step_lib.StepIDCounter(), - ), - ) - - return json.dumps(res) - except step_lib.Interrupt as out: - return [ - execution.CallResponse( - data=out.data, - display_name=out.display_name, - id=out.hashed_id, - name=out.name, - op=out.op, - opts=out.opts, - ) - ] - except Exception as err: - return execution.CallError.from_error(err) diff --git a/inngest/_internal/types.py b/inngest/_internal/types.py index 35b3b0e9..04300085 100644 --- a/inngest/_internal/types.py +++ b/inngest/_internal/types.py @@ -10,6 +10,10 @@ EmptySentinel = object() +JSONSerializableOutput = ( + bool | float | int | str | dict | list | tuple[object, ...] | None +) + class BaseModel(pydantic.BaseModel): model_config = pydantic.ConfigDict(strict=True) diff --git a/inngest/flask.py b/inngest/flask.py index 86fc9f91..7608f00a 100644 --- a/inngest/flask.py +++ b/inngest/flask.py @@ -8,7 +8,7 @@ def serve( app: flask.Flask, client: client_lib.Inngest, - functions: list[function.FunctionSync], + functions: list[function.Function], *, base_url: str | None = None, signing_key: str | None = None, diff --git a/inngest/tornado.py b/inngest/tornado.py index 9139cf4e..2e169a6a 100644 --- a/inngest/tornado.py +++ b/inngest/tornado.py @@ -17,7 +17,7 @@ def serve( app: tornado.web.Application, client: client_lib.Inngest, - functions: list[function.FunctionSync], + functions: list[function.Function], *, base_url: str | None = None, signing_key: str | None = None, diff --git a/tests/cases/base.py b/tests/cases/base.py index 37d213b1..94faae34 100644 --- a/tests/cases/base.py +++ b/tests/cases/base.py @@ -22,15 +22,10 @@ def assertion() -> None: return self.run_id -FunctionT = typing.TypeVar( - "FunctionT", bound=inngest.Function | inngest.FunctionSync -) - - @dataclasses.dataclass class Case: event_name: str - fn: inngest.Function | inngest.FunctionSync + fn: inngest.Function name: str run_test: typing.Callable[[object], None] state: BaseState diff --git a/tests/cases/cancel.py b/tests/cases/cancel.py index 531e4ba8..61ebb44f 100644 --- a/tests/cases/cancel.py +++ b/tests/cases/cancel.py @@ -21,7 +21,7 @@ def create( event_name = base.create_event_name(framework, test_name, is_sync) state = _State() - @inngest.create_function_sync( + @inngest.create_function( cancel=[ inngest.Cancel( event=f"{event_name}.cancel", @@ -85,7 +85,7 @@ def assert_is_done() -> None: base.wait_for(assert_is_done) - fn: inngest.Function | inngest.FunctionSync + fn: inngest.Function if is_sync: fn = fn_sync else: diff --git a/tests/cases/client_send.py b/tests/cases/client_send.py index b65ccec3..8024b6d0 100644 --- a/tests/cases/client_send.py +++ b/tests/cases/client_send.py @@ -17,7 +17,7 @@ def create( event_name = base.create_event_name(framework, test_name, is_sync) state = base.BaseState() - @inngest.create_function_sync( + @inngest.create_function( fn_id=test_name, retries=0, trigger=inngest.TriggerEvent(event=event_name), @@ -41,7 +41,7 @@ def run_test(_self: object) -> None: tests.helper.RunStatus.COMPLETED, ) - fn: inngest.Function | inngest.FunctionSync + fn: inngest.Function if is_sync: fn = fn_sync else: diff --git a/tests/cases/debounce.py b/tests/cases/debounce.py index ecaa0e9c..d3a6bd2e 100644 --- a/tests/cases/debounce.py +++ b/tests/cases/debounce.py @@ -21,7 +21,7 @@ def create( event_name = base.create_event_name(framework, test_name, is_sync) state = _State() - @inngest.create_function_sync( + @inngest.create_function( debounce=inngest.Debounce( period=datetime.timedelta(seconds=1), ), @@ -59,7 +59,7 @@ def run_test(_self: object) -> None: ) assert state.run_count == 1, f"Expected 1 run but got {state.run_count}" - fn: inngest.Function | inngest.FunctionSync + fn: inngest.Function if is_sync: fn = fn_sync else: diff --git a/tests/cases/event_payload.py b/tests/cases/event_payload.py index 8b566ad9..c6e6c071 100644 --- a/tests/cases/event_payload.py +++ b/tests/cases/event_payload.py @@ -19,7 +19,7 @@ def create( event_name = base.create_event_name(framework, test_name, is_sync) state = _State() - @inngest.create_function_sync( + @inngest.create_function( fn_id=test_name, retries=0, trigger=inngest.TriggerEvent(event=event_name), @@ -62,7 +62,7 @@ def run_test(_self: object) -> None: assert state.event.ts > 0 assert state.event.user == {"a": {"b": "c"}} - fn: inngest.Function | inngest.FunctionSync + fn: inngest.Function if is_sync: fn = fn_sync else: diff --git a/tests/cases/function_args.py b/tests/cases/function_args.py index a75b9df2..a48cb2f0 100644 --- a/tests/cases/function_args.py +++ b/tests/cases/function_args.py @@ -22,7 +22,7 @@ def create( event_name = base.create_event_name(framework, test_name, is_sync) state = _State() - @inngest.create_function_sync( + @inngest.create_function( fn_id=test_name, retries=0, trigger=inngest.TriggerEvent(event=event_name), @@ -73,7 +73,7 @@ def run_test(_self: object) -> None: assert isinstance(state.events, list) and len(state.events) == 1 assert isinstance(state.step, (inngest.Step, inngest.StepSync)) - fn: inngest.Function | inngest.FunctionSync + fn: inngest.Function if is_sync: fn = fn_sync else: diff --git a/tests/cases/no_steps.py b/tests/cases/no_steps.py index ea009aea..11355901 100644 --- a/tests/cases/no_steps.py +++ b/tests/cases/no_steps.py @@ -15,7 +15,7 @@ def create( event_name = base.create_event_name(framework, test_name, is_sync) state = base.BaseState() - @inngest.create_function_sync( + @inngest.create_function( fn_id=test_name, retries=0, trigger=inngest.TriggerEvent(event=event_name), @@ -39,7 +39,7 @@ def run_test(_self: object) -> None: tests.helper.RunStatus.COMPLETED, ) - fn: inngest.Function | inngest.FunctionSync + fn: inngest.Function if is_sync: fn = fn_sync else: diff --git a/tests/cases/on_failure.py b/tests/cases/on_failure.py index da79a21b..666b97a9 100644 --- a/tests/cases/on_failure.py +++ b/tests/cases/on_failure.py @@ -59,7 +59,7 @@ async def on_failure_async( state.on_failure_run_id = run_id state.step = step - @inngest.create_function_sync( + @inngest.create_function( fn_id=test_name, on_failure=on_failure_sync, retries=0, @@ -112,7 +112,7 @@ def run_test(_self: object) -> None: assert isinstance(state.events, list) and len(state.events) == 1 assert isinstance(state.step, (inngest.Step, inngest.StepSync)) - fn: inngest.Function | inngest.FunctionSync + fn: inngest.Function if is_sync: fn = fn_sync else: diff --git a/tests/cases/sleep.py b/tests/cases/sleep.py index e578a211..f88b9078 100644 --- a/tests/cases/sleep.py +++ b/tests/cases/sleep.py @@ -25,7 +25,7 @@ def create( event_name = base.create_event_name(framework, test_name, is_sync) state = _State() - @inngest.create_function_sync( + @inngest.create_function( fn_id=test_name, retries=0, trigger=inngest.TriggerEvent(event=event_name), @@ -80,7 +80,7 @@ def run_test(_self: object) -> None: seconds=2 ) - fn: inngest.Function | inngest.FunctionSync + fn: inngest.Function if is_sync: fn = fn_sync else: diff --git a/tests/cases/sleep_until.py b/tests/cases/sleep_until.py index a98e87bd..c0338103 100644 --- a/tests/cases/sleep_until.py +++ b/tests/cases/sleep_until.py @@ -25,7 +25,7 @@ def create( event_name = base.create_event_name(framework, test_name, is_sync) state = _State() - @inngest.create_function_sync( + @inngest.create_function( fn_id=test_name, retries=0, trigger=inngest.TriggerEvent(event=event_name), @@ -84,7 +84,7 @@ def run_test(_self: object) -> None: seconds=2 ) - fn: inngest.Function | inngest.FunctionSync + fn: inngest.Function if is_sync: fn = fn_sync else: diff --git a/tests/cases/two_steps.py b/tests/cases/two_steps.py index d4f30544..98bcedfe 100644 --- a/tests/cases/two_steps.py +++ b/tests/cases/two_steps.py @@ -20,7 +20,7 @@ def create( event_name = base.create_event_name(framework, test_name, is_sync) state = _State() - @inngest.create_function_sync( + @inngest.create_function( fn_id=test_name, retries=0, trigger=inngest.TriggerEvent(event=event_name), @@ -79,7 +79,7 @@ def run_test(_self: object) -> None: assert state.step_1_counter == 1 assert state.step_2_counter == 1 - fn: inngest.Function | inngest.FunctionSync + fn: inngest.Function if is_sync: fn = fn_sync else: diff --git a/tests/cases/unserializable_step_output.py b/tests/cases/unserializable_step_output.py index 21e5a564..6441b357 100644 --- a/tests/cases/unserializable_step_output.py +++ b/tests/cases/unserializable_step_output.py @@ -20,7 +20,7 @@ def create( event_name = base.create_event_name(framework, test_name, is_sync) state = _State() - @inngest.create_function_sync( + @inngest.create_function( fn_id=test_name, retries=0, trigger=inngest.TriggerEvent(event=event_name), @@ -81,7 +81,7 @@ def run_test(_self: object) -> None: assert isinstance(state.error, errors.UnserializableOutput) assert str(state.error) == "Object of type Foo is not JSON serializable" - fn: inngest.Function | inngest.FunctionSync + fn: inngest.Function if is_sync: fn = fn_sync else: diff --git a/tests/cases/wait_for_event_fulfill.py b/tests/cases/wait_for_event_fulfill.py index 9447b76e..4d28774c 100644 --- a/tests/cases/wait_for_event_fulfill.py +++ b/tests/cases/wait_for_event_fulfill.py @@ -22,7 +22,7 @@ def create( event_name = base.create_event_name(framework, test_name, is_sync) state = _State() - @inngest.create_function_sync( + @inngest.create_function( fn_id=test_name, retries=0, trigger=inngest.TriggerEvent(event=event_name), @@ -78,7 +78,7 @@ def run_test(_self: object) -> None: assert state.result.name == f"{event_name}.fulfill" assert state.result.ts > 0 - fn: inngest.Function | inngest.FunctionSync + fn: inngest.Function if is_sync: fn = fn_sync else: diff --git a/tests/cases/wait_for_event_timeout.py b/tests/cases/wait_for_event_timeout.py index f439e91b..ab37e1ba 100644 --- a/tests/cases/wait_for_event_timeout.py +++ b/tests/cases/wait_for_event_timeout.py @@ -21,7 +21,7 @@ def create( event_name = base.create_event_name(framework, test_name, is_sync) state = _State() - @inngest.create_function_sync( + @inngest.create_function( fn_id=test_name, retries=0, trigger=inngest.TriggerEvent(event=event_name), @@ -68,7 +68,7 @@ def run_test(_self: object) -> None: ) assert state.result is None - fn: inngest.Function | inngest.FunctionSync + fn: inngest.Function if is_sync: fn = fn_sync else: diff --git a/tests/test_fast_api.py b/tests/test_fast_api.py index 3cff5caf..f558faa8 100644 --- a/tests/test_fast_api.py +++ b/tests/test_fast_api.py @@ -36,7 +36,7 @@ def setUpClass(cls) -> None: case.fn for case in _cases # Should always be true but mypy doesn't know that - if isinstance(case.fn, inngest.Function) + # if case.fn.is_handler_async ], ) cls.fast_api_client = fastapi.testclient.TestClient(cls.app) diff --git a/tests/test_flask.py b/tests/test_flask.py index 14a0d6b5..8dc062ee 100644 --- a/tests/test_flask.py +++ b/tests/test_flask.py @@ -36,7 +36,7 @@ def setUpClass(cls) -> None: case.fn for case in _cases # Should always be true but mypy doesn't know that - if isinstance(case.fn, inngest.FunctionSync) + if isinstance(case.fn, inngest.Function) ], ) cls.app = app.test_client() @@ -89,7 +89,7 @@ def test_dev_server_to_prod(self) -> None: is_production=True, ) - @inngest.create_function_sync( + @inngest.create_function( fn_id="foo", retries=0, trigger=inngest.TriggerEvent(event="app/foo"),