From 3b6deeebcff148e631f52da1abda2711e8e7e91a Mon Sep 17 00:00:00 2001 From: Livio Ribeiro Date: Mon, 13 May 2024 15:00:32 -0300 Subject: [PATCH] change response start method --- README.md | 13 +++++++ pyproject.toml | 15 +++++---- src/asgikit/__init__.py | 2 +- src/asgikit/_json.py | 32 ++++++++++++++++++ src/asgikit/constants.py | 2 +- src/asgikit/requests.py | 4 +-- src/asgikit/responses.py | 73 ++++++++++++++++++++++------------------ tests/test_requests.py | 28 ++++++++++++++- tests/test_responses.py | 32 ++++++++++++++++-- 9 files changed, 155 insertions(+), 46 deletions(-) create mode 100644 src/asgikit/_json.py diff --git a/README.md b/README.md index 414a85b..c49ad59 100644 --- a/README.md +++ b/README.md @@ -35,6 +35,19 @@ using an alternative json parser, you just need to write a function that reads t Similarly, to write another data format into the response, you just write a function that writes to the response. +## Custom JSON encoder and decoder + +By default, asgikit uses `json.dumps` and `json.loads` for dealing with JSON. If +you want to use other libraries like `orjson`, just define the environment variable +`ASGIKIT_JSON_ENCODER` of the module compatible with `json`, or the full path to +the functions that perform encoding and decoding, in that order: + +```dotenv +ASGIKIT_JSON_ENCODER=orjson +# or +ASGIKIT_JSON_ENCODER=msgspc.json.encode,msgspc.json.encode +``` + ## Example request and response ```python diff --git a/pyproject.toml b/pyproject.toml index e64ce4d..54d8b22 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "asgikit" -version = "0.0.0" +version = "0.8.0" description = "Toolkit for building ASGI applications and libraries" authors = ["Livio Ribeiro "] license = "MIT" @@ -40,23 +40,26 @@ optional = true [tool.poetry.group.dev.dependencies] uvicorn = { version = "^0.29", extras = ["standard"] } +granian = "^1.3" pylint = "^3.1" flake8 = "^7.0" -mypy = "^1.9" +mypy = "^1.10" isort = "^5.13" -black = "^24.3" -ruff = "^0.3" +black = "^24.4" +ruff = "^0.4" [tool.poetry.group.test] optional = true [tool.poetry.group.test.dependencies] -pytest = "^8.1" +pytest = "^8.2" pytest-asyncio = "^0.23" pytest-cov = "^5.0" -coverage = { version = "^7.4", extras = ["toml"] } +coverage = { version = "^7.5", extras = ["toml"] } httpx = "^0.27" asgiref = "^3.8" +orjson = "^3.10" +msgspec = "^0.18" [tool.pytest.ini_options] asyncio_mode = "auto" diff --git a/src/asgikit/__init__.py b/src/asgikit/__init__.py index 98afce2..3697d43 100644 --- a/src/asgikit/__init__.py +++ b/src/asgikit/__init__.py @@ -1,7 +1,7 @@ __all__ = ( "errors", "headers", - "multi_value_dict", + "util", "query", "requests", "responses", diff --git a/src/asgikit/_json.py b/src/asgikit/_json.py new file mode 100644 index 0000000..0b9b314 --- /dev/null +++ b/src/asgikit/_json.py @@ -0,0 +1,32 @@ +import importlib +import os + + +def _import(dotted_path: str): + if "." not in dotted_path: + raise ValueError(dotted_path) + + module_name, attibute_name = dotted_path.rsplit(".", maxsplit=1) + module = importlib.import_module(module_name) + return getattr(module, attibute_name) + + +if json_encoder := os.environ.get("ASGIKIT_JSON_ENCODER"): + if "," in json_encoder: + encoder, decoder = [ + name.strip() for name in json_encoder.split(",", maxsplit=1) + ] + else: + name = json_encoder.strip() + encoder = f"{name}.dumps" + decoder = f"{name}.loads" + try: + JSON_ENCODER = _import(encoder) + JSON_DECODER = _import(decoder) + except ImportError as err: + raise ValueError(f"Invalid ASGIKIT_JSON_ENCODER: {json_encoder}") from err +else: + import json + + JSON_ENCODER = json.dumps + JSON_DECODER = json.loads diff --git a/src/asgikit/constants.py b/src/asgikit/constants.py index 89a9bfc..030e343 100644 --- a/src/asgikit/constants.py +++ b/src/asgikit/constants.py @@ -3,6 +3,7 @@ SCOPE_REQUEST_ATTRIBUTES = "attributes" SCOPE_REQUEST_IS_CONSUMED = "is_consumed" SCOPE_RESPONSE = "response" +SCOPE_RESPONSE_STATUS = "status" SCOPE_RESPONSE_HEADERS = "headers" SCOPE_RESPONSE_COOKIES = "cookies" SCOPE_RESPONSE_CONTENT_TYPE = "content_type" @@ -10,4 +11,3 @@ SCOPE_RESPONSE_ENCODING = "encoding" SCOPE_RESPONSE_IS_STARTED = "is_started" SCOPE_RESPONSE_IS_FINISHED = "is_finished" -SCOPE_RESPONSE_STATUS = "status" diff --git a/src/asgikit/requests.py b/src/asgikit/requests.py index 6d3cfd2..ed35619 100644 --- a/src/asgikit/requests.py +++ b/src/asgikit/requests.py @@ -1,5 +1,4 @@ import asyncio -import json import re from collections.abc import AsyncIterable, Awaitable, Callable from http import HTTPMethod @@ -9,6 +8,7 @@ from multipart import multipart +from asgikit._json import JSON_DECODER from asgikit.asgi import AsgiProtocol, AsgiReceive, AsgiScope, AsgiSend from asgikit.constants import ( SCOPE_ASGIKIT, @@ -233,7 +233,7 @@ async def read_json(request: Request) -> dict | list: if not body: return {} - return json.loads(body) + return JSON_DECODER(body) def _is_form_multipart(content_type: str) -> bool: diff --git a/src/asgikit/responses.py b/src/asgikit/responses.py index 9307af2..82560e1 100644 --- a/src/asgikit/responses.py +++ b/src/asgikit/responses.py @@ -1,5 +1,4 @@ import asyncio -import json import mimetypes import os from collections.abc import AsyncIterable @@ -14,6 +13,7 @@ import aiofiles import aiofiles.os +from asgikit._json import JSON_ENCODER from asgikit.asgi import AsgiProtocol, AsgiReceive, AsgiScope, AsgiSend from asgikit.constants import ( SCOPE_ASGIKIT, @@ -58,6 +58,10 @@ class Response: def __init__(self, scope: AsgiScope, receive: AsgiReceive, send: AsgiSend): scope.setdefault(SCOPE_ASGIKIT, {}) scope[SCOPE_ASGIKIT].setdefault(SCOPE_RESPONSE, {}) + + scope[SCOPE_ASGIKIT][SCOPE_RESPONSE].setdefault( + SCOPE_RESPONSE_STATUS, HTTPStatus.OK + ) scope[SCOPE_ASGIKIT][SCOPE_RESPONSE].setdefault( SCOPE_RESPONSE_HEADERS, MutableHeaders() ) @@ -76,6 +80,14 @@ def __init__(self, scope: AsgiScope, receive: AsgiReceive, send: AsgiSend): self.asgi = AsgiProtocol(scope, receive, send) + @property + def status(self) -> HTTPStatus | None: + return self.asgi.scope[SCOPE_ASGIKIT][SCOPE_RESPONSE][SCOPE_RESPONSE_STATUS] + + @status.setter + def status(self, status: HTTPStatus): + self.asgi.scope[SCOPE_ASGIKIT][SCOPE_RESPONSE][SCOPE_RESPONSE_STATUS] = status + @property def headers(self) -> MutableHeaders: return self.asgi.scope[SCOPE_ASGIKIT][SCOPE_RESPONSE][SCOPE_RESPONSE_HEADERS] @@ -134,13 +146,6 @@ def __set_finished(self): SCOPE_RESPONSE_IS_FINISHED ] = True - @property - def status(self) -> HTTPStatus | None: - return self.asgi.scope[SCOPE_ASGIKIT][SCOPE_RESPONSE].get(SCOPE_RESPONSE_STATUS) - - def __set_status(self, status: HTTPStatus): - self.asgi.scope[SCOPE_ASGIKIT][SCOPE_RESPONSE][SCOPE_RESPONSE_STATUS] = status - def header(self, name: str, value: str): self.headers.set(name, value) @@ -184,7 +189,7 @@ def __build_headers(self) -> list[tuple[bytes, bytes]]: return self.headers.encode() - async def start(self, status=HTTPStatus.OK): + async def start(self): if self.is_started: raise RuntimeError("response has already started") @@ -192,9 +197,10 @@ async def start(self, status=HTTPStatus.OK): raise RuntimeError("response has already ended") self.__set_started() - self.__set_status(status) + status = self.status headers = self.__build_headers() + await self.asgi.send( { "type": "http.response.start", @@ -230,21 +236,24 @@ async def end(self): await self.write(b"", more_body=False) -async def respond_text( - response: Response, content: str | bytes, *, status: HTTPStatus = HTTPStatus.OK -): - data = content.encode(response.encoding) if isinstance(content, str) else content +async def respond_text(response: Response, content: str | bytes): + if isinstance(content, str): + data = content.encode(response.encoding) + else: + data = content + if not response.content_type: response.content_type = "text/plain" response.content_length = len(data) - await response.start(status) + await response.start() await response.write(data, more_body=False) async def respond_status(response: Response, status: HTTPStatus): - await response.start(status) + response.status = status + await response.start() await response.end() @@ -264,14 +273,13 @@ async def respond_redirect_post_get(response: Response, location: str): await respond_status(response, HTTPStatus.SEE_OTHER) -async def respond_json(response: Response, content: Any, status=HTTPStatus.OK): - data = json.dumps(content).encode(response.encoding) +async def respond_json(response: Response, content: Any): + data = JSON_ENCODER(content) + if isinstance(data, str): + data = data.encode(response.encoding) response.content_type = "application/json" - response.content_length = len(data) - - await response.start(status) - await response.write(data, more_body=False) + await respond_text(response, data) async def __listen_for_disconnect(receive): @@ -283,6 +291,8 @@ async def __listen_for_disconnect(receive): @asynccontextmanager async def stream_writer(response: Response): + await response.start() + client_disconect = asyncio.create_task( __listen_for_disconnect(response.asgi.receive) ) @@ -299,11 +309,7 @@ async def write(data: bytes | str): client_disconect.cancel() -async def respond_stream( - response: Response, stream: AsyncIterable[bytes | str], *, status=HTTPStatus.OK -): - await response.start(status) - +async def respond_stream(response: Response, stream: AsyncIterable[bytes | str]): async with stream_writer(response) as write: async for chunk in stream: await write(chunk) @@ -326,20 +332,22 @@ def __supports_zerocopysend(scope): return "extensions" in scope and "http.response.zerocopysend" in scope["extensions"] -async def respond_file( - response: Response, path: str | PathLike[str], status=HTTPStatus.OK -): +async def respond_file(response: Response, path: str | PathLike[str]): if not response.content_type: response.content_type = __guess_mimetype(path) - stat = await asyncio.to_thread(os.stat, path) + stat = await aiofiles.os.stat(path) content_length = stat.st_size last_modified = __file_last_modified(stat) response.content_length = content_length response.headers.set("last-modified", last_modified) + if not isinstance(path, str): + path = str(path) + if __supports_pathsend(response.asgi.scope): + await response.start() await response.asgi.send( { "type": "http.response.pathsend", @@ -349,6 +357,7 @@ async def respond_file( return if __supports_zerocopysend(response.asgi.scope): + await response.start() file = await asyncio.to_thread(open, path, "rb") await response.asgi.send( { @@ -359,4 +368,4 @@ async def respond_file( return async with aiofiles.open(path, "rb") as stream: - await respond_stream(response, stream, status=status) + await respond_stream(response, stream) diff --git a/tests/test_requests.py b/tests/test_requests.py index a3c9b4c..19c38d6 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -1,4 +1,6 @@ import copy +import importlib +import sys from http import HTTPMethod import pytest @@ -134,7 +136,24 @@ async def receive() -> HTTPRequestEvent: assert result == "12345" -async def test_request_json(): +@pytest.mark.parametrize( + "name, encoder", + [ + ("json", None), + ("orjson", "orjson"), + ("msgspec", "msgspec.json.decode,msgspec.json.decode"), + ], + ids=["json", "orjson", "msgspec"], +) +async def test_request_json(name, encoder, monkeypatch): + if encoder: + monkeypatch.setenv("ASGIKIT_JSON_ENCODER", encoder) + + importlib.reload(sys.modules["asgikit._json"]) + from asgikit._json import JSON_DECODER + + assert JSON_DECODER.__module__.startswith(name) + async def receive() -> HTTPRequestEvent: return { "type": "http.request", @@ -148,6 +167,13 @@ async def receive() -> HTTPRequestEvent: assert result == {"name": "Selva", "rank": 1} +@pytest.mark.parametrize("encoder", ["invalid", "module.invalid"]) +def test_json_invalid_decoder_should_fail(encoder, monkeypatch): + monkeypatch.setenv("ASGIKIT_JSON_ENCODER", encoder) + with pytest.raises(ValueError, match=f"Invalid ASGIKIT_JSON_ENCODER: {encoder}"): + importlib.reload(sys.modules["asgikit._json"]) + + async def test_request_invalid_json_should_fail(): async def receive() -> HTTPRequestEvent: return { diff --git a/tests/test_responses.py b/tests/test_responses.py index e73e8c6..264dfdf 100644 --- a/tests/test_responses.py +++ b/tests/test_responses.py @@ -1,7 +1,10 @@ import asyncio +import importlib +import sys from http import HTTPStatus -from asgikit.requests import Request +import pytest + from asgikit.responses import ( Response, respond_file, @@ -26,7 +29,24 @@ async def test_respond_plain_text(): assert inspector.body == "Hello, World!" -async def test_respond_json(): +@pytest.mark.parametrize( + "name, encoder", + [ + ("json", None), + ("orjson", "orjson"), + ("msgspec", "msgspec.json.decode,msgspec.json.decode"), + ], + ids=["json", "orjson", "msgspec"], +) +async def test_respond_json(name, encoder, monkeypatch): + if encoder: + monkeypatch.setenv("ASGIKIT_JSON_ENCODER", encoder) + + importlib.reload(sys.modules["asgikit._json"]) + from asgikit._json import JSON_ENCODER + + assert JSON_ENCODER.__module__.startswith(name) + inspector = HttpSendInspector() scope = {"type": "http"} response = Response(scope, None, inspector) @@ -35,6 +55,13 @@ async def test_respond_json(): assert inspector.body == """{"message": "Hello, World!"}""" +@pytest.mark.parametrize("encoder", ["invalid", "module.invalid"]) +def test_json_invalid_decoder_should_fail(encoder, monkeypatch): + monkeypatch.setenv("ASGIKIT_JSON_ENCODER", encoder) + with pytest.raises(ValueError, match=r"Invalid ASGIKIT_JSON_ENCODER"): + importlib.reload(sys.modules["asgikit._json"]) + + async def test_stream(): async def stream_data(): yield "Hello, " @@ -53,7 +80,6 @@ async def test_stream_context_manager(): scope = {"type": "http", "http_version": "1.1"} response = Response(scope, None, inspector) - await response.start() async with stream_writer(response) as write: await write("Hello, ") await write("World!")