diff --git a/inngest/_internal/comm_lib/handler.py b/inngest/_internal/comm_lib/handler.py index d91b6d0e..db90de0b 100644 --- a/inngest/_internal/comm_lib/handler.py +++ b/inngest/_internal/comm_lib/handler.py @@ -11,6 +11,7 @@ from inngest._internal import ( client_lib, const, + env_lib, errors, execution_lib, function, @@ -42,6 +43,10 @@ def __init__( framework: server_lib.Framework, functions: list[function.Function], ) -> None: + # TODO: Default to true once in-band syncing is stable + self._allow_in_band_sync = env_lib.is_true( + const.EnvKey.ALLOW_IN_BAND_SYNC, + ) self._client = client self._mode = client._mode self._api_origin = client.api_origin @@ -332,174 +337,214 @@ def get_sync( status_code=403, ) - res_body: typing.Union[ - server_lib.AuthenticatedInspection, - server_lib.UnauthenticatedInspection, - ] - - is_signed_and_valid = isinstance(request_signing_key, str) - # Validate the request signature. - err = net.validate_request( - body=req.body, - headers=req.headers, - mode=self._client._mode, - signing_key=self._signing_key, - signing_key_fallback=self._signing_key_fallback, + inspection = _build_inspection_response( + self, + req, + request_signing_key, ) - if ( - self._client._mode != server_lib.ServerKind.CLOUD - or is_signed_and_valid is False - ): - authentication_succeeded = None - if isinstance(err, Exception): - authentication_succeeded = False - - res_body = server_lib.UnauthenticatedInspection( - authentication_succeeded=authentication_succeeded, - function_count=len(self._fns), - has_event_key=self._client.event_key is not None, - has_signing_key=self._signing_key is not None, - has_signing_key_fallback=self._signing_key_fallback is not None, - mode=self._mode, - ) - elif is_signed_and_valid is True: - event_key_hash = ( - transforms.hash_event_key(self._client.event_key) - if self._client.event_key - else None - ) - - signing_key_hash = ( - transforms.hash_signing_key(self._signing_key) - if self._signing_key - else None - ) - - signing_key_fallback_hash = ( - transforms.hash_signing_key(self._signing_key_fallback) - if self._signing_key_fallback - else None - ) - - res_body = server_lib.AuthenticatedInspection( - api_origin=self._client.api_origin, - app_id=self._client.app_id, - authentication_succeeded=True, - env=self._client.env, - event_api_origin=self._client.event_api_origin, - event_key_hash=event_key_hash, - framework=self._framework.value, - function_count=len(self._fns), - has_event_key=self._client.event_key is not None, - has_signing_key=self._signing_key is not None, - has_signing_key_fallback=self._signing_key_fallback is not None, - mode=self._mode, - serve_origin=req.serve_origin, - serve_path=req.serve_path, - signing_key_fallback_hash=signing_key_fallback_hash, - signing_key_hash=signing_key_hash, - ) + if isinstance(inspection, Exception): + return inspection - body_json = res_body.to_dict() - if isinstance(req.body, Exception): - body_json = { - "error": "failed to serialize inspection data", - } + res_body = inspection.to_dict() + if isinstance(res_body, Exception): + return res_body return CommResponse( - body=body_json, + body=res_body, status_code=200, ) - def _parse_registration_response( - self, - server_res: httpx.Response, - ) -> CommResponse: - try: - server_res_body = server_res.json() - except Exception: - return CommResponse.from_error( - self._client.logger, - errors.RegistrationFailedError("response is not valid JSON"), - ) - - if not isinstance(server_res_body, dict): - return CommResponse.from_error( - self._client.logger, - errors.RegistrationFailedError("response is not an object"), - ) + @wrap_handler(require_signature=False) + async def put( + self: CommHandler, + req: CommRequest, + request_signing_key: types.MaybeError[typing.Optional[str]], + ) -> typing.Union[CommResponse, Exception]: + """Handle a PUT request.""" - if server_res.status_code < 400: - return CommResponse( - body=server_res_body, - status_code=http.HTTPStatus.OK, - ) + syncer = _Syncer() - msg = server_res_body.get("error") - if not isinstance(msg, str): - msg = "registration failed" - comm_res = CommResponse.from_error( - self._client.logger, - errors.RegistrationFailedError(msg.strip()), - ) - comm_res.status_code = server_res.status_code - return comm_res + if ( + req.headers.get(server_lib.HeaderKey.SYNC_KIND.value) + == server_lib.SyncKind.IN_BAND.value + and self._allow_in_band_sync + ): + err: typing.Optional[Exception] = None + if isinstance(request_signing_key, Exception): + err = request_signing_key + elif request_signing_key is None: + err = Exception("request must be signed for in-band sync") + if err is not None: + return CommResponse.from_error( + self._client.logger, + err, + status=http.HTTPStatus.UNAUTHORIZED, + ) + return syncer.in_band(self, req, request_signing_key) + + return await syncer.out_of_band(self, req) - @wrap_handler(require_signature=False) - async def put( + @wrap_handler_sync(require_signature=False) + def put_sync( self: CommHandler, req: CommRequest, request_signing_key: types.MaybeError[typing.Optional[str]], ) -> typing.Union[CommResponse, Exception]: - """Handle a registration call.""" + """Handle a PUT request.""" - app_url = net.create_serve_url( - request_url=req.request_url, + syncer = _Syncer() + + if ( + req.headers.get(server_lib.HeaderKey.SYNC_KIND.value) + == server_lib.SyncKind.IN_BAND.value + and self._allow_in_band_sync + ): + err: typing.Optional[Exception] = None + if isinstance(request_signing_key, Exception): + err = request_signing_key + elif request_signing_key is None: + err = Exception("request must be signed for in-band sync") + if err is not None: + return CommResponse.from_error( + self._client.logger, + err, + status=http.HTTPStatus.UNAUTHORIZED, + ) + + return syncer.in_band(self, req, request_signing_key) + + return syncer.out_of_band_sync(self, req) + + +def _build_inspection_response( + handler: CommHandler, + req: CommRequest, + request_signing_key: types.MaybeError[typing.Optional[str]], +) -> types.MaybeError[ + typing.Union[ + server_lib.AuthenticatedInspection, + server_lib.UnauthenticatedInspection, + ] +]: + server_kind = transforms.get_server_kind(req.headers) + if isinstance(server_kind, Exception): + handler._client.logger.error(server_kind) + server_kind = None + + is_signed = isinstance(request_signing_key, str) + if is_signed: + event_key_hash = ( + transforms.hash_event_key(handler._client.event_key) + if handler._client.event_key + else None + ) + + signing_key_hash = ( + transforms.hash_signing_key(handler._signing_key) + if handler._signing_key + else None + ) + + signing_key_fallback_hash = ( + transforms.hash_signing_key(handler._signing_key_fallback) + if handler._signing_key_fallback + else None + ) + + return server_lib.AuthenticatedInspection( + api_origin=handler._client.api_origin, + app_id=handler._client.app_id, + authentication_succeeded=True, + env=handler._client.env, + event_api_origin=handler._client.event_api_origin, + event_key_hash=event_key_hash, + framework=handler._framework.value, + function_count=len(handler._fns), + has_event_key=handler._client.event_key is not None, + has_signing_key=handler._signing_key is not None, + has_signing_key_fallback=handler._signing_key_fallback is not None, + mode=handler._mode, serve_origin=req.serve_origin, serve_path=req.serve_path, + signing_key_fallback_hash=signing_key_fallback_hash, + signing_key_hash=signing_key_hash, ) - server_kind = transforms.get_server_kind(req.headers) - if isinstance(server_kind, Exception): - self._client.logger.error(server_kind) - server_kind = None + authentication_succeeded: typing.Optional[typing.Literal[False]] = None + if isinstance(request_signing_key, Exception): + authentication_succeeded = False - comm_res = self._validate_registration(server_kind) - if comm_res is not None: - return comm_res + return server_lib.UnauthenticatedInspection( + authentication_succeeded=authentication_succeeded, + function_count=len(handler._fns), + has_event_key=handler._client.event_key is not None, + has_signing_key=handler._signing_key is not None, + has_signing_key_fallback=handler._signing_key_fallback is not None, + mode=handler._mode, + ) - params = parse_query_params(req.query_params) - if isinstance(params, Exception): - return params - outgoing_req = self._build_register_request( - app_url=app_url, - server_kind=server_kind, - sync_id=params.sync_id, +class _Syncer: + def in_band( + self, + handler: CommHandler, + req: CommRequest, + request_signing_key: types.MaybeError[typing.Optional[str]], + ) -> types.MaybeError[CommResponse]: + if not isinstance(request_signing_key, str): + # This should be checked earlier, but we'll also check it here since + # it's critical + return Exception("request must be signed for in-band sync") + + req_body = server_lib.InBandSynchronizeRequest.from_raw(req.body) + if isinstance(req_body, Exception): + return req_body + + app_url = net.create_serve_url( + request_url=req_body.url, + serve_origin=req.serve_origin, + serve_path=req.serve_path, ) - if isinstance(outgoing_req, Exception): - return outgoing_req - res = await net.fetch_with_auth_fallback( - self._client._http_client, - self._client._http_client_sync, - outgoing_req, - signing_key=self._signing_key, - signing_key_fallback=self._signing_key_fallback, + fn_configs = handler.get_function_configs(app_url) + if isinstance(fn_configs, Exception): + return fn_configs + + inspection = _build_inspection_response( + handler, + req, + request_signing_key, ) - if isinstance(res, Exception): - return res + if isinstance(inspection, Exception): + return inspection + if isinstance(inspection, server_lib.UnauthenticatedInspection): + # Unreachable + return Exception("request must be signed for in-band sync") + + res_body = server_lib.InBandSynchronizeResponse( + app_id=handler._client.app_id, + env=handler._client.env, + framework=handler._framework, + functions=fn_configs, + inspection=inspection, + platform=None, + url=app_url, + ).to_dict() + if isinstance(res_body, Exception): + return res_body - return self._parse_registration_response(res) + return CommResponse( + body=res_body, + headers={ + server_lib.HeaderKey.SYNC_KIND.value: server_lib.SyncKind.IN_BAND.value, + }, + ) - @wrap_handler_sync(require_signature=False) - def put_sync( - self: CommHandler, + def _create_out_of_band_request( + self, + handler: CommHandler, req: CommRequest, - request_signing_key: types.MaybeError[typing.Optional[str]], - ) -> typing.Union[CommResponse, Exception]: - """Handle a registration call.""" - + ) -> types.MaybeError[typing.Union[CommResponse, httpx.Request]]: app_url = net.create_serve_url( request_url=req.request_url, serve_origin=req.serve_origin, @@ -508,52 +553,140 @@ def put_sync( server_kind = transforms.get_server_kind(req.headers) if isinstance(server_kind, Exception): - self._client.logger.error(server_kind) + handler._client.logger.error(server_kind) server_kind = None - comm_res = self._validate_registration(server_kind) - if comm_res is not None: - return comm_res + if server_kind is not None and server_kind != handler._mode: + msg: str + if server_kind == server_lib.ServerKind.DEV_SERVER: + msg = "Sync rejected since it's from a Dev Server but expected Cloud" + else: + msg = "Sync rejected since it's from Cloud but expected Dev Server" + + handler._client.logger.error(msg) + return CommResponse.from_error_code( + server_lib.ErrorCode.SERVER_KIND_MISMATCH, + msg, + http.HTTPStatus.BAD_REQUEST, + ) params = parse_query_params(req.query_params) if isinstance(params, Exception): return params - outgoing_req = self._build_register_request( - app_url=app_url, + registration_url = urllib.parse.urljoin( + handler._api_origin, + "/fn/register", + ) + + fn_configs = handler.get_function_configs(app_url) + if isinstance(fn_configs, Exception): + return fn_configs + + body = server_lib.SynchronizeRequest( + app_name=handler._client.app_id, + deploy_type=server_lib.DeployType.PING, + framework=handler._framework, + functions=fn_configs, + sdk=f"{const.LANGUAGE}:v{const.VERSION}", + url=app_url, + v="0.1", + ).to_dict() + if isinstance(body, Exception): + return body + + headers = net.create_headers( + env=handler._client.env, + framework=handler._framework, server_kind=server_kind, - sync_id=params.sync_id, ) - if isinstance(outgoing_req, Exception): - return outgoing_req - res = net.fetch_with_auth_fallback_sync( - self._client._http_client_sync, - outgoing_req, - signing_key=self._signing_key, - signing_key_fallback=self._signing_key_fallback, + outgoing_params = {} + if params.sync_id is not None: + outgoing_params[ + server_lib.QueryParamKey.SYNC_ID.value + ] = params.sync_id + + return handler._client._http_client_sync.build_request( + "POST", + registration_url, + headers=headers, + json=transforms.deep_strip_none(body), + params=outgoing_params, + timeout=30, + ) + + def _parse_out_of_band_response( + self, + handler: CommHandler, + res: httpx.Response, + ) -> types.MaybeError[CommResponse]: + try: + server_res_body = res.json() + except Exception: + return errors.RegistrationFailedError("response is not valid JSON") + + if not isinstance(server_res_body, dict): + return errors.RegistrationFailedError("response is not an object") + + if res.status_code >= 400: + msg = server_res_body.get("error") + if not isinstance(msg, str): + msg = "registration failed" + comm_res = CommResponse.from_error( + handler._client.logger, + errors.RegistrationFailedError(msg.strip()), + ) + comm_res.status_code = res.status_code + + return CommResponse( + body=server_res_body, + headers={ + server_lib.HeaderKey.SYNC_KIND.value: server_lib.SyncKind.OUT_OF_BAND.value, + }, + ) + + async def out_of_band( + self, + handler: CommHandler, + req: CommRequest, + ) -> types.MaybeError[CommResponse]: + prep = self._create_out_of_band_request(handler, req) + if isinstance(prep, Exception): + return prep + if isinstance(prep, CommResponse): + return prep + + res = await net.fetch_with_auth_fallback( + handler._client._http_client, + handler._client._http_client_sync, + prep, + signing_key=handler._signing_key, + signing_key_fallback=handler._signing_key_fallback, ) if isinstance(res, Exception): return res - return self._parse_registration_response(res) + return self._parse_out_of_band_response(handler, res) - def _validate_registration( + def out_of_band_sync( self, - server_kind: typing.Optional[server_lib.ServerKind], - ) -> typing.Optional[CommResponse]: - if server_kind is not None and server_kind != self._mode: - msg: str - if server_kind == server_lib.ServerKind.DEV_SERVER: - msg = "Sync rejected since it's from a Dev Server but expected Cloud" - else: - msg = "Sync rejected since it's from Cloud but expected Dev Server" + handler: CommHandler, + req: CommRequest, + ) -> types.MaybeError[CommResponse]: + prep = self._create_out_of_band_request(handler, req) + if isinstance(prep, Exception): + return prep + if isinstance(prep, CommResponse): + return prep - self._client.logger.error(msg) - return CommResponse.from_error_code( - server_lib.ErrorCode.SERVER_KIND_MISMATCH, - msg, - http.HTTPStatus.BAD_REQUEST, - ) + res = net.fetch_with_auth_fallback_sync( + handler._client._http_client_sync, + prep, + signing_key=handler._signing_key, + signing_key_fallback=handler._signing_key_fallback, + ) + if isinstance(res, Exception): + return res - return None + return self._parse_out_of_band_response(handler, res) diff --git a/inngest/_internal/comm_lib/utils.py b/inngest/_internal/comm_lib/utils.py index 7847066e..7357e1fd 100644 --- a/inngest/_internal/comm_lib/utils.py +++ b/inngest/_internal/comm_lib/utils.py @@ -76,7 +76,7 @@ async def wrapper( ) -> CommResponse: req.headers = net.normalize_headers(req.headers) - request_signing_key = net.validate_request( + request_signing_key = net.validate_sig( body=req.body, headers=req.headers, mode=self._client._mode, @@ -145,7 +145,7 @@ def wrapper( ) -> CommResponse: req.headers = net.normalize_headers(req.headers) - request_signing_key = net.validate_request( + request_signing_key = net.validate_sig( body=req.body, headers=req.headers, mode=self._client._mode, diff --git a/inngest/_internal/const.py b/inngest/_internal/const.py index f2bbac78..a40fc774 100644 --- a/inngest/_internal/const.py +++ b/inngest/_internal/const.py @@ -2,6 +2,7 @@ import importlib.metadata import typing +AUTHOR: typing.Final = "inngest" DEFAULT_API_ORIGIN: typing.Final = "https://api.inngest.com/" DEFAULT_EVENT_API_ORIGIN: typing.Final = "https://inn.gs/" DEV_SERVER_ORIGIN: typing.Final = "http://127.0.0.1:8288/" @@ -10,6 +11,7 @@ class EnvKey(enum.Enum): + ALLOW_IN_BAND_SYNC = "INNGEST_ALLOW_IN_BAND_SYNC" API_BASE_URL = "INNGEST_API_BASE_URL" # Sets both API and EVENT base URLs. API_BASE_URL and EVENT_API_BASE_URL diff --git a/inngest/_internal/const_test.py b/inngest/_internal/const_test.py index 830c4b9e..bf1aed9c 100644 --- a/inngest/_internal/const_test.py +++ b/inngest/_internal/const_test.py @@ -10,4 +10,6 @@ def test_version_matches_pyproject() -> None: pyproject: dict[str, object] = toml.load(f) project = pyproject.get("project") assert isinstance(project, dict) - assert const.VERSION == project.get("version"), "If this is local development, run `make install`" + assert const.VERSION == project.get( + "version" + ), "If this is local development, run `make install`" diff --git a/inngest/_internal/env_lib.py b/inngest/_internal/env_lib.py index 76f7c500..c4666a3d 100644 --- a/inngest/_internal/env_lib.py +++ b/inngest/_internal/env_lib.py @@ -38,6 +38,15 @@ def get_url( return parsed +def is_true(env_var: const.EnvKey) -> bool: + val = os.getenv(env_var.value) + if val is None: + return False + val = val.strip() + + return val.lower() in ("true", "1") + + def is_truthy(env_var: const.EnvKey) -> bool: val = os.getenv(env_var.value) if val is None: diff --git a/inngest/_internal/net.py b/inngest/_internal/net.py index bfac783c..978138f4 100644 --- a/inngest/_internal/net.py +++ b/inngest/_internal/net.py @@ -266,17 +266,23 @@ async def fetch_with_thready_safety( ) -def sign(body: bytes, signing_key: str) -> types.MaybeError[str]: +def sign( + body: bytes, + signing_key: str, + unix_ms: typing.Optional[int] = None, +) -> types.MaybeError[str]: + if unix_ms is None: + unix_ms = round(time.time()) + canonicalized = transforms.canonicalize(body) if isinstance(canonicalized, Exception): - return canonicalized + raise canonicalized mac = hmac.new( transforms.remove_signing_key_prefix(signing_key).encode("utf-8"), canonicalized, hashlib.sha256, ) - unix_ms = round(time.time()) mac.update(str(unix_ms).encode("utf-8")) sig = mac.hexdigest() @@ -284,7 +290,7 @@ def sign(body: bytes, signing_key: str) -> types.MaybeError[str]: return f"t={unix_ms}&s={sig}" -def _validate_request( +def _validate_sig( *, body: bytes, headers: dict[str, str], @@ -294,6 +300,10 @@ def _validate_request( if mode == server_lib.ServerKind.DEV_SERVER: return None + canonicalized = transforms.canonicalize(body) + if isinstance(canonicalized, Exception): + raise canonicalized + timestamp = None signature = None sig_header = headers.get(server_lib.HeaderKey.SIGNATURE.value) @@ -320,7 +330,7 @@ def _validate_request( mac = hmac.new( transforms.remove_signing_key_prefix(signing_key).encode("utf-8"), - body, + canonicalized, hashlib.sha256, ) @@ -333,7 +343,7 @@ def _validate_request( return signing_key -def validate_request( +def validate_sig( *, body: bytes, headers: dict[str, str], @@ -358,7 +368,7 @@ def validate_request( if isinstance(canonicalized, Exception): return canonicalized - err = _validate_request( + err = _validate_sig( body=canonicalized, headers=headers, mode=mode, @@ -368,7 +378,7 @@ def validate_request( # If the signature validation failed but there's a "fallback" # signing key, attempt to validate the signature with the fallback # key - err = _validate_request( + err = _validate_sig( body=canonicalized, headers=headers, mode=mode, diff --git a/inngest/_internal/net_test.py b/inngest/_internal/net_test.py index c2951d58..1dc8d7fd 100644 --- a/inngest/_internal/net_test.py +++ b/inngest/_internal/net_test.py @@ -109,7 +109,7 @@ def test_success(self) -> None: } assert not isinstance( - net.validate_request( + net.validate_sig( body=body, headers=headers, mode=server_lib.ServerKind.CLOUD, @@ -128,7 +128,7 @@ def test_escape_sequences(self) -> None: } assert not isinstance( - net.validate_request( + net.validate_sig( body=b'{"msg":"a \\u0026 b"}', headers=headers, mode=server_lib.ServerKind.CLOUD, @@ -153,7 +153,7 @@ def test_body_tamper(self) -> None: body = json.dumps({"msg": "you've been hacked"}).encode("utf-8") - validation = net.validate_request( + validation = net.validate_sig( body=body, headers=headers, mode=server_lib.ServerKind.CLOUD, @@ -177,7 +177,7 @@ def test_rotation(self) -> None: } assert not isinstance( - net.validate_request( + net.validate_sig( body=body, headers=headers, mode=server_lib.ServerKind.CLOUD, @@ -201,7 +201,7 @@ def test_fails_for_both_signing_keys(self) -> None: } assert isinstance( - net.validate_request( + net.validate_sig( body=body, headers=headers, mode=server_lib.ServerKind.CLOUD, diff --git a/inngest/_internal/server_lib/__init__.py b/inngest/_internal/server_lib/__init__.py index 9475663c..7c46733d 100644 --- a/inngest/_internal/server_lib/__init__.py +++ b/inngest/_internal/server_lib/__init__.py @@ -11,6 +11,7 @@ Probe, QueryParamKey, ServerKind, + SyncKind, ) from .event import Event from .execution_request import ServerRequest @@ -21,6 +22,8 @@ Concurrency, Debounce, FunctionConfig, + InBandSynchronizeRequest, + InBandSynchronizeResponse, Priority, RateLimit, Retries, @@ -44,6 +47,8 @@ "Framework", "FunctionConfig", "HeaderKey", + "InBandSynchronizeRequest", + "InBandSynchronizeResponse", "InternalEvents", "Opcode", "PREFERRED_EXECUTION_VERSION", @@ -58,6 +63,7 @@ "ServerKind", "ServerRequest", "Step", + "SyncKind", "Throttle", "TriggerCron", "TriggerEvent", diff --git a/inngest/_internal/server_lib/consts.py b/inngest/_internal/server_lib/consts.py index a41e3319..4044451e 100644 --- a/inngest/_internal/server_lib/consts.py +++ b/inngest/_internal/server_lib/consts.py @@ -56,6 +56,7 @@ class HeaderKey(enum.Enum): SERVER_KIND = "x-inngest-server-kind" SERVER_TIMING = "server-timing" SIGNATURE = "x-inngest-signature" + SYNC_KIND = "x-inngest-sync-kind" USER_AGENT = "user-agent" @@ -88,6 +89,11 @@ class ServerKind(enum.Enum): DEV_SERVER = "dev" +class SyncKind(enum.Enum): + IN_BAND = "in_band" + OUT_OF_BAND = "out_of_band" + + # If the Server sends this step ID then it isn't targeting a specific step UNSPECIFIED_STEP_ID: typing.Final = "step" diff --git a/inngest/_internal/server_lib/inspection.py b/inngest/_internal/server_lib/inspection.py index 6e5084e1..1d54306a 100644 --- a/inngest/_internal/server_lib/inspection.py +++ b/inngest/_internal/server_lib/inspection.py @@ -6,13 +6,14 @@ class Capabilities(types.BaseModel): + in_band_sync: str = "v1" trust_probe: str = "v1" class UnauthenticatedInspection(types.BaseModel): schema_version: str = "2024-05-24" - authentication_succeeded: typing.Optional[bool] + authentication_succeeded: typing.Optional[typing.Literal[False]] function_count: int has_event_key: bool has_signing_key: bool @@ -20,15 +21,22 @@ class UnauthenticatedInspection(types.BaseModel): mode: ServerKind -class AuthenticatedInspection(UnauthenticatedInspection): +class AuthenticatedInspection(types.BaseModel): + schema_version: str = "2024-05-24" + api_origin: str app_id: str - authentication_succeeded: bool = True + authentication_succeeded: typing.Literal[True] = True capabilities: Capabilities = Capabilities() env: typing.Optional[str] event_api_origin: str event_key_hash: typing.Optional[str] framework: str + function_count: int + has_event_key: bool + has_signing_key: bool + has_signing_key_fallback: bool + mode: ServerKind sdk_language: str = const.LANGUAGE sdk_version: str = const.VERSION serve_origin: typing.Optional[str] diff --git a/inngest/_internal/server_lib/registration.py b/inngest/_internal/server_lib/registration.py index a0160d2d..63b86897 100644 --- a/inngest/_internal/server_lib/registration.py +++ b/inngest/_internal/server_lib/registration.py @@ -5,10 +5,10 @@ import pydantic -from inngest._internal import errors, transforms, types +from inngest._internal import const, errors, transforms, types from .consts import DeployType, Framework -from .inspection import Capabilities +from .inspection import AuthenticatedInspection, Capabilities class _BaseConfig(types.BaseModel): @@ -174,3 +174,20 @@ class SynchronizeRequest(types.BaseModel): sdk: str url: str v: str + + +class InBandSynchronizeRequest(types.BaseModel): + url: str + + +class InBandSynchronizeResponse(types.BaseModel): + app_id: str + env: typing.Optional[str] + framework: Framework + functions: list[FunctionConfig] + inspection: AuthenticatedInspection + platform: typing.Optional[str] + sdk_author: str = const.AUTHOR + sdk_language: str = const.LANGUAGE + sdk_version: str = const.VERSION + url: str diff --git a/inngest/flask.py b/inngest/flask.py index 3b0675ba..48daedf5 100644 --- a/inngest/flask.py +++ b/inngest/flask.py @@ -5,13 +5,7 @@ import flask -from inngest._internal import ( - client_lib, - comm_lib, - function, - server_lib, - transforms, -) +from inngest._internal import client_lib, comm_lib, function, server_lib FRAMEWORK = server_lib.Framework.FLASK @@ -76,7 +70,7 @@ def _create_handler_async( @app.route("/api/inngest", methods=["GET", "POST", "PUT"]) async def inngest_api() -> typing.Union[flask.Response, str]: comm_req = comm_lib.CommRequest( - body=flask.request.data, + body=_get_body_bytes(), headers=dict(flask.request.headers.items()), query_params=flask.request.args, raw_request=flask.request, @@ -118,7 +112,7 @@ def _create_handler_sync( @app.route("/api/inngest", methods=["GET", "POST", "PUT"]) def inngest_api() -> typing.Union[flask.Response, str]: comm_req = comm_lib.CommRequest( - body=flask.request.data, + body=_get_body_bytes(), headers=dict(flask.request.headers.items()), query_params=flask.request.args, raw_request=flask.request, @@ -149,17 +143,22 @@ def inngest_api() -> typing.Union[flask.Response, str]: return "" +def _get_body_bytes() -> bytes: + flask.request.get_data(as_text=True) + return flask.request.data + + def _to_response( client: client_lib.Inngest, comm_res: comm_lib.CommResponse, ) -> flask.Response: - body = transforms.dump_json(comm_res.body) + body = comm_res.body_bytes() if isinstance(body, Exception): comm_res = comm_lib.CommResponse.from_error(client.logger, body) - body = json.dumps(comm_res.body) + body = json.dumps(comm_res.body).encode("utf-8") return flask.Response( headers=comm_res.headers, - response=body.encode("utf-8"), + response=body, status=comm_res.status_code, ) diff --git a/tests/base.py b/tests/base.py index 55ee78f1..41bce07b 100644 --- a/tests/base.py +++ b/tests/base.py @@ -1,6 +1,4 @@ import datetime -import hashlib -import hmac import os import time import typing @@ -10,7 +8,7 @@ import httpx import inngest -from inngest._internal import const, server_lib, transforms +from inngest._internal import const, server_lib from . import http_proxy, net @@ -129,48 +127,6 @@ def set_signing_key_fallback_env_var(self) -> str: ) return signing_key - def create_signature(self, signing_key: typing.Optional[str] = None) -> str: - if signing_key is None: - signing_key = self.signing_key - - mac = hmac.new( - transforms.remove_signing_key_prefix(signing_key).encode("utf-8"), - b"", - hashlib.sha256, - ) - unix_ms = round(time.time() * 1000) - mac.update(str(unix_ms).encode("utf-8")) - sig = mac.hexdigest() - return f"s={sig}&t={unix_ms}" - - def validate_signature( - self, - sig: str, - body: bytes, - signing_key: typing.Optional[str] = None, - ) -> None: - canonicalized = transforms.canonicalize(body) - assert not isinstance(canonicalized, Exception) - - if signing_key is None: - signing_key = self.signing_key - - parsed = urllib.parse.parse_qs(sig) - timestamp = int(parsed["t"][0]) - signature = parsed["s"][0] - - mac = hmac.new( - transforms.remove_signing_key_prefix(signing_key).encode("utf-8"), - canonicalized, - hashlib.sha256, - ) - - if timestamp: - mac.update(str(timestamp).encode("utf-8")) - - if not hmac.compare_digest(signature, mac.hexdigest()): - raise Exception("invalid signature") - class BaseTestIntrospection(BaseTest): framework: server_lib.Framework @@ -192,6 +148,7 @@ def setUp(self) -> None: "app_id": "my-app", "authentication_succeeded": True, "capabilities": { + "in_band_sync": "v1", "trust_probe": "v1", }, "env": None, diff --git a/tests/net.py b/tests/net.py index f172da5f..ab183656 100644 --- a/tests/net.py +++ b/tests/net.py @@ -1,5 +1,3 @@ -import hashlib -import hmac import random import socket import time @@ -10,17 +8,6 @@ _used_ports: set[int] = set() -def create_signature(body: bytes, signing_key: str, unix_ms: int) -> str: - mac = hmac.new( - signing_key.encode("utf-8"), - body, - hashlib.sha256, - ) - mac.update(str(unix_ms).encode("utf-8")) - sig = mac.hexdigest() - return f"s={sig}&t={unix_ms}" - - def get_available_port() -> int: start_time = time.time() diff --git a/tests/test_execution/test_fast_api.py b/tests/test_execution/test_fast_api.py index b30bbde3..629c25ba 100644 --- a/tests/test_execution/test_fast_api.py +++ b/tests/test_execution/test_fast_api.py @@ -42,14 +42,14 @@ def test_invalid_signature(self) -> None: ) ) wrong_signing_key = "signkey-prod-111111" - - sig = net.sign(b"{}", wrong_signing_key) - assert not isinstance(sig, Exception) + req_sig = net.sign(b"{}", wrong_signing_key) + if isinstance(req_sig, Exception): + raise req_sig res = fast_api_client.post( "/api/inngest?fnId=my-fn&stepId=step", headers={ - server_lib.HeaderKey.SIGNATURE.value: sig, + server_lib.HeaderKey.SIGNATURE.value: req_sig, }, ) assert res.status_code == http.HTTPStatus.UNAUTHORIZED diff --git a/tests/test_execution/test_flask.py b/tests/test_execution/test_flask.py index cdb4b7e4..a43a030d 100644 --- a/tests/test_execution/test_flask.py +++ b/tests/test_execution/test_flask.py @@ -44,14 +44,14 @@ def test_invalid_signature(self) -> None: ) ) wrong_signing_key = "signkey-prod-111111" - - sig = net.sign(b"{}", wrong_signing_key) - assert not isinstance(sig, Exception) + req_sig = net.sign(b"{}", wrong_signing_key) + if isinstance(req_sig, Exception): + raise req_sig res = flask_client.post( "/api/inngest?fnId=my-fn&stepId=step", headers={ - server_lib.HeaderKey.SIGNATURE.value: sig, + server_lib.HeaderKey.SIGNATURE.value: req_sig, }, ) assert res.status_code == http.HTTPStatus.UNAUTHORIZED diff --git a/tests/test_introspection/test_digital_ocean.py b/tests/test_introspection/test_digital_ocean.py index a5bee50b..5d9aab78 100644 --- a/tests/test_introspection/test_digital_ocean.py +++ b/tests/test_introspection/test_digital_ocean.py @@ -6,7 +6,7 @@ import inngest import inngest.digital_ocean import inngest.fast_api -from inngest._internal import server_lib +from inngest._internal import net, server_lib from inngest.experimental import digital_ocean_simulator from tests import base @@ -51,10 +51,14 @@ def test_cloud_mode_with_signature(self) -> None: ) ) + req_sig = net.sign(b"", self.signing_key) + if isinstance(req_sig, Exception): + raise req_sig + res = app_client.get( digital_ocean_simulator.FULL_PATH, headers={ - server_lib.HeaderKey.SIGNATURE.value: self.create_signature(), + server_lib.HeaderKey.SIGNATURE.value: req_sig, }, ) assert res.status_code == 200 @@ -62,9 +66,15 @@ def test_cloud_mode_with_signature(self) -> None: **self.expected_authed_body, "has_signing_key_fallback": True, } - self.validate_signature( - res.headers[server_lib.HeaderKey.SIGNATURE.value], - res.get_data(), + assert isinstance( + net.validate_sig( + body=res.get_data(), + headers=res.headers, + mode=server_lib.ServerKind.CLOUD, + signing_key=self.signing_key, + signing_key_fallback=None, + ), + str, ) def test_dev_mode_with_no_signature(self) -> None: diff --git a/tests/test_introspection/test_fast_api.py b/tests/test_introspection/test_fast_api.py index 060c3e0d..4f3ea8de 100644 --- a/tests/test_introspection/test_fast_api.py +++ b/tests/test_introspection/test_fast_api.py @@ -5,7 +5,7 @@ import inngest import inngest.fast_api -from inngest._internal import server_lib +from inngest._internal import net, server_lib from tests import base @@ -47,10 +47,15 @@ def test_cloud_mode_with_signature(self) -> None: signing_key=self.signing_key, ) ) + + req_sig = net.sign(b"", self.signing_key) + if isinstance(req_sig, Exception): + raise req_sig + res = fast_api_client.get( "/api/inngest", headers={ - server_lib.HeaderKey.SIGNATURE.value: self.create_signature(), + server_lib.HeaderKey.SIGNATURE.value: req_sig, }, ) assert res.status_code == 200 @@ -58,9 +63,15 @@ def test_cloud_mode_with_signature(self) -> None: **self.expected_authed_body, "has_signing_key_fallback": True, } - self.validate_signature( - res.headers[server_lib.HeaderKey.SIGNATURE.value], - res.text.encode("utf-8"), + assert isinstance( + net.validate_sig( + body=res.content, + headers=dict(res.headers), + mode=server_lib.ServerKind.CLOUD, + signing_key=self.signing_key, + signing_key_fallback=None, + ), + str, ) def test_cloud_mode_with_signature_fallback(self) -> None: @@ -76,12 +87,15 @@ def test_cloud_mode_with_signature_fallback(self) -> None: signing_key=self.signing_key, ) ) + + req_sig = net.sign(b"", signing_key_fallback) + if isinstance(req_sig, Exception): + raise req_sig + res = fast_api_client.get( "/api/inngest", headers={ - server_lib.HeaderKey.SIGNATURE.value: self.create_signature( - signing_key_fallback - ), + server_lib.HeaderKey.SIGNATURE.value: req_sig, }, ) assert res.status_code == 200 @@ -89,11 +103,16 @@ def test_cloud_mode_with_signature_fallback(self) -> None: **self.expected_authed_body, "has_signing_key_fallback": True, } - print(res.text) - self.validate_signature( - res.headers[server_lib.HeaderKey.SIGNATURE.value], - res.text.encode("utf-8"), - signing_key_fallback, + + assert isinstance( + net.validate_sig( + body=res.content, + headers=dict(res.headers), + mode=server_lib.ServerKind.CLOUD, + signing_key=signing_key_fallback, + signing_key_fallback=None, + ), + str, ) def test_dev_mode_with_no_signature(self) -> None: diff --git a/tests/test_introspection/test_flask.py b/tests/test_introspection/test_flask.py index 2aad5539..1171ad05 100644 --- a/tests/test_introspection/test_flask.py +++ b/tests/test_introspection/test_flask.py @@ -6,7 +6,7 @@ import inngest import inngest.flask -from inngest._internal import server_lib +from inngest._internal import net, server_lib from tests import base @@ -48,20 +48,33 @@ def test_cloud_mode_with_signature(self) -> None: signing_key=self.signing_key, ) ) + + req_sig = net.sign(b"", self.signing_key) + if isinstance(req_sig, Exception): + raise req_sig + res = flask_client.get( "/api/inngest", headers={ - server_lib.HeaderKey.SIGNATURE.value: self.create_signature(), + server_lib.HeaderKey.SIGNATURE.value: req_sig, }, ) + assert res.status_code == 200 assert res.json == { **self.expected_authed_body, "has_signing_key_fallback": True, } - self.validate_signature( - res.headers[server_lib.HeaderKey.SIGNATURE.value], - res.get_data(), + + assert isinstance( + net.validate_sig( + body=res.get_data(), + headers=res.headers, + mode=server_lib.ServerKind.CLOUD, + signing_key=self.signing_key, + signing_key_fallback=None, + ), + str, ) def test_cloud_mode_with_signature_fallback(self) -> None: @@ -77,12 +90,15 @@ def test_cloud_mode_with_signature_fallback(self) -> None: signing_key=self.signing_key, ) ) + + req_sig = net.sign(b"", signing_key_fallback) + if isinstance(req_sig, Exception): + raise req_sig + res = flask_client.get( "/api/inngest", headers={ - server_lib.HeaderKey.SIGNATURE.value: self.create_signature( - signing_key_fallback - ), + server_lib.HeaderKey.SIGNATURE.value: req_sig, }, ) assert res.status_code == 200 @@ -90,10 +106,15 @@ def test_cloud_mode_with_signature_fallback(self) -> None: **self.expected_authed_body, "has_signing_key_fallback": True, } - self.validate_signature( - res.headers[server_lib.HeaderKey.SIGNATURE.value], - res.get_data(), - signing_key_fallback, + assert isinstance( + net.validate_sig( + body=res.get_data(), + headers=res.headers, + mode=server_lib.ServerKind.CLOUD, + signing_key=signing_key_fallback, + signing_key_fallback=None, + ), + str, ) def test_dev_mode_with_no_signature(self) -> None: diff --git a/tests/test_probes/test_fast_api.py b/tests/test_probes/test_fast_api.py index c1dd2041..98668bd5 100644 --- a/tests/test_probes/test_fast_api.py +++ b/tests/test_probes/test_fast_api.py @@ -5,7 +5,7 @@ import inngest import inngest.fast_api -from inngest._internal import server_lib +from inngest._internal import net, server_lib from tests import base @@ -27,17 +27,31 @@ def test_signed(self) -> None: signing_key=self.signing_key, ) ) + + req_sig = net.sign(b"", self.signing_key) + if isinstance(req_sig, Exception): + raise req_sig + res = fast_api_client.post( "/api/inngest?probe=trust", headers={ - server_lib.HeaderKey.SIGNATURE.value: self.create_signature(), + server_lib.HeaderKey.SIGNATURE.value: req_sig, }, ) assert res.status_code == 200 sig_header = res.headers.get(server_lib.HeaderKey.SIGNATURE.value) assert sig_header is not None - self.validate_signature(sig_header, res.content) + assert isinstance( + net.validate_sig( + body=res.content, + headers=dict(res.headers), + mode=server_lib.ServerKind.CLOUD, + signing_key=self.signing_key, + signing_key_fallback=None, + ), + str, + ) def test_unsigned(self) -> None: fast_api_client = self._serve( @@ -72,12 +86,15 @@ def test_incorrectly_signed(self) -> None: signing_key=self.signing_key, ) ) + + req_sig = net.sign(b"", "wrong") + if isinstance(req_sig, Exception): + raise req_sig + res = fast_api_client.post( "/api/inngest?probe=trust", headers={ - server_lib.HeaderKey.SIGNATURE.value: self.create_signature( - signing_key="wrong" - ), + server_lib.HeaderKey.SIGNATURE.value: req_sig, }, ) assert res.status_code == 401 diff --git a/tests/test_probes/test_flask.py b/tests/test_probes/test_flask.py index b236d45f..f9a1dfa8 100644 --- a/tests/test_probes/test_flask.py +++ b/tests/test_probes/test_flask.py @@ -6,7 +6,7 @@ import inngest import inngest.flask -from inngest._internal import server_lib +from inngest._internal import net, server_lib from tests import base @@ -28,17 +28,31 @@ def test_signed(self) -> None: signing_key=self.signing_key, ) ) + + req_sig = net.sign(b"", self.signing_key) + if isinstance(req_sig, Exception): + raise req_sig + res = flask_client.post( "/api/inngest?probe=trust", headers={ - server_lib.HeaderKey.SIGNATURE.value: self.create_signature(), + server_lib.HeaderKey.SIGNATURE.value: req_sig, }, ) assert res.status_code == 200 sig_header = res.headers.get(server_lib.HeaderKey.SIGNATURE.value) assert sig_header is not None - self.validate_signature(sig_header, res.get_data()) + assert isinstance( + net.validate_sig( + body=res.get_data(), + headers=res.headers, + mode=server_lib.ServerKind.CLOUD, + signing_key=self.signing_key, + signing_key_fallback=None, + ), + str, + ) def test_unsigned(self) -> None: flask_client = self._serve( @@ -73,12 +87,15 @@ def test_incorrectly_signed(self) -> None: signing_key=self.signing_key, ) ) + + req_sig = net.sign(b"", "wrong") + if isinstance(req_sig, Exception): + raise req_sig + res = flask_client.post( "/api/inngest?probe=trust", headers={ - server_lib.HeaderKey.SIGNATURE.value: self.create_signature( - signing_key="wrong" - ), + server_lib.HeaderKey.SIGNATURE.value: req_sig, }, ) assert res.status_code == 401 diff --git a/tests/test_registration/base.py b/tests/test_registration/base.py index d15ba210..23e57dc5 100644 --- a/tests/test_registration/base.py +++ b/tests/test_registration/base.py @@ -1,4 +1,5 @@ import dataclasses +import typing import unittest import inngest @@ -9,12 +10,18 @@ @dataclasses.dataclass class RegistrationResponse: - body: object + body: bytes + headers: dict[str, str] status_code: int class TestCase(unittest.TestCase): - def register(self, headers: dict[str, str]) -> RegistrationResponse: + def put( + self, + *, + body: typing.Union[dict[str, object], bytes], + headers: typing.Optional[dict[str, str]] = None, + ) -> RegistrationResponse: raise NotImplementedError() def serve( diff --git a/tests/test_registration/cases/__init__.py b/tests/test_registration/cases/__init__.py index 2a6a7836..27e71bd7 100644 --- a/tests/test_registration/cases/__init__.py +++ b/tests/test_registration/cases/__init__.py @@ -1,9 +1,19 @@ from inngest._internal import server_lib -from . import base, cloud_branch_env, server_kind_mismatch +from . import ( + base, + cloud_branch_env, + in_sync_invalid_sig, + in_sync_missing_sig, + out_of_band, + server_kind_mismatch, +) _modules = ( cloud_branch_env, + in_sync_invalid_sig, + in_sync_missing_sig, + out_of_band, server_kind_mismatch, ) diff --git a/tests/test_registration/cases/cloud_branch_env.py b/tests/test_registration/cases/cloud_branch_env.py index 32fd40d1..941fe7de 100644 --- a/tests/test_registration/cases/cloud_branch_env.py +++ b/tests/test_registration/cases/cloud_branch_env.py @@ -1,11 +1,8 @@ -import dataclasses import json -import typing import inngest import inngest.fast_api -from inngest._internal import const, server_lib -from tests import http_proxy +from inngest._internal import net, server_lib from . import base @@ -17,41 +14,14 @@ def run_test(self: base.TestCase) -> None: """ Test that the SDK correctly syncs itself with Cloud when using a branch environment. - - We need to use a mock Cloud since the Dev Server doesn't have a mode - that simulates Cloud. """ - @dataclasses.dataclass - class State: - headers: dict[str, list[str]] - - state = State(headers={}) - - def on_request( - *, - body: typing.Optional[bytes], - headers: dict[str, list[str]], - method: str, - path: str, - ) -> http_proxy.Response: - for k, v in headers.items(): - state.headers[k] = v - - return http_proxy.Response( - body=json.dumps({}).encode("utf-8"), - headers={}, - status_code=200, - ) - - mock_cloud = http_proxy.Proxy(on_request).start() - self.addCleanup(mock_cloud.stop) + signing_key = "signkey-prod-000000" client = inngest.Inngest( - api_base_url=f"http://localhost:{mock_cloud.port}", app_id=f"{framework.value}-{_TEST_NAME}", env="my-env", - signing_key="signkey-prod-0486c9", + signing_key=signing_key, ) @client.create_function( @@ -66,14 +36,39 @@ def fn( pass self.serve(client, [fn]) - res = self.register({}) + + req_body = json.dumps( + server_lib.InBandSynchronizeRequest( + url="http://test.local" + ).to_dict() + ).encode("utf-8") + + req_sig = net.sign(req_body, signing_key) + if isinstance(req_sig, Exception): + raise req_sig + + res = self.put( + body=req_body, + headers={ + server_lib.HeaderKey.SIGNATURE.value: req_sig, + server_lib.HeaderKey.SYNC_KIND.value: server_lib.SyncKind.IN_BAND.value, + }, + ) assert res.status_code == 200 - assert state.headers.get("authorization") is not None - assert state.headers.get("x-inngest-env") == ["my-env"] - assert state.headers.get("x-inngest-framework") == [framework.value] - assert state.headers.get("x-inngest-sdk") == [ - f"inngest-py:v{const.VERSION}" - ] + assert res.headers["x-inngest-env"] == "my-env" + assert res.headers["x-inngest-expected-server-kind"] == "cloud" + assert res.headers["x-inngest-sync-kind"] == "in_band" + + assert isinstance( + net.validate_sig( + body=res.body, + headers=res.headers, + mode=server_lib.ServerKind.CLOUD, + signing_key=signing_key, + signing_key_fallback=None, + ), + str, + ) return base.Case( name=_TEST_NAME, diff --git a/tests/test_registration/cases/in_sync_invalid_sig.py b/tests/test_registration/cases/in_sync_invalid_sig.py new file mode 100644 index 00000000..61fdad97 --- /dev/null +++ b/tests/test_registration/cases/in_sync_invalid_sig.py @@ -0,0 +1,66 @@ +import json + +import inngest +import inngest.fast_api +from inngest._internal import net, server_lib + +from . import base + +_TEST_NAME = base.create_test_name(__file__) + + +def create(framework: server_lib.Framework) -> base.Case: + def run_test(self: base.TestCase) -> None: + """ + Test that the SDK correctly syncs itself with Cloud when using a branch + environment. + """ + + signing_key = "signkey-prod-000000" + + client = inngest.Inngest( + app_id=f"{framework.value}-{_TEST_NAME}", + env="my-env", + signing_key=signing_key, + ) + + @client.create_function( + fn_id="foo", + retries=0, + trigger=inngest.TriggerEvent(event="app/foo"), + ) + def fn( + ctx: inngest.Context, + step: inngest.StepSync, + ) -> None: + pass + + self.serve(client, [fn]) + + req_body = json.dumps( + server_lib.InBandSynchronizeRequest( + url="http://test.local" + ).to_dict() + ).encode("utf-8") + + wrong_signing_key = "signkey-prod-111111" + req_sig = net.sign(req_body, wrong_signing_key) + if isinstance(req_sig, Exception): + raise req_sig + + res = self.put( + body=req_body, + headers={ + server_lib.HeaderKey.SIGNATURE.value: req_sig, + server_lib.HeaderKey.SYNC_KIND.value: server_lib.SyncKind.IN_BAND.value, + }, + ) + assert res.status_code == 401 + assert res.headers["x-inngest-env"] == "my-env" + assert res.headers["x-inngest-expected-server-kind"] == "cloud" + assert "x-inngest-sync-kind" not in res.headers + + return base.Case( + name=_TEST_NAME, + run_test=run_test, + ) diff --git a/tests/test_registration/cases/in_sync_missing_sig.py b/tests/test_registration/cases/in_sync_missing_sig.py new file mode 100644 index 00000000..078b0651 --- /dev/null +++ b/tests/test_registration/cases/in_sync_missing_sig.py @@ -0,0 +1,59 @@ +import json + +import inngest +import inngest.fast_api +from inngest._internal import server_lib + +from . import base + +_TEST_NAME = base.create_test_name(__file__) + + +def create(framework: server_lib.Framework) -> base.Case: + def run_test(self: base.TestCase) -> None: + """ + Test that the SDK correctly syncs itself with Cloud when using a branch + environment. + """ + + signing_key = "signkey-prod-000000" + + client = inngest.Inngest( + app_id=f"{framework.value}-{_TEST_NAME}", + env="my-env", + signing_key=signing_key, + ) + + @client.create_function( + fn_id="foo", + retries=0, + trigger=inngest.TriggerEvent(event="app/foo"), + ) + def fn( + ctx: inngest.Context, + step: inngest.StepSync, + ) -> None: + pass + + req_body = json.dumps( + server_lib.InBandSynchronizeRequest( + url="http://test.local" + ).to_dict() + ).encode("utf-8") + + self.serve(client, [fn]) + res = self.put( + body=req_body, + headers={ + server_lib.HeaderKey.SYNC_KIND.value: server_lib.SyncKind.IN_BAND.value, + }, + ) + assert res.status_code == 401 + assert res.headers["x-inngest-env"] == "my-env" + assert res.headers["x-inngest-expected-server-kind"] == "cloud" + assert "x-inngest-sync-kind" not in res.headers + + return base.Case( + name=_TEST_NAME, + run_test=run_test, + ) diff --git a/tests/test_registration/cases/out_of_band.py b/tests/test_registration/cases/out_of_band.py new file mode 100644 index 00000000..3d0876bf --- /dev/null +++ b/tests/test_registration/cases/out_of_band.py @@ -0,0 +1,81 @@ +import dataclasses +import json +import typing + +import inngest +import inngest.fast_api +from inngest._internal import const, server_lib +from tests import http_proxy + +from . import base + +_TEST_NAME = base.create_test_name(__file__) + + +def create(framework: server_lib.Framework) -> base.Case: + def run_test(self: base.TestCase) -> None: + """ + Test that the SDK correctly syncs itself with Cloud when using a branch + environment. + + We need to use a mock Cloud since the Dev Server doesn't have a mode + that simulates Cloud. + """ + + @dataclasses.dataclass + class State: + headers: dict[str, list[str]] + + state = State(headers={}) + + def on_request( + *, + body: typing.Optional[bytes], + headers: dict[str, list[str]], + method: str, + path: str, + ) -> http_proxy.Response: + for k, v in headers.items(): + state.headers[k] = v + + return http_proxy.Response( + body=json.dumps({}).encode("utf-8"), + headers={}, + status_code=200, + ) + + mock_cloud = http_proxy.Proxy(on_request).start() + self.addCleanup(mock_cloud.stop) + + client = inngest.Inngest( + api_base_url=f"http://localhost:{mock_cloud.port}", + app_id=f"{framework.value}-{_TEST_NAME}", + env="my-env", + signing_key="signkey-prod-0486c9", + ) + + @client.create_function( + fn_id="foo", + retries=0, + trigger=inngest.TriggerEvent(event="app/foo"), + ) + def fn( + ctx: inngest.Context, + step: inngest.StepSync, + ) -> None: + pass + + self.serve(client, [fn]) + res = self.put(body={}) + assert res.status_code == 200 + assert state.headers.get("authorization") is not None + assert state.headers.get("x-inngest-env") == ["my-env"] + assert state.headers.get("x-inngest-framework") == [framework.value] + assert state.headers.get("x-inngest-sdk") == [ + f"inngest-py:v{const.VERSION}" + ] + + return base.Case( + name=_TEST_NAME, + run_test=run_test, + ) diff --git a/tests/test_registration/cases/server_kind_mismatch.py b/tests/test_registration/cases/server_kind_mismatch.py index 1f439cdb..9b9b8178 100644 --- a/tests/test_registration/cases/server_kind_mismatch.py +++ b/tests/test_registration/cases/server_kind_mismatch.py @@ -1,3 +1,5 @@ +import json + import inngest import inngest.fast_api from inngest._internal import server_lib @@ -36,11 +38,12 @@ async def fn( headers = { server_lib.HeaderKey.SERVER_KIND.value: server_lib.ServerKind.DEV_SERVER.value, } - res = self.register(headers) + res = self.put(body={}, headers=headers) assert res.status_code == 400 - assert isinstance(res.body, dict) + res_body = json.loads(res.body) + assert isinstance(res_body, dict) assert ( - res.body["code"] == server_lib.ErrorCode.SERVER_KIND_MISMATCH.value + res_body["code"] == server_lib.ErrorCode.SERVER_KIND_MISMATCH.value ) return base.Case( diff --git a/tests/test_registration/test_fast_api.py b/tests/test_registration/test_fast_api.py index 3dbd5431..554eacab 100644 --- a/tests/test_registration/test_fast_api.py +++ b/tests/test_registration/test_fast_api.py @@ -1,3 +1,6 @@ +import json +import os +import typing import unittest import fastapi @@ -5,7 +8,7 @@ import inngest import inngest.fast_api -from inngest._internal import server_lib +from inngest._internal import const, server_lib from . import base, cases @@ -14,16 +17,40 @@ class TestRegistration(base.TestCase): def setUp(self) -> None: + super().setUp() + + # TODO: Delete this when we default to allowing in-band sync + os.environ[const.EnvKey.ALLOW_IN_BAND_SYNC.value] = "1" + self.app = fastapi.FastAPI() self.app_client = fastapi.testclient.TestClient(self.app) - def register(self, headers: dict[str, str]) -> base.RegistrationResponse: + def tearDown(self) -> None: + super().tearDown() + + # TODO: Delete this when we default to allowing in-band sync + os.environ.pop(const.EnvKey.ALLOW_IN_BAND_SYNC.value, None) + + def put( + self, + *, + body: typing.Union[dict[str, object], bytes], + headers: typing.Optional[dict[str, str]] = None, + ) -> base.RegistrationResponse: + if isinstance(body, bytes): + body = json.loads(body) + + if headers is None: + headers = {} + res = self.app_client.put( "/api/inngest", + json=body, headers=headers, ) return base.RegistrationResponse( - body=res.json(), + body=res.content, + headers=dict(res.headers), status_code=res.status_code, ) diff --git a/tests/test_registration/test_flask.py b/tests/test_registration/test_flask.py index 78e3fc1a..fc72724b 100644 --- a/tests/test_registration/test_flask.py +++ b/tests/test_registration/test_flask.py @@ -1,3 +1,5 @@ +import os +import typing import unittest import flask @@ -6,7 +8,7 @@ import inngest import inngest.flask -from inngest._internal import server_lib +from inngest._internal import const, server_lib from . import base, cases @@ -15,16 +17,37 @@ class TestRegistration(base.TestCase): def setUp(self) -> None: + super().setUp() + + # TODO: Delete this when we default to allowing in-band sync + os.environ[const.EnvKey.ALLOW_IN_BAND_SYNC.value] = "1" + self.app = flask.Flask(__name__) self.app_client = self.app.test_client() - def register(self, headers: dict[str, str]) -> base.RegistrationResponse: + def tearDown(self) -> None: + super().tearDown() + + # TODO: Delete this when we default to allowing in-band sync + os.environ.pop(const.EnvKey.ALLOW_IN_BAND_SYNC.value, None) + + def put( + self, + *, + body: typing.Union[dict[str, object], bytes], + headers: typing.Optional[dict[str, str]] = None, + ) -> base.RegistrationResponse: + if headers is None: + headers = {} + res = self.app_client.put( "/api/inngest", + data=body, headers=headers, ) return base.RegistrationResponse( - body=res.json, + body=res.data, + headers=dict(res.headers.items()), status_code=res.status_code, ) diff --git a/tests/test_registration/test_tornado.py b/tests/test_registration/test_tornado.py deleted file mode 100644 index 44db203c..00000000 --- a/tests/test_registration/test_tornado.py +++ /dev/null @@ -1,55 +0,0 @@ -import json - -import tornado.httpclient -import tornado.ioloop -import tornado.log -import tornado.testing -import tornado.web - -import inngest -import inngest.tornado -from inngest._internal import server_lib - -from . import base, cases - -_framework = server_lib.Framework.TORNADO - - -class TestRegistration(tornado.testing.AsyncHTTPTestCase): - app: tornado.web.Application = tornado.web.Application() - - def get_app(self) -> tornado.web.Application: - return self.app - - def register(self, headers: dict[str, str]) -> base.RegistrationResponse: - res = self.fetch( - "/api/inngest", - body=json.dumps({}), - headers=headers, - method="PUT", - ) - return base.RegistrationResponse( - body=json.loads(res.body), - status_code=res.code, - ) - - def serve( - self, - client: inngest.Inngest, - fns: list[inngest.Function], - ) -> None: - inngest.tornado.serve( - self.app, - client, - fns, - ) - - -for case in cases.create_cases(_framework): - test_name = f"test_{case.name}" - setattr(TestRegistration, test_name, case.run_test) - - -if __name__ == "__main__": - tornado.testing.main() - tornado.testing.main()