Skip to content

Commit

Permalink
change response start method
Browse files Browse the repository at this point in the history
  • Loading branch information
livioribeiro committed May 13, 2024
1 parent 0f0a520 commit 3b6deee
Show file tree
Hide file tree
Showing 9 changed files with 155 additions and 46 deletions.
13 changes: 13 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,19 @@ using an alternative json parser, you just need to write a function that reads t
Similarly, to write another data format into the response, you just write a function that
writes to the response.

## Custom JSON encoder and decoder

By default, asgikit uses `json.dumps` and `json.loads` for dealing with JSON. If
you want to use other libraries like `orjson`, just define the environment variable
`ASGIKIT_JSON_ENCODER` of the module compatible with `json`, or the full path to
the functions that perform encoding and decoding, in that order:

```dotenv
ASGIKIT_JSON_ENCODER=orjson
# or
ASGIKIT_JSON_ENCODER=msgspc.json.encode,msgspc.json.encode
```

## Example request and response

```python
Expand Down
15 changes: 9 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "asgikit"
version = "0.0.0"
version = "0.8.0"
description = "Toolkit for building ASGI applications and libraries"
authors = ["Livio Ribeiro <[email protected]>"]
license = "MIT"
Expand Down Expand Up @@ -40,23 +40,26 @@ optional = true

[tool.poetry.group.dev.dependencies]
uvicorn = { version = "^0.29", extras = ["standard"] }
granian = "^1.3"
pylint = "^3.1"
flake8 = "^7.0"
mypy = "^1.9"
mypy = "^1.10"
isort = "^5.13"
black = "^24.3"
ruff = "^0.3"
black = "^24.4"
ruff = "^0.4"

[tool.poetry.group.test]
optional = true

[tool.poetry.group.test.dependencies]
pytest = "^8.1"
pytest = "^8.2"
pytest-asyncio = "^0.23"
pytest-cov = "^5.0"
coverage = { version = "^7.4", extras = ["toml"] }
coverage = { version = "^7.5", extras = ["toml"] }
httpx = "^0.27"
asgiref = "^3.8"
orjson = "^3.10"
msgspec = "^0.18"

[tool.pytest.ini_options]
asyncio_mode = "auto"
Expand Down
2 changes: 1 addition & 1 deletion src/asgikit/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
__all__ = (
"errors",
"headers",
"multi_value_dict",
"util",
"query",
"requests",
"responses",
Expand Down
32 changes: 32 additions & 0 deletions src/asgikit/_json.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import importlib
import os


def _import(dotted_path: str):
if "." not in dotted_path:
raise ValueError(dotted_path)

module_name, attibute_name = dotted_path.rsplit(".", maxsplit=1)
module = importlib.import_module(module_name)
return getattr(module, attibute_name)


if json_encoder := os.environ.get("ASGIKIT_JSON_ENCODER"):
if "," in json_encoder:
encoder, decoder = [
name.strip() for name in json_encoder.split(",", maxsplit=1)
]
else:
name = json_encoder.strip()
encoder = f"{name}.dumps"
decoder = f"{name}.loads"
try:
JSON_ENCODER = _import(encoder)
JSON_DECODER = _import(decoder)
except ImportError as err:
raise ValueError(f"Invalid ASGIKIT_JSON_ENCODER: {json_encoder}") from err
else:
import json

JSON_ENCODER = json.dumps
JSON_DECODER = json.loads
2 changes: 1 addition & 1 deletion src/asgikit/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
SCOPE_REQUEST_ATTRIBUTES = "attributes"
SCOPE_REQUEST_IS_CONSUMED = "is_consumed"
SCOPE_RESPONSE = "response"
SCOPE_RESPONSE_STATUS = "status"
SCOPE_RESPONSE_HEADERS = "headers"
SCOPE_RESPONSE_COOKIES = "cookies"
SCOPE_RESPONSE_CONTENT_TYPE = "content_type"
SCOPE_RESPONSE_CONTENT_LENGTH = "content_length"
SCOPE_RESPONSE_ENCODING = "encoding"
SCOPE_RESPONSE_IS_STARTED = "is_started"
SCOPE_RESPONSE_IS_FINISHED = "is_finished"
SCOPE_RESPONSE_STATUS = "status"
4 changes: 2 additions & 2 deletions src/asgikit/requests.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio
import json
import re
from collections.abc import AsyncIterable, Awaitable, Callable
from http import HTTPMethod
Expand All @@ -9,6 +8,7 @@

from multipart import multipart

from asgikit._json import JSON_DECODER
from asgikit.asgi import AsgiProtocol, AsgiReceive, AsgiScope, AsgiSend
from asgikit.constants import (
SCOPE_ASGIKIT,
Expand Down Expand Up @@ -233,7 +233,7 @@ async def read_json(request: Request) -> dict | list:
if not body:
return {}

return json.loads(body)
return JSON_DECODER(body)


def _is_form_multipart(content_type: str) -> bool:
Expand Down
73 changes: 41 additions & 32 deletions src/asgikit/responses.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio
import json
import mimetypes
import os
from collections.abc import AsyncIterable
Expand All @@ -14,6 +13,7 @@
import aiofiles
import aiofiles.os

from asgikit._json import JSON_ENCODER
from asgikit.asgi import AsgiProtocol, AsgiReceive, AsgiScope, AsgiSend
from asgikit.constants import (
SCOPE_ASGIKIT,
Expand Down Expand Up @@ -58,6 +58,10 @@ class Response:
def __init__(self, scope: AsgiScope, receive: AsgiReceive, send: AsgiSend):
scope.setdefault(SCOPE_ASGIKIT, {})
scope[SCOPE_ASGIKIT].setdefault(SCOPE_RESPONSE, {})

scope[SCOPE_ASGIKIT][SCOPE_RESPONSE].setdefault(
SCOPE_RESPONSE_STATUS, HTTPStatus.OK
)
scope[SCOPE_ASGIKIT][SCOPE_RESPONSE].setdefault(
SCOPE_RESPONSE_HEADERS, MutableHeaders()
)
Expand All @@ -76,6 +80,14 @@ def __init__(self, scope: AsgiScope, receive: AsgiReceive, send: AsgiSend):

self.asgi = AsgiProtocol(scope, receive, send)

@property
def status(self) -> HTTPStatus | None:
return self.asgi.scope[SCOPE_ASGIKIT][SCOPE_RESPONSE][SCOPE_RESPONSE_STATUS]

@status.setter
def status(self, status: HTTPStatus):
self.asgi.scope[SCOPE_ASGIKIT][SCOPE_RESPONSE][SCOPE_RESPONSE_STATUS] = status

@property
def headers(self) -> MutableHeaders:
return self.asgi.scope[SCOPE_ASGIKIT][SCOPE_RESPONSE][SCOPE_RESPONSE_HEADERS]
Expand Down Expand Up @@ -134,13 +146,6 @@ def __set_finished(self):
SCOPE_RESPONSE_IS_FINISHED
] = True

@property
def status(self) -> HTTPStatus | None:
return self.asgi.scope[SCOPE_ASGIKIT][SCOPE_RESPONSE].get(SCOPE_RESPONSE_STATUS)

def __set_status(self, status: HTTPStatus):
self.asgi.scope[SCOPE_ASGIKIT][SCOPE_RESPONSE][SCOPE_RESPONSE_STATUS] = status

def header(self, name: str, value: str):
self.headers.set(name, value)

Expand Down Expand Up @@ -184,17 +189,18 @@ def __build_headers(self) -> list[tuple[bytes, bytes]]:

return self.headers.encode()

async def start(self, status=HTTPStatus.OK):
async def start(self):
if self.is_started:
raise RuntimeError("response has already started")

if self.is_finished:
raise RuntimeError("response has already ended")

self.__set_started()
self.__set_status(status)

status = self.status
headers = self.__build_headers()

await self.asgi.send(
{
"type": "http.response.start",
Expand Down Expand Up @@ -230,21 +236,24 @@ async def end(self):
await self.write(b"", more_body=False)


async def respond_text(
response: Response, content: str | bytes, *, status: HTTPStatus = HTTPStatus.OK
):
data = content.encode(response.encoding) if isinstance(content, str) else content
async def respond_text(response: Response, content: str | bytes):
if isinstance(content, str):
data = content.encode(response.encoding)
else:
data = content

if not response.content_type:
response.content_type = "text/plain"

response.content_length = len(data)

await response.start(status)
await response.start()
await response.write(data, more_body=False)


async def respond_status(response: Response, status: HTTPStatus):
await response.start(status)
response.status = status
await response.start()
await response.end()


Expand All @@ -264,14 +273,13 @@ async def respond_redirect_post_get(response: Response, location: str):
await respond_status(response, HTTPStatus.SEE_OTHER)


async def respond_json(response: Response, content: Any, status=HTTPStatus.OK):
data = json.dumps(content).encode(response.encoding)
async def respond_json(response: Response, content: Any):
data = JSON_ENCODER(content)
if isinstance(data, str):
data = data.encode(response.encoding)

response.content_type = "application/json"
response.content_length = len(data)

await response.start(status)
await response.write(data, more_body=False)
await respond_text(response, data)


async def __listen_for_disconnect(receive):
Expand All @@ -283,6 +291,8 @@ async def __listen_for_disconnect(receive):

@asynccontextmanager
async def stream_writer(response: Response):
await response.start()

client_disconect = asyncio.create_task(
__listen_for_disconnect(response.asgi.receive)
)
Expand All @@ -299,11 +309,7 @@ async def write(data: bytes | str):
client_disconect.cancel()


async def respond_stream(
response: Response, stream: AsyncIterable[bytes | str], *, status=HTTPStatus.OK
):
await response.start(status)

async def respond_stream(response: Response, stream: AsyncIterable[bytes | str]):
async with stream_writer(response) as write:
async for chunk in stream:
await write(chunk)
Expand All @@ -326,20 +332,22 @@ def __supports_zerocopysend(scope):
return "extensions" in scope and "http.response.zerocopysend" in scope["extensions"]


async def respond_file(
response: Response, path: str | PathLike[str], status=HTTPStatus.OK
):
async def respond_file(response: Response, path: str | PathLike[str]):
if not response.content_type:
response.content_type = __guess_mimetype(path)

stat = await asyncio.to_thread(os.stat, path)
stat = await aiofiles.os.stat(path)
content_length = stat.st_size
last_modified = __file_last_modified(stat)

response.content_length = content_length
response.headers.set("last-modified", last_modified)

if not isinstance(path, str):
path = str(path)

if __supports_pathsend(response.asgi.scope):
await response.start()
await response.asgi.send(
{
"type": "http.response.pathsend",
Expand All @@ -349,6 +357,7 @@ async def respond_file(
return

if __supports_zerocopysend(response.asgi.scope):
await response.start()
file = await asyncio.to_thread(open, path, "rb")
await response.asgi.send(
{
Expand All @@ -359,4 +368,4 @@ async def respond_file(
return

async with aiofiles.open(path, "rb") as stream:
await respond_stream(response, stream, status=status)
await respond_stream(response, stream)
28 changes: 27 additions & 1 deletion tests/test_requests.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import copy
import importlib
import sys
from http import HTTPMethod

import pytest
Expand Down Expand Up @@ -134,7 +136,24 @@ async def receive() -> HTTPRequestEvent:
assert result == "12345"


async def test_request_json():
@pytest.mark.parametrize(
"name, encoder",
[
("json", None),
("orjson", "orjson"),
("msgspec", "msgspec.json.decode,msgspec.json.decode"),
],
ids=["json", "orjson", "msgspec"],
)
async def test_request_json(name, encoder, monkeypatch):
if encoder:
monkeypatch.setenv("ASGIKIT_JSON_ENCODER", encoder)

importlib.reload(sys.modules["asgikit._json"])
from asgikit._json import JSON_DECODER

assert JSON_DECODER.__module__.startswith(name)

async def receive() -> HTTPRequestEvent:
return {
"type": "http.request",
Expand All @@ -148,6 +167,13 @@ async def receive() -> HTTPRequestEvent:
assert result == {"name": "Selva", "rank": 1}


@pytest.mark.parametrize("encoder", ["invalid", "module.invalid"])
def test_json_invalid_decoder_should_fail(encoder, monkeypatch):
monkeypatch.setenv("ASGIKIT_JSON_ENCODER", encoder)
with pytest.raises(ValueError, match=f"Invalid ASGIKIT_JSON_ENCODER: {encoder}"):
importlib.reload(sys.modules["asgikit._json"])


async def test_request_invalid_json_should_fail():
async def receive() -> HTTPRequestEvent:
return {
Expand Down
Loading

0 comments on commit 3b6deee

Please sign in to comment.