diff --git a/src/selva/di/container.py b/src/selva/di/container.py index e47d5f3..1959818 100644 --- a/src/selva/di/container.py +++ b/src/selva/di/container.py @@ -27,12 +27,18 @@ def __init__(self): self.registry = ServiceRegistry() self.store: dict[tuple[type, str | None], Any] = {} self.finalizers: list[Awaitable] = [] + self.startup: list[tuple[Type, str | None]] = [] self.interceptors: list[Type[Interceptor]] = [] def register( - self, service: InjectableType, *, provides: type = None, name: str = None + self, + service: InjectableType, + *, + provides: type = None, + name: str = None, + startup: bool = False, ): - self._register_service_spec(service, provides, name) + self._register_service_spec(service, provides, name, startup) def service(self, service: type): service_info = getattr(service, DI_ATTRIBUTE_SERVICE, None) @@ -40,17 +46,19 @@ def service(self, service: type): if not service_info: raise ServiceWithoutDecoratorError(service) - provides, name = service_info - self._register_service_spec(service, provides, name) + self._register_service_spec(service, *service_info) def _register_service_spec( - self, service: type, provides: type | None, name: str | None + self, service: type, provides: type | None, name: str | None, startup: bool ): service_spec = parse_service_spec(service, provides, name) provided_service = service_spec.provides self.registry[provided_service, name] = service_spec + if startup: + self.startup.append((service_spec.provides, name)) + if provides: logger.trace( "service registered: {}.{}; provided-by={}.{} name={}", @@ -90,8 +98,7 @@ def predicate_services(item: Any): return hasattr(item, DI_ATTRIBUTE_SERVICE) for service in scan_packages(packages, predicate_services): - provides, name = getattr(service, DI_ATTRIBUTE_SERVICE) - self.register(service, provides=provides, name=name) + self.service(service) def has(self, service: type, name: str = None) -> bool: definition = self.registry.get(service, name=name) @@ -226,6 +233,10 @@ async def _run_interceptors(self, instance: Any, service_type: type): ) await maybe_async(interceptor.intercept, instance, service_type) + async def _run_startup(self): + for service, name in self.startup: + await self.get(service, name=name) + async def _run_finalizers(self): for finalizer in reversed(self.finalizers): await finalizer diff --git a/src/selva/di/decorator.py b/src/selva/di/decorator.py index e6dcf2e..32818e7 100644 --- a/src/selva/di/decorator.py +++ b/src/selva/di/decorator.py @@ -26,7 +26,12 @@ def _is_inject(value) -> bool: @dataclass_transform(eq_default=False) def service( - injectable: T = None, /, *, provides: type = None, name: str = None + injectable: T = None, + /, + *, + provides: type = None, + name: str = None, + startup: bool = False, ) -> T | Callable[[T], T]: """Declare a class or function as a service @@ -35,7 +40,9 @@ def service( """ def inner(inner_injectable: InjectableType) -> T: - setattr(inner_injectable, DI_ATTRIBUTE_SERVICE, ServiceInfo(provides, name)) + setattr( + inner_injectable, DI_ATTRIBUTE_SERVICE, ServiceInfo(provides, name, startup) + ) if inspect.isclass(inner_injectable): dependencies = [ diff --git a/src/selva/di/service/model.py b/src/selva/di/service/model.py index 6db1941..998e855 100644 --- a/src/selva/di/service/model.py +++ b/src/selva/di/service/model.py @@ -8,6 +8,7 @@ class ServiceInfo(NamedTuple): provides: type | None name: str | None + startup: bool = False class ServiceDependency(NamedTuple): diff --git a/src/selva/ext/data/redis/settings.py b/src/selva/ext/data/redis/settings.py index 5f62339..42c3ece 100644 --- a/src/selva/ext/data/redis/settings.py +++ b/src/selva/ext/data/redis/settings.py @@ -1,5 +1,5 @@ from types import NoneType -from typing import Self, Type, Literal +from typing import Literal, Self, Type from pydantic import BaseModel, ConfigDict, model_serializer, model_validator from redis import RedisError diff --git a/src/selva/web/application.py b/src/selva/web/application.py index 8fd5714..1899825 100644 --- a/src/selva/web/application.py +++ b/src/selva/web/application.py @@ -134,6 +134,7 @@ async def _initialize_middleware(self): self.handler = chain async def _lifespan_startup(self): + await self.di._run_startup() await self._initialize_extensions() await self._initialize_middleware() diff --git a/tests/di/test_startup.py b/tests/di/test_startup.py new file mode 100644 index 0000000..61fee52 --- /dev/null +++ b/tests/di/test_startup.py @@ -0,0 +1,35 @@ +import pytest + +from selva.di import service, Container + + +@pytest.fixture +def service_class(): + class ServiceClass: + startup_called = False + + def initialize(self): + ServiceClass.startup_called = not ServiceClass.startup_called + + return ServiceClass + + +@pytest.fixture +def service_factory(service_class): + def service_factory() -> service_class: + service_class.startup_called = not service_class.startup_called + return service_class() + + return service_factory + + +async def test_startup_class(ioc: Container, service_class): + ioc.register(service_class, startup=True) + await ioc._run_startup() + assert service_class.startup_called + + +async def test_startup_factory(ioc: Container, service_class, service_factory): + ioc.register(service_factory, startup=True) + await ioc._run_startup() + assert service_class.startup_called diff --git a/tests/web/application/test_service_startup.py b/tests/web/application/test_service_startup.py new file mode 100644 index 0000000..b8a4a2b --- /dev/null +++ b/tests/web/application/test_service_startup.py @@ -0,0 +1,26 @@ +from selva.configuration.defaults import default_settings +from selva.configuration.settings import Settings +from selva.di import service +from selva.web.application import Selva + + +async def test_application(): + @service(startup=True) + class Service: + startup_called = False + + def initialize(self): + Service.startup_called = True + + settings = Settings( + default_settings | { + "application": f"{__package__}.application", + } + ) + + app = Selva(settings) + app.di.service(Service) + + await app._lifespan_startup() + + assert Service.startup_called