Skip to content

Commit

Permalink
use proxy on asgi callbacks on request
Browse files Browse the repository at this point in the history
  • Loading branch information
livioribeiro committed Apr 17, 2024
1 parent 139daec commit 6be2f0c
Show file tree
Hide file tree
Showing 12 changed files with 109 additions and 23 deletions.
2 changes: 1 addition & 1 deletion src/asgikit/headers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Iterable, Optional

from asgikit.multi_value_dict import MultiStrValueDict
from asgikit.util.multi_value_dict import MultiStrValueDict

__all__ = ("Headers", "MutableHeaders")

Expand Down
2 changes: 1 addition & 1 deletion src/asgikit/query.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import urllib.parse
from itertools import chain

from .multi_value_dict import MultiStrValueDict
from asgikit.util.multi_value_dict import MultiStrValueDict

__all__ = ("Query",)

Expand Down
19 changes: 7 additions & 12 deletions src/asgikit/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import json
import re
from collections.abc import AsyncIterable, Awaitable, Callable
from functools import partial
from http import HTTPMethod
from http.cookies import SimpleCookie
from typing import Any
Expand All @@ -21,6 +20,7 @@
from asgikit.headers import Headers
from asgikit.query import Query
from asgikit.responses import Response
from asgikit.util.callable_proxy import CallableProxy
from asgikit.websockets import WebSocket

__all__ = (
Expand Down Expand Up @@ -63,7 +63,7 @@ def __init__(self, scope: AsgiScope, receive: AsgiReceive, send: AsgiSend):
scope[SCOPE_ASGIKIT][SCOPE_REQUEST].setdefault(SCOPE_REQUEST_ATTRIBUTES, {})
scope[SCOPE_ASGIKIT][SCOPE_REQUEST].setdefault(SCOPE_REQUEST_IS_CONSUMED, False)

self.asgi = AsgiProtocol(scope, receive, send)
self.asgi = AsgiProtocol(scope, CallableProxy(receive), CallableProxy(send))

self._headers: Headers | None = None
self._query: Query | None = None
Expand Down Expand Up @@ -178,16 +178,11 @@ def wrap_asgi(
receive: Callable[[AsgiReceive], Awaitable] = None,
send: Callable[[AsgiSend, dict], Awaitable] = None,
):
new_receive = (
partial(receive, self.asgi.receive) if receive else self.asgi.receive
)
new_send = partial(send, self.asgi.send) if send else self.asgi.send
self.asgi = AsgiProtocol(self.asgi.scope, new_receive, new_send)

if self.response:
self.response.asgi = self.asgi
if self.websocket:
self.websocket.asgi = self.asgi
if receive:
self.asgi.receive.wrap(receive)

if send:
self.asgi.send.wrap(send)

def __getitem__(self, item):
return self.attributes[item]
Expand Down
Empty file added src/asgikit/util/__init__.py
Empty file.
File renamed without changes.
20 changes: 20 additions & 0 deletions src/asgikit/util/callable_proxy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from collections.abc import Callable
from functools import partial
from typing import Any, Concatenate, ParamSpec

__all__ = ("CallableProxy",)

P = ParamSpec("P")


class CallableProxy:
__slots__ = ("func",)

def __init__(self, func: Callable[P, Any]):
self.func = func

def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Any:
return self.func(*args, **kwargs)

def wrap(self, wrapper: Callable[Concatenate[Callable[P, Any], P], Any]):
self.func = partial(wrapper, self.func)
File renamed without changes.
73 changes: 73 additions & 0 deletions tests/test_callable_proxy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from asgikit.util.callable_proxy import CallableProxy


def test_call():
def func() -> int:
return 1

proxy = CallableProxy(func)
assert proxy() == 1


async def test_call_async():
async def func() -> int:
return 1

proxy = CallableProxy(func)
assert await proxy() == 1


def test_call_params():
def func(a: int, b: int) -> int:
return a + b

proxy = CallableProxy(func)
assert proxy(1, 2) == 3


def test_wrap():
def func() -> int:
return 1

def wrapper(f) -> int:
return f() + 1

proxy = CallableProxy(func)
proxy.wrap(wrapper)
assert proxy() == 2


async def test_wrap_async():
async def func() -> int:
return 1

async def wrapper(f) -> int:
return await f() + 1

proxy = CallableProxy(func)
proxy.wrap(wrapper)
assert await proxy() == 2


async def test_wrap_async_non_async_wrapper():
async def func() -> int:
return 1

def wrapper(f) -> int:
return f()

proxy = CallableProxy(func)
proxy.wrap(wrapper)
assert await proxy() == 1


def test_wrap_params():
def func(a: int, b: int) -> int:
return a + b

def wrapper(f, *args, **kwargs) -> int:
return f(*args, **kwargs) + 1

proxy = CallableProxy(func)
proxy.wrap(wrapper)
assert proxy(2, 3) == 6
8 changes: 4 additions & 4 deletions tests/test_files.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from pytest import fixture

from asgikit.files import AsyncFile
from asgikit.util.async_file import AsyncFile


@fixture
Expand Down Expand Up @@ -41,11 +41,11 @@ async def test_read_file_chunks(tmp_file, monkeypatch):

import importlib

from asgikit import files
from asgikit.util import async_file

importlib.reload(files)
importlib.reload(async_file)

from asgikit.files import AsyncFile
from asgikit.util.async_file import AsyncFile

file = AsyncFile(str(tmp_file))

Expand Down
2 changes: 1 addition & 1 deletion tests/test_multi_str_value_dict.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from asgikit.multi_value_dict import MultiStrValueDict
from asgikit.util.multi_value_dict import MultiStrValueDict


@pytest.mark.parametrize(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_multi_value_dict.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from asgikit.multi_value_dict import MultiValueDict
from asgikit.util.multi_value_dict import MultiValueDict


@pytest.mark.parametrize(
Expand Down
4 changes: 1 addition & 3 deletions tests/test_responses.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
from http import HTTPStatus

from asgikit.requests import Request
from asgikit.responses import (
Response,
respond_file,
Expand All @@ -12,9 +13,6 @@
respond_text,
stream_writer,
)

from asgikit.requests import Request

from tests.utils.asgi import HttpSendInspector


Expand Down

0 comments on commit 6be2f0c

Please sign in to comment.