diff --git a/inngest/_internal/client_lib.py b/inngest/_internal/client_lib.py index 89a515fc..a2f01186 100644 --- a/inngest/_internal/client_lib.py +++ b/inngest/_internal/client_lib.py @@ -32,7 +32,7 @@ def __init__( event_key = os.getenv(const.EnvKey.EVENT_KEY.value) if event_key is None: self.logger.error("missing event key") - raise errors.MissingEventKey("missing event key") + raise errors.MissingEventKey() self._event_key = event_key event_origin = base_url @@ -99,14 +99,10 @@ def send_sync( def _extract_ids(body: object) -> list[str]: if not isinstance(body, dict) or "ids" not in body: - raise errors.InvalidResponseShape( - "unexpected response when sending events" - ) + raise errors.InvalidBody("unexpected response when sending events") ids = body["ids"] if not isinstance(ids, list): - raise errors.InvalidResponseShape( - "unexpected response when sending events" - ) + raise errors.InvalidBody("unexpected response when sending events") return ids diff --git a/inngest/_internal/comm.py b/inngest/_internal/comm.py index cc188932..971071b0 100644 --- a/inngest/_internal/comm.py +++ b/inngest/_internal/comm.py @@ -56,7 +56,7 @@ def from_internal_error( ) -> CommResponse: return cls( body={ - "code": str(err), + "code": err.code, "message": str(err), }, headers=net.create_headers(framework=framework), @@ -76,7 +76,7 @@ class CommHandler: def __init__( self, *, - api_origin: str | None = None, + base_url: str | None = None, client: client_lib.Inngest, framework: str, functions: list[function.Function] | list[function.FunctionSync], @@ -89,33 +89,37 @@ def __init__( if not self._is_production: self._logger.info("Dev Server mode enabled") - api_origin = api_origin or os.getenv(const.EnvKey.BASE_URL.value) - - if api_origin is None: + base_url = base_url or os.getenv(const.EnvKey.BASE_URL.value) + if base_url is None: if not self._is_production: self._logger.info("Defaulting API origin to Dev Server") - api_origin = const.DEV_SERVER_ORIGIN + base_url = const.DEV_SERVER_ORIGIN else: - api_origin = const.DEFAULT_API_ORIGIN + base_url = const.DEFAULT_API_ORIGIN try: - self._base_url = net.parse_url(api_origin) + self._api_origin = net.parse_url(base_url) except Exception as err: raise errors.InvalidBaseURL() from err self._client = client self._fns = {fn.get_id(): fn for fn in functions} self._framework = framework - self._signing_key = signing_key or os.getenv( - const.EnvKey.SIGNING_KEY.value - ) + + if signing_key is None: + if self._client.is_production: + signing_key = os.getenv(const.EnvKey.SIGNING_KEY.value) + if signing_key is None: + self._logger.error("missing signing key") + raise errors.MissingSigningKey() + self._signing_key = signing_key def _build_registration_request( self, app_url: str, ) -> httpx.Request: registration_url = urllib.parse.urljoin( - self._base_url, + self._api_origin, "/fn/register", ) @@ -127,7 +131,6 @@ def _build_registration_request( functions=self.get_function_configs(app_url), sdk=f"{const.LANGUAGE}:v{const.VERSION}", url=app_url, - # TODO: Do this for real. v="0.1", ).to_dict() ) diff --git a/inngest/_internal/comm_test.py b/inngest/_internal/comm_test.py index a9a402cb..aa42c4d9 100644 --- a/inngest/_internal/comm_test.py +++ b/inngest/_internal/comm_test.py @@ -50,7 +50,7 @@ def fn(**_kwargs: object) -> int: return 1 handler = comm.CommHandler( - api_origin="http://foo.bar", + base_url="http://foo.bar", client=self.client, framework="test", functions=[fn], @@ -62,7 +62,7 @@ def test_no_functions(self) -> None: functions: list[inngest.FunctionSync] = [] handler = comm.CommHandler( - api_origin="http://foo.bar", + base_url="http://foo.bar", client=self.client, framework="test", functions=functions, diff --git a/inngest/_internal/const.py b/inngest/_internal/const.py index a831ec11..69575a3e 100644 --- a/inngest/_internal/const.py +++ b/inngest/_internal/const.py @@ -17,12 +17,11 @@ class EnvKey(enum.Enum): class ErrorCode(enum.Enum): - DEV_SERVER_REGISTRATION_NOT_ALLOWED = "dev_server_registration_not_allowed" + DISALLOWED_REGISTRATION_INITIATOR = "disallowed_registration_initiator" INVALID_BASE_URL = "invalid_base_url" + INVALID_BODY = "invalid_body" INVALID_FUNCTION_CONFIG = "invalid_function_config" - INVALID_PARAM = "invalid_param" INVALID_REQUEST_SIGNATURE = "invalid_request_signature" - INVALID_RESPONSE_SHAPE = "invalid_response_shape" MISMATCHED_SYNC = "mismatched_sync" MISSING_EVENT_KEY = "missing_event_key" MISSING_FUNCTION = "missing_function" diff --git a/inngest/_internal/errors.py b/inngest/_internal/errors.py index 285a0854..e2ba4987 100644 --- a/inngest/_internal/errors.py +++ b/inngest/_internal/errors.py @@ -21,7 +21,7 @@ class DevServerRegistrationNotAllowed(InternalError): def __init__(self, message: str | None = None) -> None: super().__init__( - code=const.ErrorCode.DEV_SERVER_REGISTRATION_NOT_ALLOWED, + code=const.ErrorCode.DISALLOWED_REGISTRATION_INITIATOR, message=message, ) @@ -89,16 +89,6 @@ def __init__(self, message: str | None = None) -> None: ) -class InvalidParam(InternalError): - status_code: int = 400 - - def __init__(self, message: str | None = None) -> None: - super().__init__( - code=const.ErrorCode.INVALID_PARAM, - message=message, - ) - - class InvalidRequestSignature(InternalError): status_code: int = 401 @@ -109,12 +99,12 @@ def __init__(self, message: str | None = None) -> None: ) -class InvalidResponseShape(InternalError): +class InvalidBody(InternalError): status_code: int = 500 def __init__(self, message: str | None = None) -> None: super().__init__( - code=const.ErrorCode.INVALID_RESPONSE_SHAPE, + code=const.ErrorCode.INVALID_BODY, message=message, ) diff --git a/inngest/fast_api.py b/inngest/fast_api.py index 421adebe..86ab395d 100644 --- a/inngest/fast_api.py +++ b/inngest/fast_api.py @@ -14,7 +14,7 @@ def serve( signing_key: str | None = None, ) -> None: handler = comm.CommHandler( - api_origin=base_url or client.base_url, + base_url=base_url or client.base_url, client=client, framework="flask", functions=functions, diff --git a/inngest/flask.py b/inngest/flask.py index ee942e40..7dfb2cb2 100644 --- a/inngest/flask.py +++ b/inngest/flask.py @@ -14,7 +14,7 @@ def serve( signing_key: str | None = None, ) -> None: handler = comm.CommHandler( - api_origin=base_url or client.base_url, + base_url=base_url or client.base_url, client=client, framework="flask", functions=functions, diff --git a/inngest/tornado.py b/inngest/tornado.py index df2ee250..21ecf65e 100644 --- a/inngest/tornado.py +++ b/inngest/tornado.py @@ -23,7 +23,7 @@ def serve( signing_key: str | None = None, ) -> None: handler = comm.CommHandler( - api_origin=base_url or client.base_url, + base_url=base_url or client.base_url, client=client, framework="flask", functions=functions, @@ -42,9 +42,7 @@ def post(self) -> None: raise errors.MissingParam("fnId") fn_id = raw_fn_id[0].decode("utf-8") - headers = net.normalize_headers( - {k: v[0] for k, v in self.request.headers.items()} - ) + headers = net.normalize_headers(dict(self.request.headers.items())) comm_res = handler.call_function_sync( call=execution.Call.from_dict(json.loads(self.request.body)), @@ -64,9 +62,7 @@ def post(self) -> None: self.set_status(comm_res.status_code) def put(self) -> None: - headers = net.normalize_headers( - {k: v[0] for k, v in self.request.headers.items()} - ) + headers = net.normalize_headers(dict(self.request.headers.items())) comm_res = handler.register_sync( app_url=self.request.full_url(), diff --git a/tests/test_fast_api.py b/tests/test_fast_api.py index 56caaa51..55c13722 100644 --- a/tests/test_fast_api.py +++ b/tests/test_fast_api.py @@ -6,6 +6,7 @@ import inngest import inngest.fast_api +from inngest._internal import const from . import base, cases, dev_server, http_proxy, net @@ -89,5 +90,50 @@ def on_proxy_request( test_name = f"test_{case.name}" setattr(TestFastAPI, test_name, case.run_test) + +class TestFastAPIRegistration(unittest.TestCase): + def test_dev_server_to_prod(self) -> None: + """ + Ensure that Dev Server cannot initiate a registration request when in + production mode. + """ + + client = inngest.Inngest( + app_id="fast_api", + event_key="test", + is_production=True, + ) + + @inngest.create_function( + fn_id="foo", + retries=0, + trigger=inngest.TriggerEvent(event="app/foo"), + ) + async def fn(**_kwargs: object) -> None: + pass + + app = fastapi.FastAPI() + inngest.fast_api.serve( + app, + client, + [fn], + signing_key="signkey-prod-0486c9", + ) + fast_api_client = fastapi.testclient.TestClient(app) + res = fast_api_client.put( + "/api/inngest", + headers={ + const.HeaderKey.SERVER_KIND.value.lower(): const.ServerKind.DEV_SERVER.value, + }, + ) + assert res.status_code == 400 + body: object = res.json() + assert ( + isinstance(body, dict) + and body["code"] + == const.ErrorCode.DISALLOWED_REGISTRATION_INITIATOR.value + ) + + if __name__ == "__main__": unittest.main() diff --git a/tests/test_flask.py b/tests/test_flask.py index dddc29c5..d1a32231 100644 --- a/tests/test_flask.py +++ b/tests/test_flask.py @@ -1,3 +1,4 @@ +import json import unittest import flask @@ -5,6 +6,7 @@ import inngest import inngest.flask +from inngest._internal import const from . import base, cases, dev_server, http_proxy, net @@ -74,5 +76,50 @@ def on_proxy_request( test_name = f"test_{case.name}" setattr(TestFlask, test_name, case.run_test) + +class TestFastAPIRegistration(unittest.TestCase): + def test_dev_server_to_prod(self) -> None: + """ + Ensure that Dev Server cannot initiate a registration request when in + production mode. + """ + + client = inngest.Inngest( + app_id="flask", + event_key="test", + is_production=True, + ) + + @inngest.create_function_sync( + fn_id="foo", + retries=0, + trigger=inngest.TriggerEvent(event="app/foo"), + ) + def fn(**_kwargs: object) -> None: + pass + + app = flask.Flask(__name__) + inngest.flask.serve( + app, + client, + [fn], + signing_key="signkey-prod-0486c9", + ) + flask_client = app.test_client() + res = flask_client.put( + "/api/inngest", + headers={ + const.HeaderKey.SERVER_KIND.value.lower(): const.ServerKind.DEV_SERVER.value, + }, + ) + assert res.status_code == 400 + body: object = json.loads(res.data) + assert ( + isinstance(body, dict) + and body["code"] + == const.ErrorCode.DISALLOWED_REGISTRATION_INITIATOR.value + ) + + if __name__ == "__main__": unittest.main() diff --git a/tests/test_registration.py b/tests/test_registration.py new file mode 100644 index 00000000..13d3544a --- /dev/null +++ b/tests/test_registration.py @@ -0,0 +1,56 @@ +import unittest + +import fastapi +import fastapi.testclient + +import inngest +import inngest.fast_api +from inngest._internal import const + + +class TestFastAPI(unittest.TestCase): + def test_dev_server_to_prod(self) -> None: + """ + Ensure that Dev Server cannot initiate a registration request when in + production mode. + """ + + client = inngest.Inngest( + app_id="test", + event_key="test", + is_production=True, + ) + + @inngest.create_function( + fn_id="foo", + retries=0, + trigger=inngest.TriggerEvent(event="app/foo"), + ) + async def fn(**_kwargs: object) -> None: + pass + + app = fastapi.FastAPI() + inngest.fast_api.serve( + app, + client, + [fn], + signing_key="signkey-prod-0486c9", + ) + fast_api_client = fastapi.testclient.TestClient(app) + res = fast_api_client.put( + "/api/inngest", + headers={ + const.HeaderKey.SERVER_KIND.value.lower(): const.ServerKind.DEV_SERVER.value, + }, + ) + assert res.status_code == 400 + body: object = res.json() + assert ( + isinstance(body, dict) + and body["code"] + == const.ErrorCode.DISALLOWED_REGISTRATION_INITIATOR.value + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_tornado.py b/tests/test_tornado.py index e7711ef2..9e775236 100644 --- a/tests/test_tornado.py +++ b/tests/test_tornado.py @@ -1,9 +1,12 @@ +import json + import tornado.log import tornado.testing import tornado.web import inngest import inngest.tornado +from inngest._internal import const from . import base, cases, dev_server, http_proxy, net @@ -75,5 +78,55 @@ def on_proxy_request( setattr(TestTornado, test_name, case.run_test) +class TestTornadoRegistration(tornado.testing.AsyncHTTPTestCase): + app: tornado.web.Application = tornado.web.Application() + + def get_app(self) -> tornado.web.Application: + return self.app + + def test_dev_server_to_prod(self) -> None: + """ + Ensure that Dev Server cannot initiate a registration request when in + production mode. + """ + + client = inngest.Inngest( + app_id="flask", + event_key="test", + is_production=True, + ) + + @inngest.create_function_sync( + fn_id="foo", + retries=0, + trigger=inngest.TriggerEvent(event="app/foo"), + ) + def fn(**_kwargs: object) -> None: + pass + + inngest.tornado.serve( + self.get_app(), + client, + [fn], + signing_key="signkey-prod-0486c9", + ) + res = self.fetch( + "/api/inngest", + body=json.dumps({}), + headers={ + const.HeaderKey.SERVER_KIND.value.lower(): const.ServerKind.DEV_SERVER.value, + }, + method="PUT", + ) + assert res.code == 400 + print(res.body) + body: object = json.loads(res.body) + assert ( + isinstance(body, dict) + and body["code"] + == const.ErrorCode.DISALLOWED_REGISTRATION_INITIATOR.value + ) + + if __name__ == "__main__": tornado.testing.main()