Skip to content

Commit

Permalink
add router exceptions
Browse files Browse the repository at this point in the history
  • Loading branch information
livioribeiro committed Mar 15, 2024
1 parent 97efa1f commit cffc090
Show file tree
Hide file tree
Showing 22 changed files with 117 additions and 360 deletions.
5 changes: 1 addition & 4 deletions src/selva/_util/pydantic/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1 @@
# flake8: noqa: F401
# ruff: noqa: F401

from .dotted_path import DottedPath
from .dotted_path import DottedPath # noqa: F401
5 changes: 1 addition & 4 deletions src/selva/configuration/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1 @@
# flake8: noqa: F401
# ruff: noqa: F401

from selva.configuration.settings import Settings
from selva.configuration.settings import Settings # noqa: F401
139 changes: 42 additions & 97 deletions src/selva/di/container.py
Original file line number Diff line number Diff line change
@@ -1,62 +1,49 @@
import asyncio
import functools
import inspect
from collections import defaultdict
from collections.abc import AsyncGenerator, Awaitable, Generator, Iterable
from types import FunctionType, ModuleType
from typing import Any, Type, TypeVar
from weakref import WeakKeyDictionary

from loguru import logger

from selva._util.maybe_async import maybe_async
from selva._util.package_scan import scan_packages
from selva.di.decorator import DI_ATTRIBUTE_SERVICE, service as service_decorator
from selva.di.decorator import DI_ATTRIBUTE_SERVICE
from selva.di.decorator import service as service_decorator
from selva.di.error import (
DependencyLoopError,
NonInjectableTypeError,
ServiceNotFoundError,
ServiceWithoutDecoratorError,
NonInjectableTypeError,
ServiceWithResourceDependencyError,
)
from selva.di.inspect import is_resource, is_service
from selva.di.interceptor import Interceptor
from selva.di.service.model import InjectableType, ServiceDependency, ServiceInfo, ServiceSpec
from selva.di.service.model import InjectableType, ServiceDependency, ServiceSpec
from selva.di.service.parse import get_dependencies, parse_service_spec
from selva.di.service.registry import ServiceRegistry

T = TypeVar("T")


def _check_resource_dependency(service_spec: ServiceSpec, dependencies: dict[str, Any]):
resource_dependencies = [type(d) for d in dependencies.values() if is_resource(d)]
if not service_spec.resource and resource_dependencies:
raise ServiceWithResourceDependencyError(service_spec.service, resource_dependencies)


class Container:
def __init__(self):
self.registry = ServiceRegistry()
self.store: dict[tuple[type, str | None], Any] = {}
self.finalizers: list[Awaitable] = []
self.startup: list[tuple[Type, str | None]] = []
self.interceptors: list[Type[Interceptor]] = []
self.resource_store: dict[Any, dict[tuple[type, str | None], Any]] = defaultdict(dict)
self.resource_finalizers: dict[int, list[Awaitable]] = defaultdict(list)

def register(self, injectable: InjectableType):
service_info = getattr(injectable, DI_ATTRIBUTE_SERVICE, None)

if not service_info:
if inspect.isfunction(injectable) or inspect.isclass(injectable):
raise ServiceWithoutDecoratorError(injectable)
else:
raise NonInjectableTypeError(injectable)

kind = "resource" if service_info.resource else "service"
raise NonInjectableTypeError(injectable)

provides, name, startup, resource = service_info
service_spec = parse_service_spec(injectable, provides, name, resource=resource)
provides, name, startup = service_info
service_spec = parse_service_spec(injectable, provides, name)
provided_service = service_spec.provides

self.registry[provided_service, name] = service_spec
Expand All @@ -66,8 +53,7 @@ def register(self, injectable: InjectableType):

