Skip to content

Commit

Permalink
Add in-band sync capability (#155)
Browse files Browse the repository at this point in the history
  • Loading branch information
amh4r authored Sep 17, 2024
1 parent 28c3a4f commit f9aee3b
Show file tree
Hide file tree
Showing 31 changed files with 855 additions and 419 deletions.
475 changes: 304 additions & 171 deletions inngest/_internal/comm_lib/handler.py

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions inngest/_internal/comm_lib/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ async def wrapper(
) -> CommResponse:
req.headers = net.normalize_headers(req.headers)

request_signing_key = net.validate_request(
request_signing_key = net.validate_sig(
body=req.body,
headers=req.headers,
mode=self._client._mode,
Expand Down Expand Up @@ -145,7 +145,7 @@ def wrapper(
) -> CommResponse:
req.headers = net.normalize_headers(req.headers)

request_signing_key = net.validate_request(
request_signing_key = net.validate_sig(
body=req.body,
headers=req.headers,
mode=self._client._mode,
Expand Down
2 changes: 2 additions & 0 deletions inngest/_internal/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import importlib.metadata
import typing

AUTHOR: typing.Final = "inngest"
DEFAULT_API_ORIGIN: typing.Final = "https://api.inngest.com/"
DEFAULT_EVENT_API_ORIGIN: typing.Final = "https://inn.gs/"
DEV_SERVER_ORIGIN: typing.Final = "http://127.0.0.1:8288/"
Expand All @@ -10,6 +11,7 @@


class EnvKey(enum.Enum):
ALLOW_IN_BAND_SYNC = "INNGEST_ALLOW_IN_BAND_SYNC"
API_BASE_URL = "INNGEST_API_BASE_URL"

# Sets both API and EVENT base URLs. API_BASE_URL and EVENT_API_BASE_URL
Expand Down
4 changes: 3 additions & 1 deletion inngest/_internal/const_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,6 @@ def test_version_matches_pyproject() -> None:
pyproject: dict[str, object] = toml.load(f)
project = pyproject.get("project")
assert isinstance(project, dict)
assert const.VERSION == project.get("version"), "If this is local development, run `make install`"
assert const.VERSION == project.get(
"version"
), "If this is local development, run `make install`"
9 changes: 9 additions & 0 deletions inngest/_internal/env_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,15 @@ def get_url(
return parsed


def is_true(env_var: const.EnvKey) -> bool:
val = os.getenv(env_var.value)
if val is None:
return False
val = val.strip()

return val.lower() in ("true", "1")


def is_truthy(env_var: const.EnvKey) -> bool:
val = os.getenv(env_var.value)
if val is None:
Expand Down
26 changes: 18 additions & 8 deletions inngest/_internal/net.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,25 +266,31 @@ async def fetch_with_thready_safety(
)


def sign(body: bytes, signing_key: str) -> types.MaybeError[str]:
def sign(
body: bytes,
signing_key: str,
unix_ms: typing.Optional[int] = None,
) -> types.MaybeError[str]:
if unix_ms is None:
unix_ms = round(time.time())

canonicalized = transforms.canonicalize(body)
if isinstance(canonicalized, Exception):
return canonicalized
raise canonicalized

mac = hmac.new(
transforms.remove_signing_key_prefix(signing_key).encode("utf-8"),
canonicalized,
hashlib.sha256,
)
unix_ms = round(time.time())
mac.update(str(unix_ms).encode("utf-8"))
sig = mac.hexdigest()

# Order matters since Inngest Cloud compares strings
return f"t={unix_ms}&s={sig}"


def _validate_request(
def _validate_sig(
*,
body: bytes,
headers: dict[str, str],
Expand All @@ -294,6 +300,10 @@ def _validate_request(
if mode == server_lib.ServerKind.DEV_SERVER:
return None

canonicalized = transforms.canonicalize(body)
if isinstance(canonicalized, Exception):
raise canonicalized

timestamp = None
signature = None
sig_header = headers.get(server_lib.HeaderKey.SIGNATURE.value)
Expand All @@ -320,7 +330,7 @@ def _validate_request(

mac = hmac.new(
transforms.remove_signing_key_prefix(signing_key).encode("utf-8"),
body,
canonicalized,
hashlib.sha256,
)

Expand All @@ -333,7 +343,7 @@ def _validate_request(
return signing_key


def validate_request(
def validate_sig(
*,
body: bytes,
headers: dict[str, str],
Expand All @@ -358,7 +368,7 @@ def validate_request(
if isinstance(canonicalized, Exception):
return canonicalized

err = _validate_request(
err = _validate_sig(
body=canonicalized,
headers=headers,
mode=mode,
Expand All @@ -368,7 +378,7 @@ def validate_request(
# If the signature validation failed but there's a "fallback"
# signing key, attempt to validate the signature with the fallback
# key
err = _validate_request(
err = _validate_sig(
body=canonicalized,
headers=headers,
mode=mode,
Expand Down
10 changes: 5 additions & 5 deletions inngest/_internal/net_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def test_success(self) -> None:
}

assert not isinstance(
net.validate_request(
net.validate_sig(
body=body,
headers=headers,
mode=server_lib.ServerKind.CLOUD,
Expand All @@ -128,7 +128,7 @@ def test_escape_sequences(self) -> None:
}

assert not isinstance(
net.validate_request(
net.validate_sig(
body=b'{"msg":"a \\u0026 b"}',
headers=headers,
mode=server_lib.ServerKind.CLOUD,
Expand All @@ -153,7 +153,7 @@ def test_body_tamper(self) -> None:

body = json.dumps({"msg": "you've been hacked"}).encode("utf-8")

validation = net.validate_request(
validation = net.validate_sig(
body=body,
headers=headers,
mode=server_lib.ServerKind.CLOUD,
Expand All @@ -177,7 +177,7 @@ def test_rotation(self) -> None:
}

assert not isinstance(
net.validate_request(
net.validate_sig(
body=body,
headers=headers,
mode=server_lib.ServerKind.CLOUD,
Expand All @@ -201,7 +201,7 @@ def test_fails_for_both_signing_keys(self) -> None:
}

assert isinstance(
net.validate_request(
net.validate_sig(
body=body,
headers=headers,
mode=server_lib.ServerKind.CLOUD,
Expand Down
6 changes: 6 additions & 0 deletions inngest/_internal/server_lib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
Probe,
QueryParamKey,
ServerKind,
SyncKind,
)
from .event import Event
from .execution_request import ServerRequest
Expand All @@ -21,6 +22,8 @@
Concurrency,
Debounce,
FunctionConfig,
InBandSynchronizeRequest,
InBandSynchronizeResponse,
Priority,
RateLimit,
Retries,
Expand All @@ -44,6 +47,8 @@
"Framework",
"FunctionConfig",
"HeaderKey",
"InBandSynchronizeRequest",
"InBandSynchronizeResponse",
"InternalEvents",
"Opcode",
"PREFERRED_EXECUTION_VERSION",
Expand All @@ -58,6 +63,7 @@
"ServerKind",
"ServerRequest",
"Step",
"SyncKind",
"Throttle",
"TriggerCron",
"TriggerEvent",
Expand Down
6 changes: 6 additions & 0 deletions inngest/_internal/server_lib/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class HeaderKey(enum.Enum):
SERVER_KIND = "x-inngest-server-kind"
SERVER_TIMING = "server-timing"
SIGNATURE = "x-inngest-signature"
SYNC_KIND = "x-inngest-sync-kind"
USER_AGENT = "user-agent"


Expand Down Expand Up @@ -88,6 +89,11 @@ class ServerKind(enum.Enum):
DEV_SERVER = "dev"


class SyncKind(enum.Enum):
IN_BAND = "in_band"
OUT_OF_BAND = "out_of_band"


# If the Server sends this step ID then it isn't targeting a specific step
UNSPECIFIED_STEP_ID: typing.Final = "step"

Expand Down
14 changes: 11 additions & 3 deletions inngest/_internal/server_lib/inspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,37 @@


class Capabilities(types.BaseModel):
in_band_sync: str = "v1"
trust_probe: str = "v1"


class UnauthenticatedInspection(types.BaseModel):
schema_version: str = "2024-05-24"

authentication_succeeded: typing.Optional[bool]
authentication_succeeded: typing.Optional[typing.Literal[False]]
function_count: int
has_event_key: bool
has_signing_key: bool
has_signing_key_fallback: bool
mode: ServerKind


class AuthenticatedInspection(UnauthenticatedInspection):
class AuthenticatedInspection(types.BaseModel):
schema_version: str = "2024-05-24"

api_origin: str
app_id: str
authentication_succeeded: bool = True
authentication_succeeded: typing.Literal[True] = True
capabilities: Capabilities = Capabilities()
env: typing.Optional[str]
event_api_origin: str
event_key_hash: typing.Optional[str]
framework: str
function_count: int
has_event_key: bool
has_signing_key: bool
has_signing_key_fallback: bool
mode: ServerKind
sdk_language: str = const.LANGUAGE
sdk_version: str = const.VERSION
serve_origin: typing.Optional[str]
Expand Down
21 changes: 19 additions & 2 deletions inngest/_internal/server_lib/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@

import pydantic

from inngest._internal import errors, transforms, types
from inngest._internal import const, errors, transforms, types

from .consts import DeployType, Framework
from .inspection import Capabilities
from .inspection import AuthenticatedInspection, Capabilities


class _BaseConfig(types.BaseModel):
Expand Down Expand Up @@ -174,3 +174,20 @@ class SynchronizeRequest(types.BaseModel):
sdk: str
url: str
v: str


class InBandSynchronizeRequest(types.BaseModel):
url: str


class InBandSynchronizeResponse(types.BaseModel):
app_id: str
env: typing.Optional[str]
framework: Framework
functions: list[FunctionConfig]
inspection: AuthenticatedInspection
platform: typing.Optional[str]
sdk_author: str = const.AUTHOR
sdk_language: str = const.LANGUAGE
sdk_version: str = const.VERSION
url: str
23 changes: 11 additions & 12 deletions inngest/flask.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,7 @@

import flask

from inngest._internal import (
client_lib,
comm_lib,
function,
server_lib,
transforms,
)
from inngest._internal import client_lib, comm_lib, function, server_lib

FRAMEWORK = server_lib.Framework.FLASK

Expand Down Expand Up @@ -76,7 +70,7 @@ def _create_handler_async(
@app.route("/api/inngest", methods=["GET", "POST", "PUT"])
async def inngest_api() -> typing.Union[flask.Response, str]:
comm_req = comm_lib.CommRequest(
body=flask.request.data,
body=_get_body_bytes(),
headers=dict(flask.request.headers.items()),
query_params=flask.request.args,
raw_request=flask.request,
Expand Down Expand Up @@ -118,7 +112,7 @@ def _create_handler_sync(
@app.route("/api/inngest", methods=["GET", "POST", "PUT"])
def inngest_api() -> typing.Union[flask.Response, str]:
comm_req = comm_lib.CommRequest(
body=flask.request.data,
body=_get_body_bytes(),
headers=dict(flask.request.headers.items()),
query_params=flask.request.args,
raw_request=flask.request,
Expand Down Expand Up @@ -149,17 +143,22 @@ def inngest_api() -> typing.Union[flask.Response, str]:
return ""


def _get_body_bytes() -> bytes:
flask.request.get_data(as_text=True)
return flask.request.data


def _to_response(
client: client_lib.Inngest,
comm_res: comm_lib.CommResponse,
) -> flask.Response:
body = transforms.dump_json(comm_res.body)
body = comm_res.body_bytes()
if isinstance(body, Exception):
comm_res = comm_lib.CommResponse.from_error(client.logger, body)
body = json.dumps(comm_res.body)
body = json.dumps(comm_res.body).encode("utf-8")

return flask.Response(
headers=comm_res.headers,
response=body.encode("utf-8"),
response=body,
status=comm_res.status_code,
)
Loading

0 comments on commit f9aee3b

Please sign in to comment.