Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix configured serve path in serve functions #165

Merged
merged 5 commits into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions inngest/_internal/config_lib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import os
import typing

from inngest._internal import const


def get_serve_origin(code_value: typing.Optional[str]) -> typing.Optional[str]:
if code_value is not None:
return code_value

env_var_value = os.getenv(const.EnvKey.SERVE_ORIGIN.value)
if env_var_value:
return env_var_value

return None


def get_serve_path(code_value: typing.Optional[str]) -> typing.Optional[str]:
if code_value is not None:
return code_value

env_var_value = os.getenv(const.EnvKey.SERVE_PATH.value)
if env_var_value:
return env_var_value

return None
1 change: 1 addition & 0 deletions inngest/_internal/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
AUTHOR: typing.Final = "inngest"
DEFAULT_API_ORIGIN: typing.Final = "https://api.inngest.com/"
DEFAULT_EVENT_API_ORIGIN: typing.Final = "https://inn.gs/"
DEFAULT_SERVE_PATH: typing.Final = "/api/inngest"
DEV_SERVER_ORIGIN: typing.Final = "http://127.0.0.1:8288/"
LANGUAGE: typing.Final = "py"
VERSION: typing.Final = importlib.metadata.version("inngest")
Expand Down
6 changes: 3 additions & 3 deletions inngest/_internal/net.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import hashlib
import hmac
import http
import os
import threading
import time
import typing
Expand All @@ -13,6 +12,7 @@

