Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor/di #42

Merged
merged 4 commits into from
Mar 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,7 @@ optional = true

[tool.poetry.group.test.dependencies]
pytest = "^8"
# TODO: remove "allow-prereleases" once pytest-asyncio supports pytest 8
pytest-asyncio = { version = "^0.23", allow-prereleases = true }
pytest-asyncio = "^0.23"
pytest-cov = "^4"
coverage = { version = "^7", extras = ["toml"] }
httpx = "^0.26"
Expand Down
5 changes: 1 addition & 4 deletions src/selva/_util/pydantic/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1 @@
# flake8: noqa: F401
# ruff: noqa: F401

from .dotted_path import DottedPath
from .dotted_path import DottedPath # noqa: F401
5 changes: 1 addition & 4 deletions src/selva/configuration/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1 @@
# flake8: noqa: F401
# ruff: noqa: F401

from selva.configuration.settings import Settings
from selva.configuration.settings import Settings # noqa: F401
105 changes: 54 additions & 51 deletions src/selva/di/container.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import inspect
from collections import defaultdict
from collections.abc import AsyncGenerator, Awaitable, Generator, Iterable
from types import FunctionType, ModuleType
from typing import Any, Type, TypeVar
Expand All @@ -9,8 +10,10 @@
from selva._util.maybe_async import maybe_async
from selva._util.package_scan import scan_packages
from selva.di.decorator import DI_ATTRIBUTE_SERVICE
from selva.di.decorator import service as service_decorator
from selva.di.error import (
DependencyLoopError,
NonInjectableTypeError,
ServiceNotFoundError,
ServiceWithoutDecoratorError,
)
Expand All @@ -30,28 +33,17 @@ def __init__(self):
self.startup: list[tuple[Type, str | None]] = []
self.interceptors: list[Type[Interceptor]] = []

def register(
self,
service: InjectableType,
*,
provides: type = None,
name: str = None,
startup: bool = False,
):
self._register_service_spec(service, provides, name, startup)

def service(self, service: type):
service_info = getattr(service, DI_ATTRIBUTE_SERVICE, None)
def register(self, injectable: InjectableType):
service_info = getattr(injectable, DI_ATTRIBUTE_SERVICE, None)

if not service_info:
raise ServiceWithoutDecoratorError(service)
if inspect.isfunction(injectable) or inspect.isclass(injectable):
raise ServiceWithoutDecoratorError(injectable)

self._register_service_spec(service, *service_info)
raise NonInjectableTypeError(injectable)

def _register_service_spec(
self, service: type, provides: type | None, name: str | None, startup: bool
):
service_spec = parse_service_spec(service, provides, name)
provides, name, startup = service_info
service_spec = parse_service_spec(injectable, provides, name)
provided_service = service_spec.provides