if provides:
logger.trace(
"{} registered: {}.{}; provides={}.{} name={}",
kind,
"service registered: {}.{}; provides={}.{} name={}",
injectable.__module__,
injectable.__qualname__,
provides.__module__,
Expand All @@ -76,8 +62,7 @@ def register(self, injectable: InjectableType):
)
else:
logger.trace(
"{} registered: {}.{}; name={}",
kind,
"service registered: {}.{}; name={}",
injectable.__module__,
injectable.__qualname__,
name or "",
Expand Down Expand Up @@ -123,42 +108,42 @@ def iter_service(
for name, definition in record.providers.items():
yield definition.service, name

def iter_all_services(self) -> Iterable[tuple[type, type | FunctionType | None, str | None]]:
def iter_all_services(
self,
) -> Iterable[tuple[type, type | FunctionType | None, str | None]]:
for interface, record in self.registry.services.items():
for name, definition in record.providers.items():
yield interface, definition.service, name

async def get(self, service_type: T, *, name: str = None, optional=False, context: Any = None) -> T:
async def get(self, service_type: T, *, name: str = None, optional=False) -> T:
dependency = ServiceDependency(service_type, name=name, optional=optional)
return await self._get(dependency, context=context)
return await self._get(dependency)

async def create(self, service_type: type) -> Any:
assert inspect.isclass(service_type)
instance = service_type()

for name, dep_spec in get_dependencies(service_type):
dependency = await self._get(dep_spec, context=None)
dependency = await self._get(dep_spec)
setattr(instance, name, dependency)

return instance

def _get_from_cache(self, service_type: type, name: str | None, context: Any = None) -> Any | None:
store = self.resource_store[id(context)] if context else self.store
if instance := store.get((service_type, name)):
def _get_from_cache(self, service_type: type, name: str | None) -> Any | None:
if instance := self.store.get((service_type, name)):
return instance

return None

async def _get(
self,
dependency: ServiceDependency,
context: Any,
stack: list = None,
) -> Any | None:
service_type, name = dependency.service, dependency.name

# check if service exists in cache
if instance := self._get_from_cache(service_type, name, context):
if instance := self._get_from_cache(service_type, name):
return instance

try:
Expand All @@ -176,83 +161,73 @@ async def _get(
raise DependencyLoopError(stack + [service_type])

stack.append(service_type)
instance = await self._create_service(service_spec, context, stack)
instance = await self._create_service(service_spec, stack)
stack.pop()

return instance

async def _get_dependent_services(self, service_spec: ServiceSpec, context: Any, stack: list) -> dict[str, Any]:
deps = await asyncio.gather(*[self._get(d, context, stack) for _, d in service_spec.dependencies])
async def _get_dependent_services(
self, service_spec: ServiceSpec, stack: list
) -> dict[str, Any]:
deps = await asyncio.gather(
*[self._get(d, stack) for _, d in service_spec.dependencies]
)
names = [n for n, _ in service_spec.dependencies]
return dict(zip(names, deps))

async def _create_service(
self,
service_spec: ServiceSpec,
context: Any,
stack: list[type],
) -> Any:
name = service_spec.name

# check if service exists in cache
if instance := self._get_from_cache(service_spec.provides, name, context):
if instance := self._get_from_cache(service_spec.provides, name):
return instance

if factory := service_spec.factory:
dependencies = await self._get_dependent_services(service_spec, context, stack)
_check_resource_dependency(service_spec, dependencies)
dependencies = await self._get_dependent_services(service_spec, stack)

instance = await maybe_async(factory, **dependencies)
if inspect.isgenerator(instance):
generator = instance
instance = await asyncio.to_thread(next, generator)
self._setup_generator_finalizer(generator, context)
self._setup_generator_finalizer(generator)
elif inspect.isasyncgen(instance):
generator = instance
instance = await anext(generator)
self._setup_asyncgen_finalizer(generator, context)
self._setup_asyncgen_finalizer(generator)

if service_spec.resource:
self.resource_store[id(context)][service_spec.provides, name] = instance
else:
self.store[service_spec.provides, name] = instance
self.store[service_spec.provides, name] = instance
else:
instance = service_spec.service()
self.store[service_spec.provides, name] = instance

if service_spec.resource:
self.resource_store[id(context)][service_spec.provides, name] = instance
else:
self.store[service_spec.provides, name] = instance

dependencies = await self._get_dependent_services(service_spec, context, stack)
_check_resource_dependency(service_spec, dependencies)
dependencies = await self._get_dependent_services(service_spec, stack)

for name, dep_service in dependencies.items():
setattr(instance, name, dep_service)

if initializer := service_spec.initializer:
await maybe_async(initializer, instance)

self._setup_finalizer(service_spec, instance, context)
self._setup_finalizer(service_spec, instance)

if service_spec.provides is not Interceptor:
await self._run_interceptors(instance, service_spec.provides)

self._setup_resources(instance)
return instance

def _setup_finalizer(self, service_spec: ServiceSpec, instance: Any, context: Any = None):
def _setup_finalizer(self, service_spec: ServiceSpec, instance: Any):
if finalizer := service_spec.finalizer:
finalizer_list = self.resource_finalizers[id(context)] if context else self.finalizers
finalizer_list.append(maybe_async(finalizer, instance))
self.finalizers.append(maybe_async(finalizer, instance))

def _setup_generator_finalizer(self, gen: Generator, context: Any = None):
finalizer_list = self.resource_finalizers[id(context)] if context else self.finalizers
finalizer_list.append(asyncio.to_thread(next, gen, None))
def _setup_generator_finalizer(self, gen: Generator):
self.finalizers.append(asyncio.to_thread(next, gen, None))

def _setup_asyncgen_finalizer(self, gen: AsyncGenerator, context: Any = None):
finalizer_list = self.resource_finalizers[id(context)] if context else self.finalizers
finalizer_list.append(anext(gen, None))
def _setup_asyncgen_finalizer(self, gen: AsyncGenerator):
self.finalizers.append(anext(gen, None))

async def _run_interceptors(self, instance: Any, service_type: type):
for cls in self.interceptors:
Expand All @@ -265,38 +240,8 @@ async def _run_startup(self):
for service, name in self.startup:
await self.get(service, name=name)

async def _run_finalizers(self, context: Any = None):
finalizers_list = self.resource_finalizers[id(context)] if context else self.finalizers

for finalizer in reversed(finalizers_list):
async def _run_finalizers(self):
for finalizer in reversed(self.finalizers):
await finalizer

if context:
del self.resource_finalizers[id(context)]
del self.resource_store[id(context)]
else:
finalizers_list.clear()

def _setup_resources(self, instance: Any):
for method_name, method in inspect.getmembers(instance, lambda m: inspect.ismethod(m)):
signature = inspect.signature(method)

for param_name, param in signature.parameters.items():
if param.default is not ...:
continue

if not is_resource(param.annotation):
continue

assert inspect.iscoroutinefunction(method)

ioc = self

@functools.wraps(method)
async def wrapper(*args, **kwargs):
resource_instance = await ioc.get(param.annotation)
kwargs[param_name] = resource_instance

return await method.__call__(*args, **kwargs)

setattr(instance, method_name, wrapper)
self.finalizers.clear()
34 changes: 6 additions & 28 deletions src/selva/di/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from selva.di.inject import Inject
from selva.di.service.model import InjectableType, ServiceInfo

__all__ = ("resource", "service", "DI_ATTRIBUTE_SERVICE")
__all__ = ("service", "DI_ATTRIBUTE_SERVICE")

DI_ATTRIBUTE_SERVICE = "__selva_di_service__"

Expand All @@ -29,16 +29,12 @@ def _service(
attribute_name: str,
attribute_value,
) -> InjectableType:
setattr(
injectable, attribute_name, attribute_value
)
setattr(injectable, attribute_name, attribute_value)

if inspect.isclass(injectable):
dependencies = [
dependency
for dependency, annotation in inspect.get_annotations(
injectable
).items()
for dependency, annotation in inspect.get_annotations(injectable).items()
if _is_inject(annotation)
]

Expand Down Expand Up @@ -93,26 +89,8 @@ def service(
"""

def inner(inner_injectable) -> T:
return _service(inner_injectable, DI_ATTRIBUTE_SERVICE, ServiceInfo(provides, name, startup))

return inner(injectable) if injectable else inner


@dataclass_transform(eq_default=False)
def resource(
injectable: Callable[[], T] = None,
/,
*,
provides: type = None,
name: str = None,
) -> T | Callable[[T], T]:
"""Declare a class or function as a resource
For classes, a constructor will be generated to help create instances
outside the dependency injection context
"""

def inner(inner_injectable) -> T:
return _service(inner_injectable, DI_ATTRIBUTE_SERVICE, ServiceInfo(provides, name, resource=True))
return _service(
inner_injectable, DI_ATTRIBUTE_SERVICE, ServiceInfo(provides, name, startup)
)

return inner(injectable) if injectable else inner
8 changes: 0 additions & 8 deletions src/selva/di/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,3 @@ def __init__(self, service: InjectableType):
f"service {service.__module__}.{service.__qualname__}"
" must be decorated with @service"
)


class ServiceWithResourceDependencyError(DependencyInjectionError):
def __init__(self, service: InjectableType, resource_dependencies: list[type]):
super().__init__(
f"service {service.__module__}.{service.__qualname__}"
f" depends on resources: {', '.join(f'{d.__module__}.{d.__qualname__}' for d in resource_dependencies)}"
)
15 changes: 0 additions & 15 deletions src/selva/di/inspect.py

This file was deleted.

2 changes: 0 additions & 2 deletions src/selva/di/service/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ class ServiceInfo(NamedTuple):
provides: type | None
name: str | None
startup: bool = False
resource: bool = False


class ServiceDependency(NamedTuple):
Expand All @@ -26,4 +25,3 @@ class ServiceSpec(NamedTuple):
dependencies: list[tuple[str, ServiceDependency]]
initializer: Callable = None
finalizer: Callable = None
resource: bool = False
Loading

0 comments on commit cffc090

Please sign in to comment.