Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add in-band sync capability #155

Merged
merged 8 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
475 changes: 304 additions & 171 deletions inngest/_internal/comm_lib/handler.py

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions inngest/_internal/comm_lib/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ async def wrapper(
) -> CommResponse:
req.headers = net.normalize_headers(req.headers)

request_signing_key = net.validate_request(
request_signing_key = net.validate_sig(
body=req.body,
headers=req.headers,
mode=self._client._mode,
Expand Down Expand Up @@ -145,7 +145,7 @@ def wrapper(
) -> CommResponse:
req.headers = net.normalize_headers(req.headers)

request_signing_key = net.validate_request(
request_signing_key = net.validate_sig(
body=req.body,
headers=req.headers,
mode=self._client._mode,
Expand Down
2 changes: 2 additions & 0 deletions inngest/_internal/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import importlib.metadata
import typing

AUTHOR: typing.Final = "inngest"
DEFAULT_API_ORIGIN: typing.Final = "https://api.inngest.com/"
DEFAULT_EVENT_API_ORIGIN: typing.Final = "https://inn.gs/"
DEV_SERVER_ORIGIN: typing.Final = "http://127.0.0.1:8288/"
Expand All @@ -10,6 +11,7 @@


class EnvKey(enum.Enum):
ALLOW_IN_BAND_SYNC = "INNGEST_ALLOW_IN_BAND_SYNC"
API_BASE_URL = "INNGEST_API_BASE_URL"

# Sets both API and EVENT base URLs. API_BASE_URL and EVENT_API_BASE_URL
Expand Down
4 changes: 3 additions & 1 deletion inngest/_internal/const_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,6 @@ def test_version_matches_pyproject() -> None:
pyproject: dict[str, object] = toml.load(f)
project = pyproject.get("project")
assert isinstance(project, dict)
assert const.VERSION == project.get("version"), "If this is local development, run `make install`"
assert const.VERSION == project.get(
"version"
), "If this is local development, run `make install`"
9 changes: 9 additions & 0 deletions inngest/_internal/env_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,15 @@ def get_url(
return parsed


def is_true(env_var: const.EnvKey) -> bool:
val = os.getenv(env_var.value)
if val is None:
return False
val = val.strip()

return val.lower() in ("true", "1")


def is_truthy(env_var: const.EnvKey) -> bool:
val = os.getenv(env_var.value)
if val is None:
Expand Down
26 changes: 18 additions & 8 deletions inngest/_internal/net.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,25 +266,31 @@ async def fetch_with_thready_safety(
)


def sign(body: bytes, signing_key: str) -> types.MaybeError[str]:
def sign(
body: bytes,
signing_key: str,
unix_ms: typing.Optional[int] = None,
) -> types.MaybeError[str]:
if unix_ms is None:
unix_ms = round(time.time())

canonicalized = transforms.canonicalize(body)
if isinstance(canonicalized, Exception):
return canonicalized
raise canonicalized

mac = hmac.new(
transforms.remove_signing_key_prefix(signing_key).encode("utf-8"),
canonicalized,
hashlib.sha256,
)
unix_ms = round(time.time())
mac.update(str(unix_ms).encode("utf-8"))
sig = mac.hexdigest()

# Order matters since Inngest Cloud compares strings
return f"t={unix_ms}&s={sig}"


def _validate_request(
def _validate_sig(
*,
body: bytes,
headers: dict[str, str],
Expand All @@ -294,6 +300,10 @@ def _validate_request(
if mode == server_lib.ServerKind.DEV_SERVER:
return None

canonicalized = transforms.canonicalize(body)
if isinstance(canonicalized, Exception):
raise canonicalized

timestamp = None
signature = None
sig_header = headers.get(server_lib.HeaderKey.SIGNATURE.value)
Expand All @@ -320,7 +330,7 @@ def _validate_request(

mac = hmac.new(
transforms.remove_signing_key_prefix(signing_key).encode("utf-8"),
body,
canonicalized,
hashlib.sha256,
)

Expand All @@ -333,7 +343,7 @@ def _validate_request(
return signing_key


def validate_request(
def validate_sig(
*,
body: bytes,
headers: dict[str, str],
Expand All @@ -358,7 +368,7 @@ def validate_request(
if isinstance(canonicalized, Exception):
return canonicalized

err = _validate_request(
err = _validate_sig(
body=canonicalized,
headers=headers,
mode=mode,
Expand All @@ -368,7 +378,7 @@ def validate_request(
# If the signature validation failed but there's a "fallback"
# signing key, attempt to validate the signature with the fallback
# key
err = _validate_request(
err = _validate_sig(
body=canonicalized,
headers=headers,
mode=mode,
Expand Down
10 changes: 5 additions & 5 deletions inngest/_internal/net_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def test_success(self) -> None:
}

assert not isinstance(
net.validate_request(
net.validate_sig(
body=body,
headers=headers,
mode=server_lib.ServerKind.CLOUD,
Expand All @@ -128,7 +128,7 @@ def test_escape_sequences(self) -> None:
}

assert not isinstance(
net.validate_request(
net.validate_sig(
body=b'{"msg":"a \\u0026 b"}',
headers=headers,
mode=server_lib.ServerKind.CLOUD,
Expand All @@ -153,7 +153,7 @@ def test_body_tamper(self) -> None:

body = json.dumps({"msg": "you've been hacked"}).encode("utf-8")

validation = net.validate_request(
validation = net.validate_sig(
body=body,
headers=headers,
mode=server_lib.ServerKind.CLOUD,
Expand All @@ -177,7 +177,7 @@ def test_rotation(self) -> None:
}

assert not isinstance(
net.validate_request(
net.validate_sig(
body=body,
headers=headers,
mode=server_lib.ServerKind.CLOUD,
Expand All @@ -201,7 +201,7 @@ def test_fails_for_both_signing_keys(self) -> None:
}

assert isinstance(
net.validate_request(
net.validate_sig(
body=body,
headers=headers,
mode=server_lib.ServerKind.CLOUD,
Expand Down
6 changes: 6 additions & 0 deletions inngest/_internal/server_lib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
Probe,
QueryParamKey,
ServerKind,
SyncKind,
)
from .event import Event
from .execution_request import ServerRequest
Expand All @@ -21,6 +22,8 @@
Concurrency,
Debounce,
FunctionConfig,
InBandSynchronizeRequest,
InBandSynchronizeResponse,
Priority,
RateLimit,
Retries,
Expand All @@ -44,6 +47,8 @@
"Framework",
"FunctionConfig",
"HeaderKey",
"InBandSynchronizeRequest",
"InBandSynchronizeResponse",
"InternalEvents",
"Opcode",
"PREFERRED_EXECUTION_VERSION",
Expand All @@ -58,6 +63,7 @@
"ServerKind",
"ServerRequest",
"Step",
"SyncKind",
"Throttle",
"TriggerCron",
"TriggerEvent",
Expand Down
6 changes: 6 additions & 0 deletions inngest/_internal/server_lib/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class HeaderKey(enum.Enum):
SERVER_KIND = "x-inngest-server-kind"
SERVER_TIMING = "server-timing"
SIGNATURE = "x-inngest-signature"
SYNC_KIND = "x-inngest-sync-kind"
USER_AGENT = "user-agent"


Expand Down Expand Up @@ -88,6 +89,11 @@ class ServerKind(enum.Enum):
DEV_SERVER = "dev"


class SyncKind(enum.Enum):
IN_BAND = "in_band"
OUT_OF_BAND = "out_of_band"


# If the Server sends this step ID then it isn't targeting a specific step
UNSPECIFIED_STEP_ID: typing.Final = "step"

Expand Down
14 changes: 11 additions & 3 deletions inngest/_internal/server_lib/inspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,37 @@


class Capabilities(types.BaseModel):
in_band_sync: str = "v1"
trust_probe: str = "v1"


class UnauthenticatedInspection(types.BaseModel):
schema_version: str = "2024-05-24"

authentication_succeeded: typing.Optional[bool]
authentication_succeeded: typing.Optional[typing.Literal[False]]
function_count: int
has_event_key: bool
has_signing_key: bool
has_signing_key_fallback: bool
mode: ServerKind


class AuthenticatedInspection(UnauthenticatedInspection):
class AuthenticatedInspection(types.BaseModel):
schema_version: str = "2024-05-24"

api_origin: str
app_id: str
authentication_succeeded: bool = True
authentication_succeeded: typing.Literal[True] = True
capabilities: Capabilities = Capabilities()
env: typing.Optional[str]
event_api_origin: str
event_key_hash: typing.Optional[str]
framework: str
function_count: int
has_event_key: bool
has_signing_key: bool
has_signing_key_fallback: bool
mode: ServerKind
sdk_language: str = const.LANGUAGE
sdk_version: str = const.VERSION
serve_origin: typing.Optional[str]
Expand Down
21 changes: 19 additions & 2 deletions inngest/_internal/server_lib/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@

import pydantic

from inngest._internal import errors, transforms, types
from inngest._internal import const, errors, transforms, types

from .consts import DeployType, Framework
from .inspection import Capabilities
from .inspection import AuthenticatedInspection, Capabilities


class _BaseConfig(types.BaseModel):
Expand Down Expand Up @@ -174,3 +174,20 @@ class SynchronizeRequest(types.BaseModel):
sdk: str
url: str
v: str


class InBandSynchronizeRequest(types.BaseModel):
url: str


class InBandSynchronizeResponse(types.BaseModel):
app_id: str
env: typing.Optional[str]
framework: Framework
functions: list[FunctionConfig]
inspection: AuthenticatedInspection
platform: typing.Optional[str]
sdk_author: str = const.AUTHOR
sdk_language: str = const.LANGUAGE
sdk_version: str = const.VERSION
url: str
23 changes: 11 additions & 12 deletions inngest/flask.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,7 @@

import flask

from inngest._internal import (
client_lib,
comm_lib,
function,
server_lib,
transforms,
)
from inngest._internal import client_lib, comm_lib, function, server_lib

FRAMEWORK = server_lib.Framework.FLASK

Expand Down Expand Up @@ -76,7 +70,7 @@ def _create_handler_async(
@app.route("/api/inngest", methods=["GET", "POST", "PUT"])
async def inngest_api() -> typing.Union[flask.Response, str]:
comm_req = comm_lib.CommRequest(
body=flask.request.data,
body=_get_body_bytes(),
headers=dict(flask.request.headers.items()),
query_params=flask.request.args,
raw_request=flask.request,
Expand Down Expand Up @@ -118,7 +112,7 @@ def _create_handler_sync(
@app.route("/api/inngest", methods=["GET", "POST", "PUT"])
def inngest_api() -> typing.Union[flask.Response, str]:
comm_req = comm_lib.CommRequest(
body=flask.request.data,
body=_get_body_bytes(),
headers=dict(flask.request.headers.items()),
query_params=flask.request.args,
raw_request=flask.request,
Expand Down Expand Up @@ -149,17 +143,22 @@ def inngest_api() -> typing.Union[flask.Response, str]:
return ""


def _get_body_bytes() -> bytes:
flask.request.get_data(as_text=True)
return flask.request.data


def _to_response(
client: client_lib.Inngest,
comm_res: comm_lib.CommResponse,
) -> flask.Response:
body = transforms.dump_json(comm_res.body)
body = comm_res.body_bytes()
if isinstance(body, Exception):
comm_res = comm_lib.CommResponse.from_error(client.logger, body)
body = json.dumps(comm_res.body)
body = json.dumps(comm_res.body).encode("utf-8")

return flask.Response(
headers=comm_res.headers,
response=body.encode("utf-8"),
response=body,
status=comm_res.status_code,
)
Loading
Loading