Skip to content

Commit

Permalink
read data from Body class
Browse files Browse the repository at this point in the history
  • Loading branch information
livioribeiro committed Jan 1, 2025
1 parent 64e8248 commit 1e47884
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 78 deletions.
146 changes: 77 additions & 69 deletions src/asgikit/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
FORM_MULTIPART_CONTENT_TYPE = "multipart/form-data"
FORM_CONTENT_TYPES = (FORM_URLENCODED_CONTENT_TYPE, FORM_MULTIPART_CONTENT_TYPE)

RE_CHARSET = re.compile(r"charset=([\w-]+)")
RE_CHARSET = re.compile(r"""charset=([\w-]+|"[\w-]+")""")


def _parse_cookie(data: str) -> dict[str, str]:
Expand All @@ -45,13 +45,29 @@ def _parse_cookie(data: str) -> dict[str, str]:


class Body:
"""Async iterator over request body"""
"""Provides an async iterator over request body"""

__slots__ = ("_scope", "_receive")
content_type: str | None
content_length: int | None
charset: str | None

def __init__(self, scope: AsgiScope, receive: AsgiReceive):
__slots__ = ("_scope", "_receive", "content_type", "content_length", "charset")

def __init__(self, scope: AsgiScope, receive: AsgiReceive, headers: Headers):
self._scope = scope
self._receive = receive
self.content_type = headers.get("content-type")

if content_length := headers.get("content-length"):
self.content_length = int(content_length)
else:
self.content_length = None

if self.content_type:
values = RE_CHARSET.findall(self.content_type)
self.charset = values[0] if values else "utf-8"
else:
self.charset = "utf-8"

@property
def is_consumed(self) -> bool:
Expand Down Expand Up @@ -89,14 +105,13 @@ class Request:
"""Represents the incoming request"""

__slots__ = (
"_scope",
"_receive",
"_send",
"asgi_scope",
"asgi_receive",
"asgi_send",
"_headers",
"_query",
"_cookie",
"_charset",
"body",
"_body",
"response",
"websocket",
)
Expand All @@ -109,99 +124,98 @@ def __init__(self, scope: AsgiScope, receive: AsgiReceive, send: AsgiSend):
scope[SCOPE_ASGIKIT][SCOPE_REQUEST].setdefault(SCOPE_REQUEST_ATTRIBUTES, {})
scope[SCOPE_ASGIKIT][SCOPE_REQUEST].setdefault(SCOPE_REQUEST_IS_CONSUMED, False)

self._scope = scope
self._receive = CallableProxy(receive)
self._send = CallableProxy(send)
self.asgi_scope = scope
self.asgi_receive = CallableProxy(receive)
self.asgi_send = CallableProxy(send)

self._headers: Headers | None = None
self._query: Query | None = None
self._charset = None
self._cookie = None
self._body = None

self.body = Body(self._scope, self._receive)
self.response = (
Response(self._scope, self._receive, self._send)
Response(self.asgi_scope, self.asgi_receive, self.asgi_send)
if self.is_http
else None
)
self.websocket = (
WebSocket(self._scope, self._receive, self._send)
WebSocket(self.asgi_scope, self.asgi_receive, self.asgi_send)
if self.is_websocket
else None
)

@property
def attributes(self) -> dict[str, Any]:
"""Request attributes in the scope of asgikit"""
return self._scope[SCOPE_ASGIKIT][SCOPE_REQUEST][SCOPE_REQUEST_ATTRIBUTES]
return self.asgi_scope[SCOPE_ASGIKIT][SCOPE_REQUEST][SCOPE_REQUEST_ATTRIBUTES]

@property
def is_http(self) -> bool:
"""Tell if the request is an HTTP request
Returns False for websocket requests
"""
return self._scope["type"] == "http"
return self.asgi_scope["type"] == "http"

@property
def is_websocket(self) -> bool:
"""Tell if the request is a WebSocket request
Returns False for HTTP requests
"""
return self._scope["type"] == "websocket"
return self.asgi_scope["type"] == "websocket"

@property
def http_version(self) -> str:
return self._scope["http_version"]
return self.asgi_scope["http_version"]

@property
def server(self) -> tuple[str, int | None]:
return self._scope["server"]
return self.asgi_scope["server"]

@property
def client(self) -> tuple[str, int] | None:
return self._scope["client"]
return self.asgi_scope["client"]

@property
def scheme(self) -> str:
return self._scope["scheme"]
return self.asgi_scope["scheme"]

@property
def method(self) -> HTTPMethod | None:
"""Return None when request is websocket"""

if method := self._scope.get("method"):
if method := self.asgi_scope.get("method"):
return HTTPMethod(method)

return None

@property
def root_path(self) -> str:
return self._scope["root_path"]
return self.asgi_scope["root_path"]

@property
def path(self) -> str:
return self._scope["path"]
return self.asgi_scope["path"]

@property
def raw_path(self) -> str | None:
return self._scope["raw_path"]
return self.asgi_scope["raw_path"]

