Skip to content

Commit

Permalink
tests for auto injecting
Browse files Browse the repository at this point in the history
  • Loading branch information
nightblure committed Nov 18, 2024
1 parent 4e32d8e commit bc7b55a
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 81 deletions.
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,12 @@ pip install deps-injection
```

## Compatibility between web frameworks and injection features
| Framework | Dependency injection with @inject | Overriding providers | Dependency injection with @autoinject (_experimental_) |
|--------------------------------------------------------------------------|:---------------------------------:|:--------------------:|:------------------------------------------------------:|
| [FastAPI](https://github.com/fastapi/fastapi) ||| |
| [Flask](https://github.com/pallets/flask) ||| |
| [Django REST Framework](https://github.com/encode/django-rest-framework) ||| |
| [Litestar](https://github.com/litestar-org/litestar) || ⚠️ | ||
| Framework | Dependency injection with @inject | Overriding providers | Dependency injection with @autoinject |
|--------------------------------------------------------------------------|:---------------------------------:|:--------------------:|:-------------------------------------------:|
| [FastAPI](https://github.com/fastapi/fastapi) ||| |
| [Flask](https://github.com/pallets/flask) ||| |
| [Django REST Framework](https://github.com/encode/django-rest-framework) ||| |
| [Litestar](https://github.com/litestar-org/litestar) || ⚠️ | ||


## Using example with FastAPI
Expand Down
8 changes: 5 additions & 3 deletions src/injection/base_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
from contextlib import contextmanager
from typing import Any, Callable, Dict, Iterator, List, Optional, Type, TypeVar, cast

from injection.inject.exceptions import DuplicatedFactoryTypeAutoInjectionError
from injection.inject.exceptions import (
DuplicatedFactoryTypeAutoInjectionError,
UnknownProviderTypeAutoInjectionError,
)
from injection.providers import Singleton
from injection.providers.base import BaseProvider
from injection.providers.base_factory import BaseFactoryProvider
Expand Down Expand Up @@ -125,5 +128,4 @@ def resolve_by_type(cls, type_: Type[Any]) -> Any:
if type_ is provider.factory:
return provider()

msg = f"Provider with type {type_!s} not found"
raise Exception(msg)
raise UnknownProviderTypeAutoInjectionError(str(type_))
97 changes: 26 additions & 71 deletions src/injection/inject/auto_inject.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
from typing import Any, Callable, Coroutine, Dict, Optional, Type, TypeVar, Union, cast

from injection.base_container import DeclarativeContainer
from injection.inject.exceptions import DuplicatedFactoryTypeAutoInjectionError
from injection.inject.inject import _resolve_markers
from injection.provide import Provide

if sys.version_info < (3, 10):
Expand All @@ -19,90 +17,56 @@
_ContainerType = Union[Type[DeclarativeContainer], DeclarativeContainer]


def _resolve_signature_args_with_types_from_container(
*,
signature: inspect.Signature,
target_container: _ContainerType,
) -> Dict[str, Any]:
resolved_signature_typed_args = {}

for param_name, param in signature.parameters.items():
if not (param.annotation is not param.empty and param.default is param.empty):
continue

try:
resolved = target_container.resolve_by_type(param.annotation)
resolved_signature_typed_args[param_name] = resolved
except DuplicatedFactoryTypeAutoInjectionError:
raise

# Ignore exceptions for cases for example django rest framework
# endpoint may have parameter 'request' - we don't know how to handle a variety of parameters.
# But anyway, after this the runtime will fail with an error if something goes wrong
except Exception: # noqa: S112
continue

return resolved_signature_typed_args


def _get_sync_injected(
*,
f: Callable[P, T],
markers: Markers,
signature: inspect.Signature,
target_container: _ContainerType,
) -> Callable[P, T]:
@wraps(f)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
resolved_signature_typed_args = (
_resolve_signature_args_with_types_from_container(
signature=signature,
target_container=target_container,
)
)
resolved_signature_typed_args = {}

for i, (param_name, param) in enumerate(signature.parameters.items()):
if i < len(args) or param_name in kwargs:
continue

if not (
param.annotation is not param.empty and param.default is param.empty
):
continue

provide_markers = {
k: v
for k, v in kwargs.items()
if k not in markers and isinstance(v, Provide)
}
provide_markers.update(markers)
resolved_values = _resolve_markers(provide_markers)
resolved = target_container.resolve_by_type(param.annotation)
resolved_signature_typed_args[param_name] = resolved

kwargs.update(resolved_values)
kwargs.update(resolved_signature_typed_args)
return f(*args, **kwargs)
return f(*args, **resolved_signature_typed_args, **kwargs)

return wrapper


def _get_async_injected(
*,
f: Callable[P, Coroutine[Any, Any, T]],
markers: Markers,
signature: inspect.Signature,
target_container: _ContainerType,
) -> Callable[P, Coroutine[Any, Any, T]]:
@wraps(f)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
resolved_signature_typed_args = (
_resolve_signature_args_with_types_from_container(
signature=signature,
target_container=target_container,
)
)
resolved_signature_typed_args = {}

for i, (param_name, param) in enumerate(signature.parameters.items()):
if i < len(args) or param_name in kwargs:
continue

if not (
param.annotation is not param.empty and param.default is param.empty
):
continue

provide_markers = {
k: v
for k, v in kwargs.items()
if k not in markers and isinstance(v, Provide)
}
provide_markers.update(markers)
resolved_values = _resolve_markers(provide_markers)
resolved = target_container.resolve_by_type(param.annotation)
resolved_signature_typed_args[param_name] = resolved

kwargs.update(resolved_values)
kwargs.update(resolved_signature_typed_args)
return await f(*args, **kwargs)
return await f(*args, **resolved_signature_typed_args, **kwargs)

return wrapper

Expand All @@ -126,26 +90,17 @@ def auto_inject(
target_container = container_subclasses[0]

signature = inspect.signature(f)
parameters = signature.parameters

markers = {
parameter_name: parameter_value.default
for parameter_name, parameter_value in parameters.items()
if isinstance(parameter_value.default, Provide)
}

if inspect.iscoroutinefunction(f):
func_with_injected_params = _get_async_injected(
f=f,
markers=markers,
signature=signature,
target_container=target_container,
)
return cast(Callable[P, T], func_with_injected_params)

return _get_sync_injected(
f=f,
markers=markers,
signature=signature,
target_container=target_container,
)
6 changes: 6 additions & 0 deletions src/injection/inject/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,9 @@ def __init__(self, type_: str) -> None:
f"more than one provider for type '{type_}'"
)
super().__init__(message)


class UnknownProviderTypeAutoInjectionError(Exception):
def __init__(self, type_: str) -> None:
message = f"Provider with type {type_!r} not found"
super().__init__(message)
58 changes: 57 additions & 1 deletion tests/test_auto_inject.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
import pytest

from injection import DeclarativeContainer, auto_inject
from injection.inject.exceptions import DuplicatedFactoryTypeAutoInjectionError
from injection.inject.exceptions import (
DuplicatedFactoryTypeAutoInjectionError,
UnknownProviderTypeAutoInjectionError,
)
from tests.container_objects import Redis, Service, SomeService


Expand Down Expand Up @@ -64,3 +67,56 @@ async def test_auto_inject_expect_error_on_duplicated_provider_types(container):
):
with pytest.raises(DuplicatedFactoryTypeAutoInjectionError):
await _async_func(a=234, b="rnd")


def test_auto_injection_with_args_overriding(container) -> None:
@auto_inject
def _inner(
arg1: bool, # noqa: FBT001
arg2: Service,
arg3: int = 100,
) -> None:
_ = arg1
_ = arg3
original_obj = container.service()
assert arg2.a != original_obj.a
assert arg2.b != original_obj.b

_inner(True, container.service(b="url", a=2000)) # noqa: FBT003
_inner(arg1=True, arg2=container.service(b="urljyfuf", a=8400))
_inner(True, arg2=container.service(b="afdsfsf", a=2242)) # noqa: FBT003


async def test_auto_injection_with_args_overriding_async(container) -> None:
@auto_inject
async def _inner(
arg1: bool, # noqa: FBT001
arg2: Service,
arg3: int = 100,
) -> int:
_ = arg1
_ = arg3
original_obj = container.service()
assert arg2.a != original_obj.a
assert arg2.b != original_obj.b
return arg3

assert await _inner(True, container.service(b="url", a=2000)) == 100 # noqa: FBT003
assert await _inner(arg1=True, arg2=container.service(b="url", a=2000)) == 100
assert await _inner(True, arg2=container.service(b="url", a=2000)) == 100 # noqa: FBT003


def test_auto_injection_expect_error_on_unknown_provider():
@auto_inject
def inner(_: object): ...

with pytest.raises(UnknownProviderTypeAutoInjectionError):
inner()


async def test_auto_injection_expect_error_on_unknown_provider_async():
@auto_inject
async def inner(_: object): ...

with pytest.raises(UnknownProviderTypeAutoInjectionError):
await inner()
18 changes: 18 additions & 0 deletions tests/test_inject.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ def test_injection_with_args_overriding(container) -> None:
def _inner(
arg1: bool, # noqa: FBT001
arg2: Service = Provide[container.service],
arg3: int = 100,
) -> None:
_ = arg1
_ = arg3
original_obj = container.service()
assert arg2.a != original_obj.a
assert arg2.b != original_obj.b
Expand All @@ -23,12 +25,28 @@ async def test_injection_with_args_overriding_async(container) -> None:
async def _inner(
arg1: bool, # noqa: FBT001
arg2: Service = Provide[container.service],
arg3: int = 100,
) -> None:
_ = arg1
_ = arg3
original_obj = container.service()
assert arg2.a != original_obj.a
assert arg2.b != original_obj.b

await _inner(True, container.service(b="url", a=2000)) # noqa: FBT003
await _inner(arg1=True, arg2=container.service(b="url", a=2000))
await _inner(True, arg2=container.service(b="url", a=2000)) # noqa: FBT003


async def test_injection_async(container) -> None:
@inject
async def _inner(
arg1: bool, # noqa: FBT001
arg2: Service = Provide[container.service],
arg3: int = 100,
) -> None:
assert arg1
assert isinstance(arg2, Service)
assert arg3 == 100

await _inner(True) # noqa: FBT003

0 comments on commit bc7b55a

Please sign in to comment.