diff --git a/.vscode/settings.json b/.vscode/settings.json index b2a34a58..3d59617d 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -5,6 +5,6 @@ } }, "editor.formatOnSave": true, - "isort.args": ["--profile", "black"], + "isort.args": ["--profile", "black", "--config", "pyproject.toml"], "isort.check": true } diff --git a/inngest/client.py b/inngest/client.py index 24b02a1f..6650943b 100644 --- a/inngest/client.py +++ b/inngest/client.py @@ -4,7 +4,7 @@ from urllib.parse import urljoin from .const import DEFAULT_EVENT_ORIGIN, DEV_SERVER_ORIGIN, EnvKey -from .env import allow_dev_server +from .env import is_prod from .errors import InvalidResponseShape, MissingEventKey from .event import Event from .net import create_headers, requests_session @@ -24,7 +24,7 @@ def __init__( self.logger = logger or getLogger(__name__) if event_key is None: - if allow_dev_server(): + if not is_prod(): event_key = "NO_EVENT_KEY_SET" else: event_key = os.getenv(EnvKey.EVENT_KEY.value) @@ -35,7 +35,7 @@ def __init__( event_origin = base_url if event_origin is None: - if allow_dev_server(): + if not is_prod(): self.logger.info("Defaulting event origin to Dev Server") event_origin = DEV_SERVER_ORIGIN else: diff --git a/inngest/comm.py b/inngest/comm.py index f323f774..ac219d63 100644 --- a/inngest/comm.py +++ b/inngest/comm.py @@ -1,9 +1,11 @@ from __future__ import annotations +import hashlib +import hmac import os from dataclasses import dataclass from logging import Logger -from urllib.parse import urljoin +from urllib.parse import parse_qs, urljoin import requests @@ -15,15 +17,25 @@ VERSION, EnvKey, ErrorCode, + HeaderKey, +) +from .env import is_prod +from .errors import ( + InvalidBaseURL, + InvalidRequestSignature, + MissingHeader, + MissingSigningKey, ) -from .env import allow_dev_server -from .errors import InvalidBaseURL from .execution import Call, CallError, CallResponse from .function import Function from .function_config import FunctionConfig from .net import create_headers, parse_url, requests_session from .registration import DeployType, RegisterRequest -from .transforms import hash_signing_key, remove_none_deep +from .transforms import ( + hash_signing_key, + remove_none_deep, + remove_signing_key_prefix, +) @dataclass @@ -59,13 +71,13 @@ def __init__( ) -> None: self._logger = logger - if allow_dev_server(): + if not is_prod(): self._logger.info("Dev Server mode enabled") api_origin = api_origin or os.getenv(EnvKey.BASE_URL.value) if api_origin is None: - if allow_dev_server(): + if not is_prod(): self._logger.info("Defaulting API origin to Dev Server") api_origin = DEV_SERVER_ORIGIN else: @@ -86,11 +98,14 @@ def call_function( *, call: Call, fn_id: str, + req_sig: RequestSignature, ) -> CommResponse: """ Handles a function call from the Executor. """ + req_sig.validate(self._signing_key) + if fn_id not in self._fns: raise Exception(f"function {fn_id} not found") @@ -181,7 +196,7 @@ def register( Handles a registration call. """ - if is_from_dev_server and not allow_dev_server(): + if is_from_dev_server and is_prod(): self._logger.error( "Dev Server registration not allowed in production mode" ) @@ -225,3 +240,46 @@ def register( ) return self._parse_registration_response(res) + + +class RequestSignature: + _signature: str | None = None + _timestamp: int | None = None + + def __init__( + self, + body: bytes, + headers: dict[str, str], + ) -> None: + self._body = body + + sig_header = headers.get(HeaderKey.SIGNATURE.value) + if sig_header is not None: + parsed = parse_qs(sig_header) + if "t" in parsed: + self._timestamp = int(parsed["t"][0]) + if "s" in parsed: + self._signature = parsed["s"][0] + + def validate(self, signing_key: str | None) -> None: + if not is_prod(): + return + + if signing_key is None: + raise MissingSigningKey( + "cannot validate signature in production mode without a signing key" + ) + + if self._signature is None: + raise MissingHeader( + f"cannot validate signature in production mode without a {HeaderKey.SIGNATURE.value} header" + ) + + mac = hmac.new( + remove_signing_key_prefix(signing_key).encode("utf-8"), + self._body, + hashlib.sha256, + ) + mac.update(str(self._timestamp).encode()) + if not hmac.compare_digest(self._signature, mac.hexdigest()): + raise InvalidRequestSignature() diff --git a/inngest/const.py b/inngest/const.py index 9ae69966..addc8033 100644 --- a/inngest/const.py +++ b/inngest/const.py @@ -16,3 +16,7 @@ class EnvKey(Enum): class ErrorCode(Enum): DEV_SERVER_REGISTRATION_NOT_ALLOWED = "DEV_SERVER_REGISTRATION_NOT_ALLOWED" + + +class HeaderKey(Enum): + SIGNATURE = "X-Inngest-Signature" diff --git a/inngest/env.py b/inngest/env.py index 6254288a..87c13e9b 100644 --- a/inngest/env.py +++ b/inngest/env.py @@ -42,7 +42,7 @@ def _starts_with(key: EnvKey, value: str) -> _EnvCheck: ] -def allow_dev_server() -> bool: +def is_prod() -> bool: for check in _PROD_CHECKS: value = os.getenv(check.key.value) operator = check.operator @@ -53,12 +53,12 @@ def allow_dev_server() -> bool: if operator == "equals": if value == expected: - return False + return True elif operator == "is_truthy": if value: - return False + return True elif operator == "starts_with" and isinstance(expected, str): if value.startswith(expected): - return False + return True - return True + return False diff --git a/inngest/env_test.py b/inngest/env_test.py index a59e616b..52878e7d 100644 --- a/inngest/env_test.py +++ b/inngest/env_test.py @@ -1,33 +1,33 @@ import os -from .env import EnvKey, allow_dev_server +from .env import EnvKey, is_prod def test_allow_dev_server() -> None: - assert allow_dev_server() is True + assert is_prod() is False os.environ["CF_PAGES"] = "1" - assert allow_dev_server() is False + assert is_prod() is True _clear() os.environ["CONTEXT"] = "production" - assert allow_dev_server() is False + assert is_prod() is True _clear() os.environ["DENO_DEPLOYMENT_ID"] = "1" - assert allow_dev_server() is False + assert is_prod() is True _clear() os.environ["ENVIRONMENT"] = "production" - assert allow_dev_server() is False + assert is_prod() is True _clear() os.environ["FLASK_ENV"] = "production" - assert allow_dev_server() is False + assert is_prod() is True _clear() os.environ["VERCEL_ENV"] = "production" - assert allow_dev_server() is False + assert is_prod() is True _clear() diff --git a/inngest/errors.py b/inngest/errors.py index d96ec90a..fc7053bb 100644 --- a/inngest/errors.py +++ b/inngest/errors.py @@ -1,14 +1,30 @@ -class InvalidBaseURL(Exception): +class InngestError(Exception): pass +class InvalidBaseURL(Exception): + code = "invalid_base_url" + + +class InvalidRequestSignature(Exception): + code = "invalid_request_signature" + + class InvalidResponseShape(Exception): - pass + code = "invalid_response_shape" class MissingEventKey(Exception): - pass + code = "missing_event_key" + + +class MissingHeader(Exception): + code = "missing_header" + + +class MissingSigningKey(Exception): + code = "missing_signing_key" class NonRetriableError(Exception): - pass + code = "non_retriable_error" diff --git a/inngest/frameworks/flask.py b/inngest/frameworks/flask.py index bf5930e7..df7db638 100644 --- a/inngest/frameworks/flask.py +++ b/inngest/frameworks/flask.py @@ -3,7 +3,7 @@ from flask import Flask, Response, make_response, request from inngest.client import Inngest -from inngest.comm import CommHandler, CommResponse +from inngest.comm import CommHandler, CommResponse, RequestSignature from inngest.execution import Call from inngest.function import Function @@ -36,6 +36,10 @@ def inngest_api() -> Response | str: comm.call_function( call=Call.from_dict(json.loads(request.data)), fn_id=fn_id, + req_sig=RequestSignature( + body=request.data, + headers=dict(request.headers.items()), + ), ) ) diff --git a/inngest/frameworks/tornado.py b/inngest/frameworks/tornado.py index 1905acaa..75cabc46 100644 --- a/inngest/frameworks/tornado.py +++ b/inngest/frameworks/tornado.py @@ -4,7 +4,7 @@ from tornado.web import Application, RequestHandler from inngest.client import Inngest -from inngest.comm import CommHandler +from inngest.comm import CommHandler, RequestSignature from inngest.execution import Call from inngest.function import Function @@ -37,9 +37,19 @@ def post(self) -> None: raise Exception("missing fnId") fn_id = raw_fn_id[0].decode("utf-8") + headers: dict[str, str] = {} + + for k, v in self.request.headers.items(): + if isinstance(k, str) and isinstance(v[0], str): + headers[k] = v[0] + comm_res = comm.call_function( call=Call.from_dict(json.loads(self.request.body)), fn_id=fn_id, + req_sig=RequestSignature( + body=self.request.body, + headers=headers, + ), ) self.write(json.dumps(comm_res.body)) diff --git a/inngest/transforms.py b/inngest/transforms.py index 9d2b509e..d0797fd5 100644 --- a/inngest/transforms.py +++ b/inngest/transforms.py @@ -5,17 +5,22 @@ def hash_signing_key(key: str) -> str: + return hashlib.sha256( + bytearray.fromhex(remove_signing_key_prefix(key)) + ).hexdigest() + + +def hash_step_id(step_id: str) -> str: + return hashlib.sha1(step_id.encode("utf-8")).hexdigest() + + +def remove_signing_key_prefix(key: str) -> str: prefix_match = re.match(r"^signkey-[\w]+-", key) prefix = "" if prefix_match: prefix = prefix_match.group(0) - key_without_prefix = key[len(prefix) :] - return hashlib.sha256(bytearray.fromhex(key_without_prefix)).hexdigest() - - -def hash_step_id(step_id: str) -> str: - return hashlib.sha1(step_id.encode("utf-8")).hexdigest() + return key[len(prefix) :] def remove_none_deep(obj: T) -> T: diff --git a/pyproject.toml b/pyproject.toml index c48f4bad..c7da709c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,7 @@ extra = [ line-length = 80 [tool.isort] +line_length = 80 profile = "black" [tool.mypy] @@ -51,6 +52,7 @@ disable = [ 'duplicate-code', 'fixme', 'invalid-envvar-value', + 'line-too-long', 'missing-docstring', 'too-few-public-methods', 'too-many-arguments',