Skip to content

Commit

Permalink
Use result more; frameworks
Browse files Browse the repository at this point in the history
  • Loading branch information
amh4r committed Oct 31, 2023
1 parent d2add0c commit ded3e06
Show file tree
Hide file tree
Showing 12 changed files with 52 additions and 39 deletions.
6 changes: 3 additions & 3 deletions inngest/_internal/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def body(self, body: object) -> None:
def from_error(
cls,
err: Exception,
framework: str,
framework: const.Framework,
) -> CommResponse:
code = const.ErrorCode.UNKNOWN.value
status_code = http.HTTPStatus.INTERNAL_SERVER_ERROR.value
Expand All @@ -76,7 +76,7 @@ class CommHandler:
_base_url: str
_client: client_lib.Inngest
_fns: dict[str, function.Function | function.FunctionSync]
_framework: str
_framework: const.Framework
_is_production: bool
_logger: logging.Logger
_signing_key: str | None
Expand All @@ -86,7 +86,7 @@ def __init__(
*,
base_url: str | None = None,
client: client_lib.Inngest,
framework: str,
framework: const.Framework,
functions: list[function.Function] | list[function.FunctionSync],
logger: logging.Logger,
signing_key: str | None = None,
Expand Down
9 changes: 3 additions & 6 deletions inngest/_internal/comm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import unittest

import inngest
from inngest._internal import errors, result
from inngest._internal import const, errors, result

from . import comm

Expand Down Expand Up @@ -50,7 +50,7 @@ def fn(**_kwargs: object) -> int:
handler = comm.CommHandler(
base_url="http://foo.bar",
client=self.client,
framework="test",
framework=const.Framework.FLASK,
functions=[fn],
logger=self.client.logger,
)
Expand All @@ -68,7 +68,7 @@ def test_no_functions(self) -> None:
handler = comm.CommHandler(
base_url="http://foo.bar",
client=self.client,
framework="test",
framework=const.Framework.FLASK,
functions=functions,
logger=self.client.logger,
)
Expand All @@ -79,6 +79,3 @@ def test_no_functions(self) -> None:
case result.Err(err):
assert isinstance(err, errors.InvalidConfig)
assert str(err) == "no functions found"

# with pytest.raises(errors.InvalidConfig, match="no functions found"):
# handler.get_function_configs("http://foo.bar")
6 changes: 6 additions & 0 deletions inngest/_internal/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ class ErrorCode(enum.Enum):
UNSERIALIZABLE_OUTPUT = "unserializable_output"


class Framework(enum.StrEnum):
FAST_API = "fast_api"
FLASK = "flask"
TORNADO = "tornado"


class HeaderKey(enum.Enum):
CONTENT_TYPE = "Content-Type"
FORWARDED_FOR = "X-Forwarded-For"
Expand Down
13 changes: 12 additions & 1 deletion inngest/_internal/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import enum

from . import event_lib, types
from . import errors, event_lib, transforms, types


class Call(types.BaseModel):
Expand All @@ -23,11 +23,22 @@ class CallStack(types.BaseModel):


class CallError(types.BaseModel):
is_internal: bool
is_retriable: bool
message: str
name: str
stack: str

@classmethod
def from_error(cls, err: Exception) -> CallError:
return cls(
is_internal=isinstance(err, errors.InternalError),
is_retriable=isinstance(err, errors.NonRetriableError) is False,
message=str(err),
name=type(err).__name__,
stack=transforms.get_traceback(err),
)


class CallResponse(types.BaseModel):
data: object
Expand Down
18 changes: 7 additions & 11 deletions inngest/_internal/function/function_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import hashlib
import json
import traceback
import typing

from inngest._internal import (
Expand Down Expand Up @@ -99,10 +98,14 @@ async def call(
handler = self._handler
elif self.on_failure_fn_id == fn_id:
if self._opts.on_failure is None:
raise errors.MissingFunction("on_failure not defined")
return execution.CallError.from_error(
errors.MissingFunction("on_failure not defined")
)
handler = self._opts.on_failure
else:
raise errors.MissingFunction("function ID mismatch")
return execution.CallError.from_error(
errors.MissingFunction("function ID mismatch")
)

res = await handler(
attempt=call.ctx.attempt,
Expand All @@ -129,11 +132,4 @@ async def call(
)
]
except Exception as err:
is_retriable = isinstance(err, errors.NonRetriableError) is False

return execution.CallError(
is_retriable=is_retriable,
message=str(err),
name=type(err).__name__,
stack=traceback.format_exc(),
)
return execution.CallError.from_error(err)
18 changes: 7 additions & 11 deletions inngest/_internal/function/function_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import hashlib
import json
import traceback
import typing

from inngest._internal import (
Expand Down Expand Up @@ -107,10 +106,14 @@ def call(
handler = self._handler
elif self.on_failure_fn_id == fn_id:
if self._opts.on_failure is None:
raise errors.MissingFunction("on_failure not defined")
return execution.CallError.from_error(
errors.MissingFunction("on_failure not defined")
)
handler = self._opts.on_failure
else:
raise errors.MissingFunction("function ID mismatch")
return execution.CallError.from_error(
errors.MissingFunction("function ID mismatch")
)

res = handler(
attempt=call.ctx.attempt,
Expand All @@ -137,11 +140,4 @@ def call(
)
]
except Exception as err:
is_retriable = isinstance(err, errors.NonRetriableError) is False

return execution.CallError(
is_retriable=is_retriable,
message=str(err),
name=type(err).__name__,
stack=traceback.format_exc(),
)
return execution.CallError.from_error(err)
4 changes: 2 additions & 2 deletions inngest/_internal/net.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@

def create_headers(
*,
framework: str | None = None,
framework: const.Framework | None = None,
) -> dict[str, str]:
headers = {
const.HeaderKey.USER_AGENT.value: f"inngest-{const.LANGUAGE}:v{const.VERSION}",
const.HeaderKey.SDK.value: f"inngest-{const.LANGUAGE}:v{const.VERSION}",
}

if framework is not None:
headers[const.HeaderKey.FRAMEWORK.value] = framework
headers[const.HeaderKey.FRAMEWORK.value] = framework.value

return headers

Expand Down
4 changes: 2 additions & 2 deletions inngest/_internal/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pydantic

from . import function_config, types
from . import const, function_config, types


class DeployType(enum.Enum):
Expand All @@ -12,7 +12,7 @@ class DeployType(enum.Enum):
class RegisterRequest(types.BaseModel):
app_name: str
deploy_type: DeployType
framework: str
framework: const.Framework
functions: list[function_config.FunctionConfig] = pydantic.Field(
min_length=1
)
Expand Down
7 changes: 7 additions & 0 deletions inngest/_internal/transforms.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
import datetime
import hashlib
import re
import traceback

from . import errors, result, types


def get_traceback(err: Exception) -> str:
return "".join(
traceback.format_exception(type(err), err, err.__traceback__)
)


def hash_signing_key(key: str) -> str:
return hashlib.sha256(
bytearray.fromhex(remove_signing_key_prefix(key))
Expand Down
2 changes: 1 addition & 1 deletion inngest/fast_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def serve(
handler = comm.CommHandler(
base_url=base_url or client.base_url,
client=client,
framework="flask",
framework=const.Framework.FAST_API,
functions=functions,
logger=client.logger,
signing_key=signing_key,
Expand Down
2 changes: 1 addition & 1 deletion inngest/flask.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def serve(
handler = comm.CommHandler(
base_url=base_url or client.base_url,
client=client,
framework="flask",
framework=const.Framework.FLASK,
functions=functions,
logger=app.logger,
signing_key=signing_key,
Expand Down
2 changes: 1 addition & 1 deletion inngest/tornado.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def serve(
handler = comm.CommHandler(
base_url=base_url or client.base_url,
client=client,
framework="flask",
framework=const.Framework.TORNADO,
functions=functions,
logger=client.logger,
signing_key=signing_key,
Expand Down

0 comments on commit ded3e06

Please sign in to comment.