@property
def headers(self) -> Headers:
if not self._headers:
self._headers = Headers(self._scope["headers"])
self._headers = Headers(self.asgi_scope["headers"])
return self._headers

@property
def raw_query(self) -> str:
return unquote_plus(self._scope["query_string"].decode("ascii"))
return unquote_plus(self.asgi_scope["query_string"].decode("ascii"))

@property
def query(self) -> Query:
if not self._query:
self._query = Query(self._scope["query_string"])
self._query = Query(self.asgi_scope["query_string"])
return self._query

@property
Expand All @@ -211,21 +225,10 @@ def cookie(self) -> dict[str, str]:
return self._cookie

@property
def content_type(self) -> str | None:
return self.headers.get("content-type")

@property
def content_length(self) -> int | None:
if content_length := self.headers.get("content-length"):
return int(content_length)
return None

@property
def charset(self) -> str:
if not self._charset:
values = RE_CHARSET.findall(self.content_type)
self._charset = values[0] if values else "utf-8"
return self._charset
def body(self) -> Body:
if not self._body:
self._body = Body(self.asgi_scope, self.asgi_receive, self.headers)
return self._body

@property
def accept(self) -> str:
Expand All @@ -238,10 +241,10 @@ def wrap_asgi(
send: AsgiSend = None,
):
if receive:
self._receive.wrap(receive)
self.asgi_receive.wrap(receive)

if send:
self._send.wrap(send)
self.asgi_send.wrap(send)

def __getitem__(self, item):
return self.attributes[item]
Expand All @@ -256,45 +259,47 @@ def __contains__(self, item):
return item in self.attributes


async def read_body(request: Request) -> bytes:
async def read_body(obj: Body | Request) -> bytes:
"""Read the full request body"""

body = bytearray()
body = obj.body if isinstance(obj, Request) else obj
data = bytearray()

async for chunk in request.body:
body.extend(chunk)
async for chunk in body:
data.extend(chunk)

return bytes(body)
return bytes(data)


async def read_text(request: Request, encoding: str = None) -> str:
async def read_text(obj: Body | Request, encoding: str = None) -> str:
"""Read the full request body as str"""

body = await read_body(request)
return body.decode(encoding or request.charset)
body = obj.body if isinstance(obj, Request) else obj
data = await read_body(body)
return data.decode(encoding or body.charset)


async def read_json(request: Request) -> dict | list:
async def read_json(obj: Body | Request) -> dict | list:
"""Read the full request body and parse it as json"""

body = await read_body(request)
if not body:
return {}

return JSON_DECODER(body)
if data := await read_body(obj):
return JSON_DECODER(data)
return {}


def _is_form_multipart(content_type: str) -> bool:
return content_type.startswith(FORM_MULTIPART_CONTENT_TYPE)


async def read_form(request: Request) -> dict[str, str | multipart.File]:
async def read_form(obj: Body | Request) -> dict[str, str | multipart.File]:
"""Read the full request body and parse it as form encoded"""

if _is_form_multipart(request.content_type):
return await _read_form_multipart(request)
body = obj.body if isinstance(obj, Request) else obj

if _is_form_multipart(body.content_type or ""):
return await _read_form_multipart(obj)

data = await read_text(request)
data = await read_text(body)
if not data:
return {}

Expand All @@ -305,12 +310,14 @@ async def read_form(request: Request) -> dict[str, str | multipart.File]:


async def _read_form_multipart(
request: Request,
obj: Body | Request,
) -> dict[str, str | multipart.File]:
fields: dict[str, str] = {}
files: dict[str, multipart.File] = {}

charset = request.charset
body = obj.body if isinstance(obj, Request) else obj
content_type = body.content_type or ""
charset = body.charset

def on_field(field: multipart.Field):
fields[field.field_name.decode(charset)] = field.value.decode(charset)
Expand All @@ -319,9 +326,10 @@ def on_file(file: multipart.File):
file.file_object.seek(0)
files[file.field_name.decode(charset)] = file

parser = multipart.create_form_parser(request.headers, on_field, on_file)
headers = {"Content-Type": content_type}
parser = multipart.create_form_parser(headers, on_field, on_file)

async for data in request.body:
async for data in body:
# `parser.write` can potentially write to a file,
# therefore we need to call it using `asyncio.to_thread`
await asyncio.to_thread(parser.write, data)
Expand Down
4 changes: 1 addition & 3 deletions src/asgikit/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,7 @@ def content_length(self) -> int | None:

@content_length.setter
def content_length(self, value: str):
self._scope[SCOPE_ASGIKIT][SCOPE_RESPONSE][
SCOPE_RESPONSE_CONTENT_LENGTH
] = value
self._scope[SCOPE_ASGIKIT][SCOPE_RESPONSE][SCOPE_RESPONSE_CONTENT_LENGTH] = value

@property
def encoding(self) -> str:
Expand Down
Loading

0 comments on commit 1e47884

Please sign in to comment.