diff --git a/inngest/_internal/comm.py b/inngest/_internal/comm.py index 2f7f516..2eb2fab 100644 --- a/inngest/_internal/comm.py +++ b/inngest/_internal/comm.py @@ -54,7 +54,7 @@ def body(self, body: object) -> None: def from_error( cls, err: Exception, - framework: str, + framework: const.Framework, ) -> CommResponse: code = const.ErrorCode.UNKNOWN.value status_code = http.HTTPStatus.INTERNAL_SERVER_ERROR.value @@ -76,7 +76,7 @@ class CommHandler: _base_url: str _client: client_lib.Inngest _fns: dict[str, function.Function | function.FunctionSync] - _framework: str + _framework: const.Framework _is_production: bool _logger: logging.Logger _signing_key: str | None @@ -86,7 +86,7 @@ def __init__( *, base_url: str | None = None, client: client_lib.Inngest, - framework: str, + framework: const.Framework, functions: list[function.Function] | list[function.FunctionSync], logger: logging.Logger, signing_key: str | None = None, diff --git a/inngest/_internal/comm_test.py b/inngest/_internal/comm_test.py index 2d652c0..cbc119f 100644 --- a/inngest/_internal/comm_test.py +++ b/inngest/_internal/comm_test.py @@ -5,7 +5,7 @@ import unittest import inngest -from inngest._internal import errors, result +from inngest._internal import const, errors, result from . import comm @@ -50,7 +50,7 @@ def fn(**_kwargs: object) -> int: handler = comm.CommHandler( base_url="http://foo.bar", client=self.client, - framework="test", + framework=const.Framework.FLASK, functions=[fn], logger=self.client.logger, ) @@ -68,7 +68,7 @@ def test_no_functions(self) -> None: handler = comm.CommHandler( base_url="http://foo.bar", client=self.client, - framework="test", + framework=const.Framework.FLASK, functions=functions, logger=self.client.logger, ) @@ -79,6 +79,3 @@ def test_no_functions(self) -> None: case result.Err(err): assert isinstance(err, errors.InvalidConfig) assert str(err) == "no functions found" - - # with pytest.raises(errors.InvalidConfig, match="no functions found"): - # handler.get_function_configs("http://foo.bar") diff --git a/inngest/_internal/const.py b/inngest/_internal/const.py index 733a7f0..0baffa3 100644 --- a/inngest/_internal/const.py +++ b/inngest/_internal/const.py @@ -32,6 +32,12 @@ class ErrorCode(enum.Enum): UNSERIALIZABLE_OUTPUT = "unserializable_output" +class Framework(enum.StrEnum): + FAST_API = "fast_api" + FLASK = "flask" + TORNADO = "tornado" + + class HeaderKey(enum.Enum): CONTENT_TYPE = "Content-Type" FORWARDED_FOR = "X-Forwarded-For" diff --git a/inngest/_internal/execution.py b/inngest/_internal/execution.py index edfbe61..fb7e34a 100644 --- a/inngest/_internal/execution.py +++ b/inngest/_internal/execution.py @@ -2,7 +2,7 @@ import enum -from . import event_lib, types +from . import errors, event_lib, transforms, types class Call(types.BaseModel): @@ -23,11 +23,22 @@ class CallStack(types.BaseModel): class CallError(types.BaseModel): + is_internal: bool is_retriable: bool message: str name: str stack: str + @classmethod + def from_error(cls, err: Exception) -> CallError: + return cls( + is_internal=isinstance(err, errors.InternalError), + is_retriable=isinstance(err, errors.NonRetriableError) is False, + message=str(err), + name=type(err).__name__, + stack=transforms.get_traceback(err), + ) + class CallResponse(types.BaseModel): data: object diff --git a/inngest/_internal/function/function_async.py b/inngest/_internal/function/function_async.py index 02163b5..9ab7a88 100644 --- a/inngest/_internal/function/function_async.py +++ b/inngest/_internal/function/function_async.py @@ -2,7 +2,6 @@ import hashlib import json -import traceback import typing from inngest._internal import ( @@ -99,10 +98,14 @@ async def call( handler = self._handler elif self.on_failure_fn_id == fn_id: if self._opts.on_failure is None: - raise errors.MissingFunction("on_failure not defined") + return execution.CallError.from_error( + errors.MissingFunction("on_failure not defined") + ) handler = self._opts.on_failure else: - raise errors.MissingFunction("function ID mismatch") + return execution.CallError.from_error( + errors.MissingFunction("function ID mismatch") + ) res = await handler( attempt=call.ctx.attempt, @@ -129,11 +132,4 @@ async def call( ) ] except Exception as err: - is_retriable = isinstance(err, errors.NonRetriableError) is False - - return execution.CallError( - is_retriable=is_retriable, - message=str(err), - name=type(err).__name__, - stack=traceback.format_exc(), - ) + return execution.CallError.from_error(err) diff --git a/inngest/_internal/function/function_sync.py b/inngest/_internal/function/function_sync.py index 8ba6e05..6363d8d 100644 --- a/inngest/_internal/function/function_sync.py +++ b/inngest/_internal/function/function_sync.py @@ -2,7 +2,6 @@ import hashlib import json -import traceback import typing from inngest._internal import ( @@ -107,10 +106,14 @@ def call( handler = self._handler elif self.on_failure_fn_id == fn_id: if self._opts.on_failure is None: - raise errors.MissingFunction("on_failure not defined") + return execution.CallError.from_error( + errors.MissingFunction("on_failure not defined") + ) handler = self._opts.on_failure else: - raise errors.MissingFunction("function ID mismatch") + return execution.CallError.from_error( + errors.MissingFunction("function ID mismatch") + ) res = handler( attempt=call.ctx.attempt, @@ -137,11 +140,4 @@ def call( ) ] except Exception as err: - is_retriable = isinstance(err, errors.NonRetriableError) is False - - return execution.CallError( - is_retriable=is_retriable, - message=str(err), - name=type(err).__name__, - stack=traceback.format_exc(), - ) + return execution.CallError.from_error(err) diff --git a/inngest/_internal/net.py b/inngest/_internal/net.py index b706eff..55b51a2 100644 --- a/inngest/_internal/net.py +++ b/inngest/_internal/net.py @@ -10,7 +10,7 @@ def create_headers( *, - framework: str | None = None, + framework: const.Framework | None = None, ) -> dict[str, str]: headers = { const.HeaderKey.USER_AGENT.value: f"inngest-{const.LANGUAGE}:v{const.VERSION}", @@ -18,7 +18,7 @@ def create_headers( } if framework is not None: - headers[const.HeaderKey.FRAMEWORK.value] = framework + headers[const.HeaderKey.FRAMEWORK.value] = framework.value return headers diff --git a/inngest/_internal/registration.py b/inngest/_internal/registration.py index 3e61af9..a34e4ce 100644 --- a/inngest/_internal/registration.py +++ b/inngest/_internal/registration.py @@ -2,7 +2,7 @@ import pydantic -from . import function_config, types +from . import const, function_config, types class DeployType(enum.Enum): @@ -12,7 +12,7 @@ class DeployType(enum.Enum): class RegisterRequest(types.BaseModel): app_name: str deploy_type: DeployType - framework: str + framework: const.Framework functions: list[function_config.FunctionConfig] = pydantic.Field( min_length=1 ) diff --git a/inngest/_internal/transforms.py b/inngest/_internal/transforms.py index 143cf02..ee4dfc3 100644 --- a/inngest/_internal/transforms.py +++ b/inngest/_internal/transforms.py @@ -1,10 +1,17 @@ import datetime import hashlib import re +import traceback from . import errors, result, types +def get_traceback(err: Exception) -> str: + return "".join( + traceback.format_exception(type(err), err, err.__traceback__) + ) + + def hash_signing_key(key: str) -> str: return hashlib.sha256( bytearray.fromhex(remove_signing_key_prefix(key)) diff --git a/inngest/fast_api.py b/inngest/fast_api.py index 86ab395..760e48a 100644 --- a/inngest/fast_api.py +++ b/inngest/fast_api.py @@ -16,7 +16,7 @@ def serve( handler = comm.CommHandler( base_url=base_url or client.base_url, client=client, - framework="flask", + framework=const.Framework.FAST_API, functions=functions, logger=client.logger, signing_key=signing_key, diff --git a/inngest/flask.py b/inngest/flask.py index 7dfb2cb..5a89c90 100644 --- a/inngest/flask.py +++ b/inngest/flask.py @@ -16,7 +16,7 @@ def serve( handler = comm.CommHandler( base_url=base_url or client.base_url, client=client, - framework="flask", + framework=const.Framework.FLASK, functions=functions, logger=app.logger, signing_key=signing_key, diff --git a/inngest/tornado.py b/inngest/tornado.py index 21ecf65..0e807b4 100644 --- a/inngest/tornado.py +++ b/inngest/tornado.py @@ -25,7 +25,7 @@ def serve( handler = comm.CommHandler( base_url=base_url or client.base_url, client=client, - framework="flask", + framework=const.Framework.TORNADO, functions=functions, logger=client.logger, signing_key=signing_key,