from inngest._internal import (
async_lib,
config_lib,
const,
errors,
server_lib,
Expand Down Expand Up @@ -90,8 +90,8 @@ def create_serve_url(
"""

# User can also specify these via env vars. The env vars take precedence.
serve_origin = os.getenv(const.EnvKey.SERVE_ORIGIN.value, serve_origin)
serve_path = os.getenv(const.EnvKey.SERVE_PATH.value, serve_path)
serve_origin = config_lib.get_serve_origin(serve_origin)
serve_path = config_lib.get_serve_path(serve_path)

parsed_url = urllib.parse.urlparse(request_url)
new_scheme = parsed_url.scheme
Expand Down
4 changes: 2 additions & 2 deletions inngest/_internal/net_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_serve_origin_env_var(self) -> None:
serve_origin="https://bar.test",
serve_path=None,
)
expected = "https://bar-env.test/api/inngest"
expected = "https://bar.test/api/inngest"
assert actual == expected

def test_serve_origin_missing_scheme(self) -> None:
Expand Down Expand Up @@ -85,7 +85,7 @@ def test_serve_path_env_var(self) -> None:
serve_origin=None,
serve_path="/custom/path",
)
expected = "https://foo.test/env/path"
expected = "https://foo.test/custom/path"
assert actual == expected

def test_serve_origin_and_path(self) -> None:
Expand Down
22 changes: 19 additions & 3 deletions inngest/django.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,15 @@
import django.urls
import django.views.decorators.csrf

from ._internal import client_lib, comm_lib, function, server_lib, transforms
from ._internal import (
client_lib,
comm_lib,
config_lib,
const,
function,
server_lib,
transforms,
)

FRAMEWORK = server_lib.Framework.DJANGO

Expand Down Expand Up @@ -107,7 +115,9 @@ def inngest_api(
)

return django.urls.path(
"api/inngest",
_trim_leading_slash(
config_lib.get_serve_path(serve_path) or const.DEFAULT_SERVE_PATH
),
django.views.decorators.csrf.csrf_exempt(inngest_api),
)

Expand Down Expand Up @@ -166,7 +176,9 @@ async def inngest_api(
)

return django.urls.path(
"api/inngest",
_trim_leading_slash(
config_lib.get_serve_path(serve_path) or const.DEFAULT_SERVE_PATH
),
django.views.decorators.csrf.csrf_exempt(inngest_api),
)

Expand All @@ -185,3 +197,7 @@ def _to_response(
headers=comm_res.headers,
status=comm_res.status_code,
)


def _trim_leading_slash(value: str) -> str:
return value.lstrip("/")
16 changes: 12 additions & 4 deletions inngest/fast_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,15 @@

import fastapi

from ._internal import client_lib, comm_lib, function, server_lib, transforms
from ._internal import (
client_lib,
comm_lib,
config_lib,
const,
function,
server_lib,
transforms,
)

FRAMEWORK = server_lib.Framework.FAST_API

Expand Down Expand Up @@ -37,7 +45,7 @@ def serve(
functions=functions,
)

@app.get("/api/inngest")
@app.get(config_lib.get_serve_path(serve_path) or const.DEFAULT_SERVE_PATH)
async def get_api_inngest(
request: fastapi.Request,
) -> fastapi.Response:
Expand All @@ -56,7 +64,7 @@ async def get_api_inngest(
),
)

@app.post("/api/inngest")
@app.post(config_lib.get_serve_path(serve_path) or const.DEFAULT_SERVE_PATH)
async def post_inngest_api(
request: fastapi.Request,
) -> fastapi.Response:
Expand All @@ -75,7 +83,7 @@ async def post_inngest_api(
),
)

@app.put("/api/inngest")
@app.put(config_lib.get_serve_path(serve_path) or const.DEFAULT_SERVE_PATH)
async def put_inngest_api(
request: fastapi.Request,
) -> fastapi.Response:
Expand Down
19 changes: 16 additions & 3 deletions inngest/flask.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,14 @@

import flask

from inngest._internal import client_lib, comm_lib, function, server_lib
from inngest._internal import (
client_lib,
comm_lib,
config_lib,
const,
function,
server_lib,
)

FRAMEWORK = server_lib.Framework.FLASK

Expand Down Expand Up @@ -67,7 +74,10 @@ def _create_handler_async(
serve_origin: typing.Optional[str],
serve_path: typing.Optional[str],
) -> None:
@app.route("/api/inngest", methods=["GET", "POST", "PUT"])
@app.route(
config_lib.get_serve_path(serve_path) or const.DEFAULT_SERVE_PATH,
methods=["GET", "POST", "PUT"],
)
async def inngest_api() -> typing.Union[flask.Response, str]:
comm_req = comm_lib.CommRequest(
body=_get_body_bytes(),
Expand Down Expand Up @@ -109,7 +119,10 @@ def _create_handler_sync(
serve_origin: typing.Optional[str],
serve_path: typing.Optional[str],
) -> None:
@app.route("/api/inngest", methods=["GET", "POST", "PUT"])
@app.route(
config_lib.get_serve_path(serve_path) or const.DEFAULT_SERVE_PATH,
methods=["GET", "POST", "PUT"],
)
def inngest_api() -> typing.Union[flask.Response, str]:
comm_req = comm_lib.CommRequest(
body=_get_body_bytes(),
Expand Down
14 changes: 13 additions & 1 deletion inngest/tornado.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from inngest._internal import (
client_lib,
comm_lib,
config_lib,
const,
function,
server_lib,
transforms,
Expand Down Expand Up @@ -36,6 +38,7 @@ def serve(
serve_origin: Origin to serve the functions from.
serve_path: Path to serve the functions from.
"""

handler = comm_lib.CommHandler(
client=client,
framework=FRAMEWORK,
Expand Down Expand Up @@ -115,7 +118,16 @@ def _write_comm_response(

self.set_status(comm_res.status_code)

app.add_handlers(r".*", [("/api/inngest", InngestHandler)])
app.add_handlers(
r".*",
[
(
config_lib.get_serve_path(serve_path)
or const.DEFAULT_SERVE_PATH,
InngestHandler,
)
],
)


def _parse_query_params(raw: dict[str, list[bytes]]) -> dict[str, str]:
Expand Down
27 changes: 26 additions & 1 deletion tests/test_introspection/test_fast_api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import typing
import unittest

import fastapi
Expand All @@ -12,12 +13,18 @@
class TestIntrospection(base.BaseTestIntrospection):
framework = server_lib.Framework.FAST_API

def _serve(self, client: inngest.Inngest) -> fastapi.testclient.TestClient:
def _serve(
self,
client: inngest.Inngest,
*,
serve_path: typing.Optional[str] = None,
) -> fastapi.testclient.TestClient:
app = fastapi.FastAPI()
inngest.fast_api.serve(
app,
client,
self.create_functions(client),
serve_path=serve_path,
)
return fastapi.testclient.TestClient(app)

Expand Down Expand Up @@ -132,6 +139,24 @@ def test_dev_mode_with_no_signature(self) -> None:
}
assert res.headers.get(server_lib.HeaderKey.SIGNATURE.value) is None

def test_serve_path(self) -> None:
flask_client = self._serve(
inngest.Inngest(
app_id="my-app",
event_key="test",
is_production=False,
signing_key=self.signing_key,
),
serve_path="/custom/path",
)
res = flask_client.get("/custom/path")
assert res.status_code == 200
assert res.json() == {
**self.expected_unauthed_body,
"mode": "dev",
}
assert res.headers.get(server_lib.HeaderKey.SIGNATURE.value) is None


if __name__ == "__main__":
unittest.main()
27 changes: 26 additions & 1 deletion tests/test_introspection/test_flask.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import typing
import unittest

import flask
Expand All @@ -13,12 +14,18 @@
class TestIntrospection(base.BaseTestIntrospection):
framework = server_lib.Framework.FLASK

def _serve(self, client: inngest.Inngest) -> flask.testing.FlaskClient:
def _serve(
self,
client: inngest.Inngest,
*,
serve_path: typing.Optional[str] = None,
) -> flask.testing.FlaskClient:
app = flask.Flask(__name__)
inngest.flask.serve(
app,
client,
self.create_functions(client),
serve_path=serve_path,
)
return app.test_client()

Expand Down Expand Up @@ -134,6 +141,24 @@ def test_dev_mode_with_no_signature(self) -> None:
}
assert res.headers.get(server_lib.HeaderKey.SIGNATURE.value) is None

def test_serve_path(self) -> None:
flask_client = self._serve(
inngest.Inngest(
app_id="my-app",
event_key="test",
is_production=False,
signing_key=self.signing_key,
),
serve_path="/custom/path",
)
res = flask_client.get("/custom/path")
assert res.status_code == 200
assert res.json == {
**self.expected_unauthed_body,
"mode": "dev",
}
assert res.headers.get(server_lib.HeaderKey.SIGNATURE.value) is None


if __name__ == "__main__":
unittest.main()
Loading