self.registry[provided_service, name] = service_spec
Expand All @@ -61,18 +53,18 @@ def _register_service_spec(

if provides:
logger.trace(
"service registered: {}.{}; provided-by={}.{} name={}",
"service registered: {}.{}; provides={}.{} name={}",
injectable.__module__,
injectable.__qualname__,
provides.__module__,
provides.__qualname__,
service.__module__,
service.__qualname__,
name or "",
)
else:
logger.trace(
"service registered: {}.{}; name={}",
service.__module__,
service.__qualname__,
injectable.__module__,
injectable.__qualname__,
name or "",
)

Expand All @@ -85,9 +77,11 @@ def define(self, service_type: type, instance: Any, *, name: str = None):

def interceptor(self, interceptor: Type[Interceptor]):
self.register(
interceptor,
provides=Interceptor,
name=f"{interceptor.__module__}.{interceptor.__qualname__}",
service_decorator(
interceptor,
provides=Interceptor,
name=f"{interceptor.__module__}.{interceptor.__qualname__}",
)
)
self.interceptors.append(interceptor)

Expand All @@ -97,11 +91,11 @@ def scan(self, *packages: str | ModuleType):
def predicate_services(item: Any):
return hasattr(item, DI_ATTRIBUTE_SERVICE)

for service in scan_packages(packages, predicate_services):
self.service(service)
for found_service in scan_packages(packages, predicate_services):
self.register(found_service)

def has(self, service: type, name: str = None) -> bool:
definition = self.registry.get(service, name=name)
def has(self, service_type: type, name: str = None) -> bool:
definition = self.registry.get(service_type, name=name)
return definition is not None

def iter_service(
Expand All @@ -121,20 +115,21 @@ def iter_all_services(
for name, definition in record.providers.items():
yield interface, definition.service, name

async def get(self, service: T, *, name: str = None, optional=False) -> T:
dependency = ServiceDependency(service, name=name, optional=optional)
async def get(self, service_type: T, *, name: str = None, optional=False) -> T:
dependency = ServiceDependency(service_type, name=name, optional=optional)
return await self._get(dependency)

async def create(self, service: type) -> Any:
instance = service()
for name, dep_spec in get_dependencies(service):
async def create(self, service_type: type) -> Any:
assert inspect.isclass(service_type)
instance = service_type()

for name, dep_spec in get_dependencies(service_type):
dependency = await self._get(dep_spec)
setattr(instance, name, dependency)

return instance

def _get_from_cache(self, service_type: type, name: str | None) -> Any | None:
# try getting service from store
if instance := self.store.get((service_type, name)):
return instance

Expand All @@ -147,11 +142,14 @@ async def _get(
) -> Any | None:
service_type, name = dependency.service, dependency.name

# check if service exists in cache
if instance := self._get_from_cache(service_type, name):
return instance

try:
service_spec = self.registry[service_type, name]
service_spec = self.registry.get(service_type, name)
if not service_spec:
raise ServiceNotFoundError(service_type, name)
except ServiceNotFoundError:
if dependency.optional:
return None
Expand All @@ -168,21 +166,28 @@ async def _get(

return instance

async def _get_dependent_services(
self, service_spec: ServiceSpec, stack: list
) -> dict[str, Any]:
deps = await asyncio.gather(
*[self._get(d, stack) for _, d in service_spec.dependencies]
)
names = [n for n, _ in service_spec.dependencies]
return dict(zip(names, deps))

async def _create_service(
self,
service_spec: ServiceSpec,
stack: list[type],
) -> Any:
# check if service exists in cache
name = service_spec.name

# check if service exists in cache
if instance := self._get_from_cache(service_spec.provides, name):
return instance

if factory := service_spec.factory:
dependencies = {
name: await self._get(dep, stack)
for name, dep in service_spec.dependencies
}
dependencies = await self._get_dependent_services(service_spec, stack)

instance = await maybe_async(factory, **dependencies)
if inspect.isgenerator(instance):
Expand All @@ -199,23 +204,21 @@ async def _create_service(
instance = service_spec.service()
self.store[service_spec.provides, name] = instance

for name, dep_service in service_spec.dependencies:
dependency = await self._get(dep_service, stack)
setattr(instance, name, dependency)
dependencies = await self._get_dependent_services(service_spec, stack)

for name, dep_service in dependencies.items():
setattr(instance, name, dep_service)

if initializer := service_spec.initializer:
await maybe_async(initializer, instance)

await self._run_initializer(service_spec, instance)
self._setup_finalizer(service_spec, instance)

if service_spec.provides is not Interceptor:
await self._run_interceptors(instance, service_spec.provides)

return instance

@staticmethod
async def _run_initializer(service_spec: ServiceSpec, instance: Any):
if initializer := service_spec.initializer:
await maybe_async(initializer, instance)

def _setup_finalizer(self, service_spec: ServiceSpec, instance: Any):
if finalizer := service_spec.finalizer:
self.finalizers.append(maybe_async(finalizer, instance))
Expand Down
96 changes: 51 additions & 45 deletions src/selva/di/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,55 @@ def _is_inject(value) -> bool:
return isinstance(args[1], Inject) or args[1] is Inject


def _service(
injectable: InjectableType,
attribute_name: str,
attribute_value,
) -> InjectableType:
setattr(injectable, attribute_name, attribute_value)

if inspect.isclass(injectable):
dependencies = [
dependency
for dependency, annotation in inspect.get_annotations(injectable).items()
if _is_inject(annotation)
]

# save a reference to the original constructor
original_init = getattr(injectable, "__init__", None)

def init(self, *args, **kwargs):
"""Generated init method for service

Positional and keyword arguments will be set to declared dependencies.
Dependencies without an argument to set their value will be None.
Remaining arguments will be ignored.
"""

# call original constructor
if original_init:
original_init(self)

positional_params = [d for d in dependencies if d not in kwargs]

# set positional argument values
values = dict(zip(positional_params, args))

# set keyword argument values
values |= {k: v for k, v in kwargs.items() if k in dependencies}

# the rest of the dependencies will be set to None
param_keys = values.keys()
values |= {d: None for d in dependencies if d not in param_keys}

for k, v in values.items():
setattr(self, k, v)

setattr(injectable, "__init__", init)

return injectable


@dataclass_transform(eq_default=False)
def service(
injectable: T = None,
Expand All @@ -39,52 +88,9 @@ def service(
outside the dependency injection context
"""

def inner(inner_injectable: InjectableType) -> T:
setattr(
def inner(inner_injectable) -> T:
return _service(
inner_injectable, DI_ATTRIBUTE_SERVICE, ServiceInfo(provides, name, startup)
)

if inspect.isclass(inner_injectable):
dependencies = [
dependency
for dependency, annotation in inspect.get_annotations(
inner_injectable
).items()
if _is_inject(annotation)
]

# save a reference to the original constructor
original_init = getattr(inner_injectable, "__init__", None)

def init(self, *args, **kwargs):
"""Generated init method for service

Positional and keyword arguments will be set to declared dependencies.
Dependencies without an argument to set their value will be None.
Remaining arguments will be ignored.
"""

# call original constructor
if original_init:
original_init(self)

positional_params = [d for d in dependencies if d not in kwargs]

# set positional argument values
values = dict(zip(positional_params, args))

# set keyword argument values
values |= {k: v for k, v in kwargs.items() if k in dependencies}

# the rest of the dependencies will be set to None
param_keys = values.keys()
values |= {d: None for d in dependencies if d not in param_keys}

for k, v in values.items():
setattr(self, k, v)

setattr(inner_injectable, "__init__", init)

return inner_injectable

return inner(injectable) if injectable else inner
2 changes: 1 addition & 1 deletion src/selva/di/service/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class ServiceDependency(NamedTuple):
class ServiceSpec(NamedTuple):
service: type
provides: type
factory: FunctionType
factory: FunctionType | None
name: str | None
dependencies: list[tuple[str, ServiceDependency]]
initializer: Callable = None
Expand Down
5 changes: 3 additions & 2 deletions src/selva/di/service/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,10 @@ def __init__(self):
self.services: dict[type, ServiceRecord] = defaultdict(ServiceRecord)

def get(self, key: type, name: str = None) -> ServiceSpec | None:
if (key, name) not in self:
try:
return self[key, name]
except ServiceNotFoundError:
return None
return self[key, name]

def __getitem__(self, key: type | tuple[type, str]):
inner_key, name = _get_key_with_name(key)
Expand Down
3 changes: 2 additions & 1 deletion src/selva/ext/data/redis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from selva.configuration.settings import Settings
from selva.di.container import Container
from selva.di.decorator import service as service_decorator

from .service import make_service

Expand All @@ -13,4 +14,4 @@ def selva_extension(container: Container, settings: Settings):
for name in settings.data.redis:
service_name = name if name != "default" else None

container.register(make_service(name), name=service_name)
container.register(service_decorator(make_service(name), name=service_name))
2 changes: 2 additions & 0 deletions src/selva/ext/data/redis/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ def build_backoff(data: BackoffSchema) -> AbstractBackoff:
if value := data.decorrelated_jitter:
return DecorrelatedJitterBackoff(**value.model_dump(exclude_unset=True))

raise ValueError("No value defined for 'backoff'")


def build_retry(data: RetrySchema):
kwargs = {
Expand Down
Loading