Skip to content

Commit

Permalink
add flag to @service decorator to create and initialize the service o…
Browse files Browse the repository at this point in the history
…n application startup (#41)
  • Loading branch information
livioribeiro committed Feb 29, 2024
1 parent 0826225 commit 82b06e6
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 9 deletions.
25 changes: 18 additions & 7 deletions src/selva/di/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,30 +27,38 @@ def __init__(self):
self.registry = ServiceRegistry()
self.store: dict[tuple[type, str | None], Any] = {}
self.finalizers: list[Awaitable] = []
self.startup: list[tuple[Type, str | None]] = []
self.interceptors: list[Type[Interceptor]] = []

def register(
self, service: InjectableType, *, provides: type = None, name: str = None
self,
service: InjectableType,
*,
provides: type = None,
name: str = None,
startup: bool = False,
):
self._register_service_spec(service, provides, name)
self._register_service_spec(service, provides, name, startup)

def service(self, service: type):
service_info = getattr(service, DI_ATTRIBUTE_SERVICE, None)

if not service_info:
raise ServiceWithoutDecoratorError(service)

provides, name = service_info
self._register_service_spec(service, provides, name)
self._register_service_spec(service, *service_info)

def _register_service_spec(
self, service: type, provides: type | None, name: str | None
self, service: type, provides: type | None, name: str | None, startup: bool
):
service_spec = parse_service_spec(service, provides, name)
provided_service = service_spec.provides

self.registry[provided_service, name] = service_spec

if startup:
self.startup.append((service_spec.provides, name))

if provides:
logger.trace(
"service registered: {}.{}; provided-by={}.{} name={}",
Expand Down Expand Up @@ -90,8 +98,7 @@ def predicate_services(item: Any):
return hasattr(item, DI_ATTRIBUTE_SERVICE)

for service in scan_packages(packages, predicate_services):
provides, name = getattr(service, DI_ATTRIBUTE_SERVICE)
self.register(service, provides=provides, name=name)
self.service(service)

def has(self, service: type, name: str = None) -> bool:
definition = self.registry.get(service, name=name)
Expand Down Expand Up @@ -226,6 +233,10 @@ async def _run_interceptors(self, instance: Any, service_type: type):
)
await maybe_async(interceptor.intercept, instance, service_type)

async def _run_startup(self):
for service, name in self.startup:
await self.get(service, name=name)

async def _run_finalizers(self):
for finalizer in reversed(self.finalizers):
await finalizer
Expand Down
11 changes: 9 additions & 2 deletions src/selva/di/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,12 @@ def _is_inject(value) -> bool:

@dataclass_transform(eq_default=False)
def service(
injectable: T = None, /, *, provides: type = None, name: str = None
injectable: T = None,
/,
*,
provides: type = None,
name: str = None,
startup: bool = False,
) -> T | Callable[[T], T]:
"""Declare a class or function as a service
Expand All @@ -35,7 +40,9 @@ def service(
"""

def inner(inner_injectable: InjectableType) -> T:
setattr(inner_injectable, DI_ATTRIBUTE_SERVICE, ServiceInfo(provides, name))
setattr(
inner_injectable, DI_ATTRIBUTE_SERVICE, ServiceInfo(provides, name, startup)
)

if inspect.isclass(inner_injectable):
dependencies = [
Expand Down
1 change: 1 addition & 0 deletions src/selva/di/service/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
class ServiceInfo(NamedTuple):
provides: type | None
name: str | None
startup: bool = False


class ServiceDependency(NamedTuple):
Expand Down
1 change: 1 addition & 0 deletions src/selva/web/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ async def _initialize_middleware(self):
self.handler = chain

async def _lifespan_startup(self):
await self.di._run_startup()
await self._initialize_extensions()
await self._initialize_middleware()

Expand Down
35 changes: 35 additions & 0 deletions tests/di/test_startup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import pytest

from selva.di import service, Container


@pytest.fixture
def service_class():
class ServiceClass:
startup_called = False

def initialize(self):
ServiceClass.startup_called = not ServiceClass.startup_called

return ServiceClass


@pytest.fixture
def service_factory(service_class):
def service_factory() -> service_class:
service_class.startup_called = not service_class.startup_called
return service_class()

return service_factory


async def test_startup_class(ioc: Container, service_class):
ioc.register(service_class, startup=True)
await ioc._run_startup()
assert service_class.startup_called


async def test_startup_factory(ioc: Container, service_class, service_factory):
ioc.register(service_factory, startup=True)
await ioc._run_startup()
assert service_class.startup_called
26 changes: 26 additions & 0 deletions tests/web/application/test_service_startup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from selva.configuration.defaults import default_settings
from selva.configuration.settings import Settings
from selva.di import service
from selva.web.application import Selva


async def test_application():
@service(startup=True)
class Service:
startup_called = False

def initialize(self):
Service.startup_called = True

settings = Settings(
default_settings | {
"application": f"{__package__}.application",
}
)

app = Selva(settings)
app.di.service(Service)

await app._lifespan_startup()

assert Service.startup_called

0 comments on commit 82b06e6

Please sign in to comment.