From a7715961e404fc4ce8d96fd5c029b4dc426795ed Mon Sep 17 00:00:00 2001 From: Aaron Harper Date: Wed, 1 Nov 2023 09:07:37 -0400 Subject: [PATCH 1/2] Combine serve and serve_sync --- README.md | 38 +++++++---------------- inngest/_internal/comm.py | 7 +++-- inngest/_internal/execution.py | 9 ++++++ inngest/_internal/function.py | 33 +++++++++++++++++--- inngest/_internal/step_lib/step_async.py | 32 +++++++++++++++++-- inngest/_internal/step_lib/step_sync.py | 5 +++ inngest/_internal/types.py | 12 +++++-- inngest/flask.py | 33 +++++++++++--------- inngest/tornado.py | 2 +- tests/cases/unserializable_step_output.py | 2 +- tests/test_flask.py | 4 +-- 11 files changed, 122 insertions(+), 55 deletions(-) diff --git a/README.md b/README.md index 50811513..4e355a35 100644 --- a/README.md +++ b/README.md @@ -23,7 +23,7 @@ # Inngest Python SDK -## 🚧 Currently in alpha! Not guaranteed to be production ready! 🚧 +> 🚧 Currently in beta! It hasn't been battle-tested in production environments yet. Supported frameworks: @@ -33,9 +33,11 @@ Supported frameworks: ## Usage +> 💡 Most of these examples don't show `async` functions but you can mix `async` and non-`async` functions in the same app! + - [Basic](#basic-no-steps) - [Step run](#step-run) -- [Avoiding async/await](#avoiding-async-functions) +- [Async function](#async-function) ### Basic (no steps) @@ -46,7 +48,6 @@ import flask import inngest.flask import requests - @inngest.create_function( fn_id="find_person", trigger=inngest.TriggerEvent(event="app/person.find"), @@ -61,7 +62,6 @@ async def fetch_person( res = requests.get(f"https://swapi.dev/api/people/{person_id}") return res.json() - app = flask.Flask(__name__) inngest_client = inngest.Inngest(app_id="flask_example") @@ -89,10 +89,10 @@ The following example registers a function that will: fn_id="find_ships", trigger=inngest.TriggerEvent(event="app/ships.find"), ) -async def fetch_ships( +def fetch_ships( *, event: inngest.Event, - step: inngest.Step, + step: inngest.StepSync, **_kwargs: object, ) -> dict: """ @@ -125,37 +125,23 @@ async def fetch_ships( } ``` -### Avoiding async functions - -Completely avoiding `async`` functions only requires 2 differences: - -1. Use `step: inngest.StepSync` instead of `step: inngest.Step` -2. Use `serve_sync` instead of `serve` +### Async functions ```py @inngest.create_function( fn_id="find_person", trigger=inngest.TriggerEvent(event="app/person.find"), ) -def fetch_person( +async def fetch_person( *, event: inngest.Event, - step: inngest.StepSync, + step: inngest.Step, **_kwargs: object, ) -> dict: person_id = event.data["person_id"] - res = requests.get(f"https://swapi.dev/api/people/{person_id}") - return res.json() - - -app = flask.Flask(__name__) -inngest_client = inngest.Inngest(app_id="flask_example") - -inngest.flask.serve_sync( - app, - inngest_client, - [fetch_person], -) + async with httpx.AsyncClient() as client: + res = await client.get(f"https://swapi.dev/api/people/{person_id}") + return res.json() ``` > 💡 You can mix `async` and non-`async` functions in the same app! diff --git a/inngest/_internal/comm.py b/inngest/_internal/comm.py index 3ae691e8..302e1c87 100644 --- a/inngest/_internal/comm.py +++ b/inngest/_internal/comm.py @@ -18,6 +18,7 @@ registration, result, transforms, + types, ) @@ -226,7 +227,9 @@ def call_function_sync( def _create_response( self, - call_res: list[execution.CallResponse] | str | execution.CallError, + call_res: list[execution.CallResponse] + | types.Serializable + | execution.CallError, ) -> CommResponse: comm_res = CommResponse( headers={ @@ -235,7 +238,7 @@ def _create_response( } ) - if isinstance(call_res, list): + if execution.is_call_responses(call_res): out: list[dict[str, object]] = [] for item in call_res: match item.to_dict(): diff --git a/inngest/_internal/execution.py b/inngest/_internal/execution.py index fb7e34af..92b64b1f 100644 --- a/inngest/_internal/execution.py +++ b/inngest/_internal/execution.py @@ -1,6 +1,7 @@ from __future__ import annotations import enum +import typing from . import errors, event_lib, transforms, types @@ -49,6 +50,14 @@ class CallResponse(types.BaseModel): opts: dict[str, object] | None = None +def is_call_responses( + value: object, +) -> typing.TypeGuard[list[CallResponse]]: + if not isinstance(value, list): + return False + return all(isinstance(item, CallResponse) for item in value) + + class Opcode(enum.Enum): SLEEP = "Sleep" STEP = "Step" diff --git a/inngest/_internal/function.py b/inngest/_internal/function.py index 13a07c8c..e91cee2a 100644 --- a/inngest/_internal/function.py +++ b/inngest/_internal/function.py @@ -16,6 +16,7 @@ execution, function_config, step_lib, + transforms, types, ) @@ -39,7 +40,7 @@ def __call__( events: list[event_lib.Event], run_id: str, step: step_lib.Step, - ) -> typing.Awaitable[types.JSONSerializableOutput]: + ) -> typing.Awaitable[types.Serializable]: ... @@ -53,7 +54,7 @@ def __call__( events: list[event_lib.Event], run_id: str, step: step_lib.StepSync, - ) -> types.JSONSerializableOutput: + ) -> types.Serializable: ... @@ -134,8 +135,22 @@ def id(self) -> str: @property def is_handler_async(self) -> bool: + """ + Whether the main handler is async. + """ return _is_function_handler_async(self._handler) + @property + def is_on_failure_handler_async(self) -> bool | None: + """ + Whether the on_failure handler is async. Returns None if there isn't an + on_failure handler. + """ + + if self._opts.on_failure is None: + return None + return _is_function_handler_async(self._opts.on_failure) + @property def on_failure_fn_id(self) -> str | None: return self._on_failure_fn_id @@ -162,7 +177,9 @@ async def call( call: execution.Call, client: client_lib.Inngest, fn_id: str, - ) -> list[execution.CallResponse] | str | execution.CallError: + ) -> list[ + execution.CallResponse + ] | types.Serializable | execution.CallError: try: handler: FunctionHandlerAsync | FunctionHandlerSync if self.id == fn_id: @@ -213,7 +230,7 @@ async def call( ) ) - return json.dumps(res) + return res except step_lib.Interrupt as out: return [ execution.CallResponse( @@ -226,6 +243,10 @@ async def call( ) ] except Exception as err: + # An error occurred with the user's code. Print the traceback to + # help them debug. + print(transforms.get_traceback(err)) + return execution.CallError.from_error(err) def call_sync( @@ -281,6 +302,10 @@ def call_sync( ) ] except Exception as err: + # An error occurred with the user's code. Print the traceback to + # help them debug. + print(transforms.get_traceback(err)) + return execution.CallError.from_error(err) def get_config(self, app_url: str) -> _Config: diff --git a/inngest/_internal/step_lib/step_async.py b/inngest/_internal/step_lib/step_async.py index 3cf3f859..2fdddda6 100644 --- a/inngest/_internal/step_lib/step_async.py +++ b/inngest/_internal/step_lib/step_async.py @@ -1,4 +1,5 @@ import datetime +import inspect import json import typing @@ -26,13 +27,35 @@ def __init__( self._memos = memos self._step_id_counter = step_id_counter + @typing.overload async def run( self, step_id: str, - handler: typing.Callable[[], typing.Awaitable[types.T]], - ) -> types.T: + handler: typing.Callable[[], typing.Awaitable[types.SerializableT]], + ) -> types.SerializableT: + ... + + @typing.overload + async def run( + self, + step_id: str, + handler: typing.Callable[[], types.SerializableT], + ) -> types.SerializableT: + ... + + async def run( + self, + step_id: str, + handler: typing.Callable[[], typing.Awaitable[types.SerializableT]] + | typing.Callable[[], types.SerializableT], + ) -> types.SerializableT: """ Run logic that should be retried on error and memoized after success. + + Args: + step_id: Unique step ID within the function. If the same step ID is + encountered multiple times then it'll get an index suffix. + handler: The logic to run. Can be async or sync. """ hashed_id = self._get_hashed_id(step_id) @@ -41,7 +64,10 @@ async def run( if memo is not types.EmptySentinel: return memo # type: ignore - output = await handler() + if inspect.iscoroutinefunction(handler): + output = await handler() + else: + output = handler() try: json.dumps(output) diff --git a/inngest/_internal/step_lib/step_sync.py b/inngest/_internal/step_lib/step_sync.py index cad9c1b6..3997a4e8 100644 --- a/inngest/_internal/step_lib/step_sync.py +++ b/inngest/_internal/step_lib/step_sync.py @@ -33,6 +33,11 @@ def run( ) -> types.T: """ Run logic that should be retried on error and memoized after success. + + Args: + step_id: Unique step ID within the function. If the same step ID is + encountered multiple times then it'll get an index suffix. + handler: The logic to run. """ hashed_id = self._get_hashed_id(step_id) diff --git a/inngest/_internal/types.py b/inngest/_internal/types.py index 04300085..3d84e283 100644 --- a/inngest/_internal/types.py +++ b/inngest/_internal/types.py @@ -10,9 +10,17 @@ EmptySentinel = object() -JSONSerializableOutput = ( - bool | float | int | str | dict | list | tuple[object, ...] | None +Serializable = ( + bool + | float + | int + | str + | dict[typing.Any, typing.Any] + | list[typing.Any] + | tuple[typing.Any, ...] + | None ) +SerializableT = typing.TypeVar("SerializableT", bound=Serializable) class BaseModel(pydantic.BaseModel): diff --git a/inngest/flask.py b/inngest/flask.py index 5d4ab95d..fe2a0d13 100644 --- a/inngest/flask.py +++ b/inngest/flask.py @@ -22,6 +22,21 @@ def serve( signing_key=signing_key, ) + async_mode = any( + function.is_handler_async or function.is_on_failure_handler_async + for function in functions + ) + if async_mode: + _create_handler_async(app, client, handler) + else: + _create_handler_sync(app, client, handler) + + +def _create_handler_async( + app: flask.Flask, + client: client_lib.Inngest, + handler: comm.CommHandler, +) -> None: @app.route("/api/inngest", methods=["GET", "POST", "PUT"]) async def inngest_api() -> flask.Response | str: headers = net.normalize_headers(dict(flask.request.headers.items())) @@ -60,26 +75,15 @@ async def inngest_api() -> flask.Response | str: ) ) + # Should be unreachable return "" -def serve_sync( +def _create_handler_sync( app: flask.Flask, client: client_lib.Inngest, - functions: list[function.Function], - *, - base_url: str | None = None, - signing_key: str | None = None, + handler: comm.CommHandler, ) -> None: - handler = comm.CommHandler( - base_url=base_url or client.base_url, - client=client, - framework=const.Framework.FLASK, - functions=functions, - logger=app.logger, - signing_key=signing_key, - ) - @app.route("/api/inngest", methods=["GET", "POST", "PUT"]) def inngest_api() -> flask.Response | str: headers = net.normalize_headers(dict(flask.request.headers.items())) @@ -118,6 +122,7 @@ def inngest_api() -> flask.Response | str: ) ) + # Should be unreachable return "" diff --git a/inngest/tornado.py b/inngest/tornado.py index 282f2b84..2e169a6a 100644 --- a/inngest/tornado.py +++ b/inngest/tornado.py @@ -14,7 +14,7 @@ ) -def serve_sync( +def serve( app: tornado.web.Application, client: client_lib.Inngest, functions: list[function.Function], diff --git a/tests/cases/unserializable_step_output.py b/tests/cases/unserializable_step_output.py index 6441b357..e2268f6d 100644 --- a/tests/cases/unserializable_step_output.py +++ b/tests/cases/unserializable_step_output.py @@ -65,7 +65,7 @@ async def step_1() -> Foo: return Foo() try: - await step.run("step_1", step_1) + await step.run("step_1", step_1) # type: ignore except BaseException as err: state.error = err raise diff --git a/tests/test_flask.py b/tests/test_flask.py index e5e4a5eb..4b259ed3 100644 --- a/tests/test_flask.py +++ b/tests/test_flask.py @@ -29,7 +29,7 @@ def setUpClass(cls) -> None: app = flask.Flask(__name__) app.logger.disabled = True - inngest.flask.serve_sync( + inngest.flask.serve( app, _client, [ @@ -98,7 +98,7 @@ def fn(**_kwargs: object) -> None: pass app = flask.Flask(__name__) - inngest.flask.serve_sync( + inngest.flask.serve( app, client, [fn], From 8b31acc93f5f2a8347aad5f92ea2c46b34c79fd5 Mon Sep 17 00:00:00 2001 From: Aaron Harper Date: Wed, 1 Nov 2023 09:18:37 -0400 Subject: [PATCH 2/2] Add transforms.dump_json --- inngest/_internal/function.py | 10 ++++------ inngest/_internal/step_lib/step_async.py | 12 ++++++------ inngest/_internal/step_lib/step_sync.py | 12 ++++++------ inngest/_internal/transforms.py | 8 ++++++++ 4 files changed, 24 insertions(+), 18 deletions(-) diff --git a/inngest/_internal/function.py b/inngest/_internal/function.py index e91cee2a..a6e56538 100644 --- a/inngest/_internal/function.py +++ b/inngest/_internal/function.py @@ -3,7 +3,6 @@ import dataclasses import hashlib import inspect -import json import typing import pydantic @@ -254,7 +253,9 @@ def call_sync( call: execution.Call, client: client_lib.Inngest, fn_id: str, - ) -> list[execution.CallResponse] | str | execution.CallError: + ) -> ( + list[execution.CallResponse] | types.Serializable | execution.CallError + ): try: handler: FunctionHandlerAsync | FunctionHandlerSync if self.id == fn_id: @@ -271,7 +272,7 @@ def call_sync( ) if _is_function_handler_sync(handler): - res = handler( + return handler( attempt=call.ctx.attempt, event=call.event, events=call.events, @@ -282,9 +283,6 @@ def call_sync( step_lib.StepIDCounter(), ), ) - - return json.dumps(res) - return execution.CallError.from_error( errors.MismatchedSync( "encountered async function in non-async context" diff --git a/inngest/_internal/step_lib/step_async.py b/inngest/_internal/step_lib/step_async.py index 2fdddda6..f3dfa975 100644 --- a/inngest/_internal/step_lib/step_async.py +++ b/inngest/_internal/step_lib/step_async.py @@ -1,11 +1,9 @@ import datetime import inspect -import json import typing from inngest._internal import ( client_lib, - errors, event_lib, execution, result, @@ -69,10 +67,12 @@ async def run( else: output = handler() - try: - json.dumps(output) - except TypeError as err: - raise errors.UnserializableOutput(str(err)) from None + # Check whether output is serializable + match transforms.dump_json(output): + case result.Ok(_): + pass + case result.Err(err): + raise err raise base.Interrupt( hashed_id=hashed_id, diff --git a/inngest/_internal/step_lib/step_sync.py b/inngest/_internal/step_lib/step_sync.py index 3997a4e8..0ba45a35 100644 --- a/inngest/_internal/step_lib/step_sync.py +++ b/inngest/_internal/step_lib/step_sync.py @@ -1,10 +1,8 @@ import datetime -import json import typing from inngest._internal import ( client_lib, - errors, event_lib, execution, result, @@ -48,10 +46,12 @@ def run( output = handler() - try: - json.dumps(output) - except TypeError as err: - raise errors.UnserializableOutput(str(err)) from None + # Check whether output is serializable + match transforms.dump_json(output): + case result.Ok(_): + pass + case result.Err(err): + raise err raise base.Interrupt( hashed_id=hashed_id, diff --git a/inngest/_internal/transforms.py b/inngest/_internal/transforms.py index ee4dfc36..e82ec323 100644 --- a/inngest/_internal/transforms.py +++ b/inngest/_internal/transforms.py @@ -1,5 +1,6 @@ import datetime import hashlib +import json import re import traceback @@ -22,6 +23,13 @@ def hash_step_id(step_id: str) -> str: return hashlib.sha1(step_id.encode("utf-8")).hexdigest() +def dump_json(obj: object) -> result.Result[str, errors.UnserializableOutput]: + try: + return result.Ok(json.dumps(obj)) + except Exception as err: + return result.Err(errors.UnserializableOutput(str(err))) + + def remove_signing_key_prefix(key: str) -> str: prefix_match = re.match(r"^signkey-[\w]+-", key) prefix = ""