Skip to content

Commit

Permalink
Add registration test
Browse files Browse the repository at this point in the history
  • Loading branch information
amh4r committed Oct 31, 2023
1 parent 13912a0 commit ab8cf40
Show file tree
Hide file tree
Showing 12 changed files with 233 additions and 47 deletions.
10 changes: 3 additions & 7 deletions inngest/_internal/client_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(
event_key = os.getenv(const.EnvKey.EVENT_KEY.value)
if event_key is None:
self.logger.error("missing event key")
raise errors.MissingEventKey("missing event key")
raise errors.MissingEventKey()
self._event_key = event_key

event_origin = base_url
Expand Down Expand Up @@ -99,14 +99,10 @@ def send_sync(

def _extract_ids(body: object) -> list[str]:
if not isinstance(body, dict) or "ids" not in body:
raise errors.InvalidResponseShape(
"unexpected response when sending events"
)
raise errors.InvalidBody("unexpected response when sending events")

ids = body["ids"]
if not isinstance(ids, list):
raise errors.InvalidResponseShape(
"unexpected response when sending events"
)
raise errors.InvalidBody("unexpected response when sending events")

return ids
29 changes: 16 additions & 13 deletions inngest/_internal/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def from_internal_error(
) -> CommResponse:
return cls(
body={
"code": str(err),
"code": err.code,
"message": str(err),
},
headers=net.create_headers(framework=framework),
Expand All @@ -76,7 +76,7 @@ class CommHandler:
def __init__(
self,
*,
api_origin: str | None = None,
base_url: str | None = None,
client: client_lib.Inngest,
framework: str,
functions: list[function.Function] | list[function.FunctionSync],
Expand All @@ -89,33 +89,37 @@ def __init__(
if not self._is_production:
self._logger.info("Dev Server mode enabled")

api_origin = api_origin or os.getenv(const.EnvKey.BASE_URL.value)

if api_origin is None:
base_url = base_url or os.getenv(const.EnvKey.BASE_URL.value)
if base_url is None:
if not self._is_production:
self._logger.info("Defaulting API origin to Dev Server")
api_origin = const.DEV_SERVER_ORIGIN
base_url = const.DEV_SERVER_ORIGIN
else:
api_origin = const.DEFAULT_API_ORIGIN
base_url = const.DEFAULT_API_ORIGIN

try:
self._base_url = net.parse_url(api_origin)
self._api_origin = net.parse_url(base_url)
except Exception as err:
raise errors.InvalidBaseURL() from err

self._client = client
self._fns = {fn.get_id(): fn for fn in functions}
self._framework = framework
self._signing_key = signing_key or os.getenv(
const.EnvKey.SIGNING_KEY.value
)

if signing_key is None:
if self._client.is_production:
signing_key = os.getenv(const.EnvKey.SIGNING_KEY.value)
if signing_key is None:
self._logger.error("missing signing key")
raise errors.MissingSigningKey()
self._signing_key = signing_key

def _build_registration_request(
self,
app_url: str,
) -> httpx.Request:
registration_url = urllib.parse.urljoin(
self._base_url,
self._api_origin,
"/fn/register",
)

Expand All @@ -127,7 +131,6 @@ def _build_registration_request(
functions=self.get_function_configs(app_url),
sdk=f"{const.LANGUAGE}:v{const.VERSION}",
url=app_url,
# TODO: Do this for real.
v="0.1",
).to_dict()
)
Expand Down
4 changes: 2 additions & 2 deletions inngest/_internal/comm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def fn(**_kwargs: object) -> int:
return 1

handler = comm.CommHandler(
api_origin="http://foo.bar",
base_url="http://foo.bar",
client=self.client,
framework="test",
functions=[fn],
Expand All @@ -62,7 +62,7 @@ def test_no_functions(self) -> None:
functions: list[inngest.FunctionSync] = []

handler = comm.CommHandler(
api_origin="http://foo.bar",
base_url="http://foo.bar",
client=self.client,
framework="test",
functions=functions,
Expand Down
5 changes: 2 additions & 3 deletions inngest/_internal/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,11 @@ class EnvKey(enum.Enum):


class ErrorCode(enum.Enum):
DEV_SERVER_REGISTRATION_NOT_ALLOWED = "dev_server_registration_not_allowed"
DISALLOWED_REGISTRATION_INITIATOR = "disallowed_registration_initiator"
INVALID_BASE_URL = "invalid_base_url"
INVALID_BODY = "invalid_body"
INVALID_FUNCTION_CONFIG = "invalid_function_config"
INVALID_PARAM = "invalid_param"
INVALID_REQUEST_SIGNATURE = "invalid_request_signature"
INVALID_RESPONSE_SHAPE = "invalid_response_shape"
MISMATCHED_SYNC = "mismatched_sync"
MISSING_EVENT_KEY = "missing_event_key"
MISSING_FUNCTION = "missing_function"
Expand Down
16 changes: 3 additions & 13 deletions inngest/_internal/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class DevServerRegistrationNotAllowed(InternalError):

def __init__(self, message: str | None = None) -> None:
super().__init__(
code=const.ErrorCode.DEV_SERVER_REGISTRATION_NOT_ALLOWED,
code=const.ErrorCode.DISALLOWED_REGISTRATION_INITIATOR,
message=message,
)

Expand Down Expand Up @@ -89,16 +89,6 @@ def __init__(self, message: str | None = None) -> None:
)


class InvalidParam(InternalError):
status_code: int = 400

def __init__(self, message: str | None = None) -> None:
super().__init__(
code=const.ErrorCode.INVALID_PARAM,
message=message,
)


class InvalidRequestSignature(InternalError):
status_code: int = 401

Expand All @@ -109,12 +99,12 @@ def __init__(self, message: str | None = None) -> None:
)


class InvalidResponseShape(InternalError):
class InvalidBody(InternalError):
status_code: int = 500

def __init__(self, message: str | None = None) -> None:
super().__init__(
code=const.ErrorCode.INVALID_RESPONSE_SHAPE,
code=const.ErrorCode.INVALID_BODY,
message=message,
)

Expand Down
2 changes: 1 addition & 1 deletion inngest/fast_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def serve(
signing_key: str | None = None,
) -> None:
handler = comm.CommHandler(
api_origin=base_url or client.base_url,
base_url=base_url or client.base_url,
client=client,
framework="flask",
functions=functions,
Expand Down
2 changes: 1 addition & 1 deletion inngest/flask.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def serve(
signing_key: str | None = None,
) -> None:
handler = comm.CommHandler(
api_origin=base_url or client.base_url,
base_url=base_url or client.base_url,
client=client,
framework="flask",
functions=functions,
Expand Down
10 changes: 3 additions & 7 deletions inngest/tornado.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def serve(
signing_key: str | None = None,
) -> None:
handler = comm.CommHandler(
api_origin=base_url or client.base_url,
base_url=base_url or client.base_url,
client=client,
framework="flask",
functions=functions,
Expand All @@ -42,9 +42,7 @@ def post(self) -> None:
raise errors.MissingParam("fnId")
fn_id = raw_fn_id[0].decode("utf-8")

headers = net.normalize_headers(
{k: v[0] for k, v in self.request.headers.items()}
)
headers = net.normalize_headers(dict(self.request.headers.items()))

comm_res = handler.call_function_sync(
call=execution.Call.from_dict(json.loads(self.request.body)),
Expand All @@ -64,9 +62,7 @@ def post(self) -> None:
self.set_status(comm_res.status_code)

def put(self) -> None:
headers = net.normalize_headers(
{k: v[0] for k, v in self.request.headers.items()}
)
headers = net.normalize_headers(dict(self.request.headers.items()))

comm_res = handler.register_sync(
app_url=self.request.full_url(),
Expand Down
46 changes: 46 additions & 0 deletions tests/test_fast_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import inngest
import inngest.fast_api
from inngest._internal import const

from . import base, cases, dev_server, http_proxy, net

Expand Down Expand Up @@ -89,5 +90,50 @@ def on_proxy_request(
test_name = f"test_{case.name}"
setattr(TestFastAPI, test_name, case.run_test)


class TestFastAPIRegistration(unittest.TestCase):
def test_dev_server_to_prod(self) -> None:
"""
Ensure that Dev Server cannot initiate a registration request when in
production mode.
"""

client = inngest.Inngest(
app_id="fast_api",
event_key="test",
is_production=True,
)

@inngest.create_function(
fn_id="foo",
retries=0,
trigger=inngest.TriggerEvent(event="app/foo"),
)
async def fn(**_kwargs: object) -> None:
pass

app = fastapi.FastAPI()
inngest.fast_api.serve(
app,
client,
[fn],
signing_key="signkey-prod-0486c9",
)
fast_api_client = fastapi.testclient.TestClient(app)
res = fast_api_client.put(
"/api/inngest",
headers={
const.HeaderKey.SERVER_KIND.value.lower(): const.ServerKind.DEV_SERVER.value,
},
)
assert res.status_code == 400
body: object = res.json()
assert (
isinstance(body, dict)
and body["code"]
== const.ErrorCode.DISALLOWED_REGISTRATION_INITIATOR.value
)


if __name__ == "__main__":
unittest.main()
47 changes: 47 additions & 0 deletions tests/test_flask.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import json
import unittest

import flask
import flask.testing

import inngest
import inngest.flask
from inngest._internal import const

from . import base, cases, dev_server, http_proxy, net

Expand Down Expand Up @@ -74,5 +76,50 @@ def on_proxy_request(
test_name = f"test_{case.name}"
setattr(TestFlask, test_name, case.run_test)


class TestFastAPIRegistration(unittest.TestCase):
def test_dev_server_to_prod(self) -> None:
"""
Ensure that Dev Server cannot initiate a registration request when in
production mode.
"""

client = inngest.Inngest(
app_id="flask",
event_key="test",
is_production=True,
)

@inngest.create_function_sync(
fn_id="foo",
retries=0,
trigger=inngest.TriggerEvent(event="app/foo"),
)
def fn(**_kwargs: object) -> None:
pass

app = flask.Flask(__name__)
inngest.flask.serve(
app,
client,
[fn],
signing_key="signkey-prod-0486c9",
)
flask_client = app.test_client()
res = flask_client.put(
"/api/inngest",
headers={
const.HeaderKey.SERVER_KIND.value.lower(): const.ServerKind.DEV_SERVER.value,
},
)
assert res.status_code == 400
body: object = json.loads(res.data)
assert (
isinstance(body, dict)
and body["code"]
== const.ErrorCode.DISALLOWED_REGISTRATION_INITIATOR.value
)


if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit ab8cf40

Please sign in to comment.