diff --git a/docs/middleware.md b/docs/middleware.md index 596675c..461eda0 100644 --- a/docs/middleware.md +++ b/docs/middleware.md @@ -30,14 +30,16 @@ in the processing of the request: from datetime import datetime from asgikit.requests import Request - from selva.web.middleware import Middleware + from selva.di import service + from selva.web.middleware import Middleware, CallNext from loguru import logger + @service class TimingMiddleware(Middleware): - async def __call__(self, chain: Callable, request: Request): + async def __call__(self, call_next: CallNext, request: Request): request_start = datetime.now() - await chain(request) # (1) + await call_next(request) # (1) request_end = datetime.now() delta = request_end - request_start @@ -77,20 +79,22 @@ the timings using a service instead of printing to the console: ```python from collections.abc import Callable from datetime import datetime + from typing import Annotated from asgikit.requests import Request - from selva.di import Inject - from selva.web.middleware import Middleware + from selva.di import service, Inject + from selva.web.middleware import Middleware, CallNext from application.service import TimingService - class TimingMiddleware(Middleware): - timing_service: TimingService = Inject() + @service + class TimingMiddleware: + timing_service: Annotated[TimingService, Inject] - async def __call__(self, chain: Callable, request: Request): + async def __call__(self, call_next: CallNext, request: Request): request_start = datetime.now() - await chain(request) + await call_next(request) request_end = datetime.now() await self.timing_service.save(request_start, request_end) diff --git a/src/selva/configuration/settings.py b/src/selva/configuration/settings.py index 949e8a7..d0930df 100644 --- a/src/selva/configuration/settings.py +++ b/src/selva/configuration/settings.py @@ -26,7 +26,7 @@ SELVA_PROFILE = "SELVA_PROFILE" -class Settings(Mapping): +class Settings(Mapping[str, Any]): def __init__(self, data: dict): for key, value in data.items(): if isinstance(value, dict): diff --git a/src/selva/di/container.py b/src/selva/di/container.py index 3d2c228..8340b3b 100644 --- a/src/selva/di/container.py +++ b/src/selva/di/container.py @@ -1,6 +1,5 @@ import asyncio import inspect -from collections import defaultdict from collections.abc import AsyncGenerator, Awaitable, Generator, Iterable from types import FunctionType, ModuleType from typing import Any, Type, TypeVar @@ -115,21 +114,13 @@ def iter_all_services( for name, definition in record.providers.items(): yield interface, definition.service, name - async def get(self, service_type: T, *, name: str = None, optional=False) -> T: + async def get( + self, service_type: Type[T], *, name: str = None, optional=False + ) -> T: dependency = ServiceDependency(service_type, name=name, optional=optional) return await self._get(dependency) - async def create(self, service_type: type) -> Any: - assert inspect.isclass(service_type) - instance = service_type() - - for name, dep_spec in get_dependencies(service_type): - dependency = await self._get(dep_spec) - setattr(instance, name, dependency) - - return instance - - def _get_from_cache(self, service_type: type, name: str | None) -> Any | None: + def _get_from_cache(self, service_type: Type[T], name: str | None) -> T | None: if instance := self.store.get((service_type, name)): return instance @@ -170,8 +161,7 @@ async def _get_dependent_services( self, service_spec: ServiceSpec, stack: list ) -> dict[str, Any]: return { - name: await self._get(dep, stack) - for name, dep in service_spec.dependencies + name: await self._get(dep, stack) for name, dep in service_spec.dependencies } async def _create_service( diff --git a/src/selva/web/application.py b/src/selva/web/application.py index 74fcb2f..ddd2793 100644 --- a/src/selva/web/application.py +++ b/src/selva/web/application.py @@ -124,7 +124,7 @@ async def _initialize_middleware(self): ) for cls in reversed(middleware): - mid = await self.di.create(cls) + mid = await self.di.get(cls) chain = functools.partial(mid, self.handler) self.handler = chain @@ -206,7 +206,7 @@ async def _handle_request(self, scope, receive, send): logger.error("Response is finished") return - await respond_status(response, err.status) + await respond_text(response, str(err), status=err.status) except Exception as err: logger.exception("Error processing request") await respond_text( diff --git a/src/selva/web/middleware.py b/src/selva/web/middleware.py index 963b30f..0411d24 100644 --- a/src/selva/web/middleware.py +++ b/src/selva/web/middleware.py @@ -1,16 +1,19 @@ -from abc import ABC, abstractmethod from collections.abc import Awaitable, Callable +from typing import Protocol, runtime_checkable from asgikit.requests import Request -__all__ = ("Middleware",) +__all__ = ("Middleware", "CallNext") -class Middleware(ABC): - @abstractmethod +CallNext = Callable[[Request], Awaitable] + + +@runtime_checkable +class Middleware(Protocol): async def __call__( self, - call: Callable[[Request], Awaitable], + call_next: CallNext, request: Request, ): raise NotImplementedError() diff --git a/tests/di/test_create.py b/tests/di/test_create.py deleted file mode 100644 index 6fe3d3b..0000000 --- a/tests/di/test_create.py +++ /dev/null @@ -1,51 +0,0 @@ -from typing import Annotated - -from selva.di.container import Container -from selva.di.decorator import service -from selva.di.inject import Inject - - -@service -class Service1: - pass - - -@service -def service1_factory() -> Service1: - return Service1() - - -@service -class Service2: - service1: Annotated[Service1, Inject] - - -@service -def service2_factory(service1: Service1) -> Service2: - service = Service2() - setattr(service, "service1", service1) - return service - - -class Creatable: - service2: Annotated[Service2, Inject] - - -async def test_create_object_with_class(ioc: Container): - ioc.register(Service1) - ioc.register(Service2) - - result = await ioc.create(Creatable) - - assert isinstance(result.service2, Service2) - assert isinstance(result.service2.service1, Service1) - - -async def test_create_object_with_factory(ioc: Container): - ioc.register(service1_factory) - ioc.register(service2_factory) - - result = await ioc.create(Creatable) - - assert isinstance(result.service2, Service2) - assert isinstance(result.service2.service1, Service1) diff --git a/tests/web/application/test_middleware.py b/tests/web/application/test_middleware.py index 09006e1..492fcf5 100644 --- a/tests/web/application/test_middleware.py +++ b/tests/web/application/test_middleware.py @@ -5,11 +5,10 @@ from selva.configuration.defaults import default_settings from selva.di import service from selva.web.application import Selva -from selva.web.middleware import Middleware @service -class MyMiddleware(Middleware): +class MyMiddleware: async def __call__(self, call, request): send = request.asgi.send @@ -31,6 +30,7 @@ async def test_middleware(): } ) app = Selva(settings) + app.di.register(MyMiddleware) await app._lifespan_startup() client = AsyncClient(app=app) diff --git a/tests/web/routing/router.py b/tests/web/routing/router.py index dd62c27..aa6173a 100644 --- a/tests/web/routing/router.py +++ b/tests/web/routing/router.py @@ -1,7 +1,10 @@ import pytest from selva.web.routing.decorator import controller, get -from selva.web.routing.exception import ControllerWithoutDecoratorError, DuplicateRouteError +from selva.web.routing.exception import ( + ControllerWithoutDecoratorError, + DuplicateRouteError, +) from selva.web.routing.router import Router @@ -29,4 +32,4 @@ async def route2(self, request): router = Router() with pytest.raises(DuplicateRouteError): - router.route(Controller) \ No newline at end of file + router.route(Controller)