Skip to content

Commit

Permalink
Add request validation
Browse files Browse the repository at this point in the history
  • Loading branch information
amh4r committed Oct 22, 2023
1 parent 6a823cd commit 8e84511
Show file tree
Hide file tree
Showing 11 changed files with 135 additions and 36 deletions.
2 changes: 1 addition & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@
}
},
"editor.formatOnSave": true,
"isort.args": ["--profile", "black"],
"isort.args": ["--profile", "black", "--config", "pyproject.toml"],
"isort.check": true
}
6 changes: 3 additions & 3 deletions inngest/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from urllib.parse import urljoin

from .const import DEFAULT_EVENT_ORIGIN, DEV_SERVER_ORIGIN, EnvKey
from .env import allow_dev_server
from .env import is_prod
from .errors import InvalidResponseShape, MissingEventKey
from .event import Event
from .net import create_headers, requests_session
Expand All @@ -24,7 +24,7 @@ def __init__(
self.logger = logger or getLogger(__name__)

if event_key is None:
if allow_dev_server():
if not is_prod():
event_key = "NO_EVENT_KEY_SET"
else:
event_key = os.getenv(EnvKey.EVENT_KEY.value)
Expand All @@ -35,7 +35,7 @@ def __init__(

event_origin = base_url
if event_origin is None:
if allow_dev_server():
if not is_prod():
self.logger.info("Defaulting event origin to Dev Server")
event_origin = DEV_SERVER_ORIGIN
else:
Expand Down
72 changes: 65 additions & 7 deletions inngest/comm.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from __future__ import annotations

import hashlib
import hmac
import os
from dataclasses import dataclass
from logging import Logger
from urllib.parse import urljoin
from urllib.parse import parse_qs, urljoin

import requests

Expand All @@ -15,15 +17,25 @@
VERSION,
EnvKey,
ErrorCode,
HeaderKey,
)
from .env import is_prod
from .errors import (
InvalidBaseURL,
InvalidRequestSignature,
MissingHeader,
MissingSigningKey,
)
from .env import allow_dev_server
from .errors import InvalidBaseURL
from .execution import Call, CallError, CallResponse
from .function import Function
from .function_config import FunctionConfig
from .net import create_headers, parse_url, requests_session
from .registration import DeployType, RegisterRequest
from .transforms import hash_signing_key, remove_none_deep
from .transforms import (
hash_signing_key,
remove_none_deep,
remove_signing_key_prefix,
)


@dataclass
Expand Down Expand Up @@ -59,13 +71,13 @@ def __init__(
) -> None:
self._logger = logger

if allow_dev_server():
if not is_prod():
self._logger.info("Dev Server mode enabled")

api_origin = api_origin or os.getenv(EnvKey.BASE_URL.value)

if api_origin is None:
if allow_dev_server():
if not is_prod():
self._logger.info("Defaulting API origin to Dev Server")
api_origin = DEV_SERVER_ORIGIN
else:
Expand All @@ -86,11 +98,14 @@ def call_function(
*,
call: Call,
fn_id: str,
req_sig: RequestSignature,
) -> CommResponse:
"""
Handles a function call from the Executor.
"""

req_sig.validate(self._signing_key)

if fn_id not in self._fns:
raise Exception(f"function {fn_id} not found")

Expand Down Expand Up @@ -181,7 +196,7 @@ def register(
Handles a registration call.
"""

if is_from_dev_server and not allow_dev_server():
if is_from_dev_server and is_prod():
self._logger.error(
"Dev Server registration not allowed in production mode"
)
Expand Down Expand Up @@ -225,3 +240,46 @@ def register(
)

return self._parse_registration_response(res)


class RequestSignature:
_signature: str | None = None
_timestamp: int | None = None

def __init__(
self,
body: bytes,
headers: dict[str, str],
) -> None:
self._body = body

sig_header = headers.get(HeaderKey.SIGNATURE.value)
if sig_header is not None:
parsed = parse_qs(sig_header)
if "t" in parsed:
self._timestamp = int(parsed["t"][0])
if "s" in parsed:
self._signature = parsed["s"][0]

def validate(self, signing_key: str | None) -> None:
if not is_prod():
return

if signing_key is None:
raise MissingSigningKey(
"cannot validate signature in production mode without a signing key"
)

if self._signature is None:
raise MissingHeader(
f"cannot validate signature in production mode without a {HeaderKey.SIGNATURE.value} header"
)

mac = hmac.new(
remove_signing_key_prefix(signing_key).encode("utf-8"),
self._body,
hashlib.sha256,
)
mac.update(str(self._timestamp).encode())
if not hmac.compare_digest(self._signature, mac.hexdigest()):
raise InvalidRequestSignature()
4 changes: 4 additions & 0 deletions inngest/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,7 @@ class EnvKey(Enum):

class ErrorCode(Enum):
DEV_SERVER_REGISTRATION_NOT_ALLOWED = "DEV_SERVER_REGISTRATION_NOT_ALLOWED"


class HeaderKey(Enum):
SIGNATURE = "X-Inngest-Signature"
10 changes: 5 additions & 5 deletions inngest/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def _starts_with(key: EnvKey, value: str) -> _EnvCheck:
]


def allow_dev_server() -> bool:
def is_prod() -> bool:
for check in _PROD_CHECKS:
value = os.getenv(check.key.value)
operator = check.operator
Expand All @@ -53,12 +53,12 @@ def allow_dev_server() -> bool:

if operator == "equals":
if value == expected:
return False
return True
elif operator == "is_truthy":
if value:
return False
return True
elif operator == "starts_with" and isinstance(expected, str):
if value.startswith(expected):
return False
return True

return True
return False
16 changes: 8 additions & 8 deletions inngest/env_test.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,33 @@
import os

from .env import EnvKey, allow_dev_server
from .env import EnvKey, is_prod


def test_allow_dev_server() -> None:
assert allow_dev_server() is True
assert is_prod() is False

os.environ["CF_PAGES"] = "1"
assert allow_dev_server() is False
assert is_prod() is True
_clear()

os.environ["CONTEXT"] = "production"
assert allow_dev_server() is False
assert is_prod() is True
_clear()

os.environ["DENO_DEPLOYMENT_ID"] = "1"
assert allow_dev_server() is False
assert is_prod() is True
_clear()

os.environ["ENVIRONMENT"] = "production"
assert allow_dev_server() is False
assert is_prod() is True
_clear()

os.environ["FLASK_ENV"] = "production"
assert allow_dev_server() is False
assert is_prod() is True
_clear()

os.environ["VERCEL_ENV"] = "production"
assert allow_dev_server() is False
assert is_prod() is True
_clear()


Expand Down
24 changes: 20 additions & 4 deletions inngest/errors.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,30 @@
class InvalidBaseURL(Exception):
class InngestError(Exception):
pass


class InvalidBaseURL(Exception):
code = "invalid_base_url"


class InvalidRequestSignature(Exception):
code = "invalid_request_signature"


class InvalidResponseShape(Exception):
pass
code = "invalid_response_shape"


class MissingEventKey(Exception):
pass
code = "missing_event_key"


class MissingHeader(Exception):
code = "missing_header"


class MissingSigningKey(Exception):
code = "missing_signing_key"


class NonRetriableError(Exception):
pass
code = "non_retriable_error"
6 changes: 5 additions & 1 deletion inngest/frameworks/flask.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from flask import Flask, Response, make_response, request

from inngest.client import Inngest
from inngest.comm import CommHandler, CommResponse
from inngest.comm import CommHandler, CommResponse, RequestSignature
from inngest.execution import Call
from inngest.function import Function

Expand Down Expand Up @@ -36,6 +36,10 @@ def inngest_api() -> Response | str:
comm.call_function(
call=Call.from_dict(json.loads(request.data)),
fn_id=fn_id,
req_sig=RequestSignature(
body=request.data,
headers=dict(request.headers.items()),
),
)
)

Expand Down
12 changes: 11 additions & 1 deletion inngest/frameworks/tornado.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from tornado.web import Application, RequestHandler

from inngest.client import Inngest
from inngest.comm import CommHandler
from inngest.comm import CommHandler, RequestSignature
from inngest.execution import Call
from inngest.function import Function

Expand Down Expand Up @@ -37,9 +37,19 @@ def post(self) -> None:
raise Exception("missing fnId")
fn_id = raw_fn_id[0].decode("utf-8")

headers: dict[str, str] = {}

for k, v in self.request.headers.items():
if isinstance(k, str) and isinstance(v[0], str):
headers[k] = v[0]

comm_res = comm.call_function(
call=Call.from_dict(json.loads(self.request.body)),
fn_id=fn_id,
req_sig=RequestSignature(
body=self.request.body,
headers=headers,
),
)

self.write(json.dumps(comm_res.body))
Expand Down
17 changes: 11 additions & 6 deletions inngest/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,22 @@


def hash_signing_key(key: str) -> str:
return hashlib.sha256(
bytearray.fromhex(remove_signing_key_prefix(key))
).hexdigest()


def hash_step_id(step_id: str) -> str:
return hashlib.sha1(step_id.encode("utf-8")).hexdigest()


def remove_signing_key_prefix(key: str) -> str:
prefix_match = re.match(r"^signkey-[\w]+-", key)
prefix = ""
if prefix_match:
prefix = prefix_match.group(0)

key_without_prefix = key[len(prefix) :]
return hashlib.sha256(bytearray.fromhex(key_without_prefix)).hexdigest()


def hash_step_id(step_id: str) -> str:
return hashlib.sha1(step_id.encode("utf-8")).hexdigest()
return key[len(prefix) :]


def remove_none_deep(obj: T) -> T:
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ extra = [
line-length = 80

[tool.isort]
line_length = 80
profile = "black"

[tool.mypy]
Expand All @@ -51,6 +52,7 @@ disable = [
'duplicate-code',
'fixme',
'invalid-envvar-value',
'line-too-long',
'missing-docstring',
'too-few-public-methods',
'too-many-arguments',
Expand Down

0 comments on commit 8e84511

Please sign in to comment.