diff --git a/README.md b/README.md index 50fce92..578b39e 100644 --- a/README.md +++ b/README.md @@ -47,12 +47,12 @@ pip install deps-injection ``` ## Compatibility between web frameworks and injection features -| Framework | Dependency injection with @inject | Overriding providers | Dependency injection with @autoinject (_experimental_) | -|--------------------------------------------------------------------------|:---------------------------------:|:--------------------:|:------------------------------------------------------:| -| [FastAPI](https://github.com/fastapi/fastapi) | ✅ | ✅ | ➖ | -| [Flask](https://github.com/pallets/flask) | ✅ | ✅ | ✅ | -| [Django REST Framework](https://github.com/encode/django-rest-framework) | ✅ | ✅ | ✅ | -| [Litestar](https://github.com/litestar-org/litestar) | ✅ | ⚠️ | ➖ | ➖ | +| Framework | Dependency injection with @inject | Overriding providers | Dependency injection with @autoinject | +|--------------------------------------------------------------------------|:---------------------------------:|:--------------------:|:-------------------------------------------:| +| [FastAPI](https://github.com/fastapi/fastapi) | ✅ | ✅ | ➖ | +| [Flask](https://github.com/pallets/flask) | ✅ | ✅ | ✅ | +| [Django REST Framework](https://github.com/encode/django-rest-framework) | ✅ | ✅ | ✅ | +| [Litestar](https://github.com/litestar-org/litestar) | ✅ | ⚠️ | ➖ | ➖ | ## Using example with FastAPI diff --git a/src/injection/base_container.py b/src/injection/base_container.py index 1e8848a..5b4950d 100644 --- a/src/injection/base_container.py +++ b/src/injection/base_container.py @@ -3,7 +3,10 @@ from contextlib import contextmanager from typing import Any, Callable, Dict, Iterator, List, Optional, Type, TypeVar, cast -from injection.inject.exceptions import DuplicatedFactoryTypeAutoInjectionError +from injection.inject.exceptions import ( + DuplicatedFactoryTypeAutoInjectionError, + UnknownProviderTypeAutoInjectionError, +) from injection.providers import Singleton from injection.providers.base import BaseProvider from injection.providers.base_factory import BaseFactoryProvider @@ -125,5 +128,4 @@ def resolve_by_type(cls, type_: Type[Any]) -> Any: if type_ is provider.factory: return provider() - msg = f"Provider with type {type_!s} not found" - raise Exception(msg) + raise UnknownProviderTypeAutoInjectionError(str(type_)) diff --git a/src/injection/inject/auto_inject.py b/src/injection/inject/auto_inject.py index 9a0d994..e94108c 100644 --- a/src/injection/inject/auto_inject.py +++ b/src/injection/inject/auto_inject.py @@ -4,8 +4,6 @@ from typing import Any, Callable, Coroutine, Dict, Optional, Type, TypeVar, Union, cast from injection.base_container import DeclarativeContainer -from injection.inject.exceptions import DuplicatedFactoryTypeAutoInjectionError -from injection.inject.inject import _resolve_markers from injection.provide import Provide if sys.version_info < (3, 10): @@ -19,59 +17,29 @@ _ContainerType = Union[Type[DeclarativeContainer], DeclarativeContainer] -def _resolve_signature_args_with_types_from_container( - *, - signature: inspect.Signature, - target_container: _ContainerType, -) -> Dict[str, Any]: - resolved_signature_typed_args = {} - - for param_name, param in signature.parameters.items(): - if not (param.annotation is not param.empty and param.default is param.empty): - continue - - try: - resolved = target_container.resolve_by_type(param.annotation) - resolved_signature_typed_args[param_name] = resolved - except DuplicatedFactoryTypeAutoInjectionError: - raise - - # Ignore exceptions for cases for example django rest framework - # endpoint may have parameter 'request' - we don't know how to handle a variety of parameters. - # But anyway, after this the runtime will fail with an error if something goes wrong - except Exception: # noqa: S112 - continue - - return resolved_signature_typed_args - - def _get_sync_injected( *, f: Callable[P, T], - markers: Markers, signature: inspect.Signature, target_container: _ContainerType, ) -> Callable[P, T]: @wraps(f) def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: - resolved_signature_typed_args = ( - _resolve_signature_args_with_types_from_container( - signature=signature, - target_container=target_container, - ) - ) + resolved_signature_typed_args = {} + + for i, (param_name, param) in enumerate(signature.parameters.items()): + if i < len(args) or param_name in kwargs: + continue + + if not ( + param.annotation is not param.empty and param.default is param.empty + ): + continue - provide_markers = { - k: v - for k, v in kwargs.items() - if k not in markers and isinstance(v, Provide) - } - provide_markers.update(markers) - resolved_values = _resolve_markers(provide_markers) + resolved = target_container.resolve_by_type(param.annotation) + resolved_signature_typed_args[param_name] = resolved - kwargs.update(resolved_values) - kwargs.update(resolved_signature_typed_args) - return f(*args, **kwargs) + return f(*args, **resolved_signature_typed_args, **kwargs) return wrapper @@ -79,30 +47,26 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: def _get_async_injected( *, f: Callable[P, Coroutine[Any, Any, T]], - markers: Markers, signature: inspect.Signature, target_container: _ContainerType, ) -> Callable[P, Coroutine[Any, Any, T]]: @wraps(f) async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: - resolved_signature_typed_args = ( - _resolve_signature_args_with_types_from_container( - signature=signature, - target_container=target_container, - ) - ) + resolved_signature_typed_args = {} + + for i, (param_name, param) in enumerate(signature.parameters.items()): + if i < len(args) or param_name in kwargs: + continue + + if not ( + param.annotation is not param.empty and param.default is param.empty + ): + continue - provide_markers = { - k: v - for k, v in kwargs.items() - if k not in markers and isinstance(v, Provide) - } - provide_markers.update(markers) - resolved_values = _resolve_markers(provide_markers) + resolved = target_container.resolve_by_type(param.annotation) + resolved_signature_typed_args[param_name] = resolved - kwargs.update(resolved_values) - kwargs.update(resolved_signature_typed_args) - return await f(*args, **kwargs) + return await f(*args, **resolved_signature_typed_args, **kwargs) return wrapper @@ -126,18 +90,10 @@ def auto_inject( target_container = container_subclasses[0] signature = inspect.signature(f) - parameters = signature.parameters - - markers = { - parameter_name: parameter_value.default - for parameter_name, parameter_value in parameters.items() - if isinstance(parameter_value.default, Provide) - } if inspect.iscoroutinefunction(f): func_with_injected_params = _get_async_injected( f=f, - markers=markers, signature=signature, target_container=target_container, ) @@ -145,7 +101,6 @@ def auto_inject( return _get_sync_injected( f=f, - markers=markers, signature=signature, target_container=target_container, ) diff --git a/src/injection/inject/exceptions.py b/src/injection/inject/exceptions.py index 6b02e60..e015e2d 100644 --- a/src/injection/inject/exceptions.py +++ b/src/injection/inject/exceptions.py @@ -5,3 +5,9 @@ def __init__(self, type_: str) -> None: f"more than one provider for type '{type_}'" ) super().__init__(message) + + +class UnknownProviderTypeAutoInjectionError(Exception): + def __init__(self, type_: str) -> None: + message = f"Provider with type {type_!r} not found" + super().__init__(message) diff --git a/tests/test_auto_inject.py b/tests/test_auto_inject.py index 09c753c..744ec5d 100644 --- a/tests/test_auto_inject.py +++ b/tests/test_auto_inject.py @@ -3,7 +3,10 @@ import pytest from injection import DeclarativeContainer, auto_inject -from injection.inject.exceptions import DuplicatedFactoryTypeAutoInjectionError +from injection.inject.exceptions import ( + DuplicatedFactoryTypeAutoInjectionError, + UnknownProviderTypeAutoInjectionError, +) from tests.container_objects import Redis, Service, SomeService @@ -64,3 +67,56 @@ async def test_auto_inject_expect_error_on_duplicated_provider_types(container): ): with pytest.raises(DuplicatedFactoryTypeAutoInjectionError): await _async_func(a=234, b="rnd") + + +def test_auto_injection_with_args_overriding(container) -> None: + @auto_inject + def _inner( + arg1: bool, # noqa: FBT001 + arg2: Service, + arg3: int = 100, + ) -> None: + _ = arg1 + _ = arg3 + original_obj = container.service() + assert arg2.a != original_obj.a + assert arg2.b != original_obj.b + + _inner(True, container.service(b="url", a=2000)) # noqa: FBT003 + _inner(arg1=True, arg2=container.service(b="urljyfuf", a=8400)) + _inner(True, arg2=container.service(b="afdsfsf", a=2242)) # noqa: FBT003 + + +async def test_auto_injection_with_args_overriding_async(container) -> None: + @auto_inject + async def _inner( + arg1: bool, # noqa: FBT001 + arg2: Service, + arg3: int = 100, + ) -> int: + _ = arg1 + _ = arg3 + original_obj = container.service() + assert arg2.a != original_obj.a + assert arg2.b != original_obj.b + return arg3 + + assert await _inner(True, container.service(b="url", a=2000)) == 100 # noqa: FBT003 + assert await _inner(arg1=True, arg2=container.service(b="url", a=2000)) == 100 + assert await _inner(True, arg2=container.service(b="url", a=2000)) == 100 # noqa: FBT003 + + +def test_auto_injection_expect_error_on_unknown_provider(): + @auto_inject + def inner(_: object): ... + + with pytest.raises(UnknownProviderTypeAutoInjectionError): + inner() + + +async def test_auto_injection_expect_error_on_unknown_provider_async(): + @auto_inject + async def inner(_: object): ... + + with pytest.raises(UnknownProviderTypeAutoInjectionError): + await inner() diff --git a/tests/test_inject.py b/tests/test_inject.py index ed42b1e..d759d84 100644 --- a/tests/test_inject.py +++ b/tests/test_inject.py @@ -7,8 +7,10 @@ def test_injection_with_args_overriding(container) -> None: def _inner( arg1: bool, # noqa: FBT001 arg2: Service = Provide[container.service], + arg3: int = 100, ) -> None: _ = arg1 + _ = arg3 original_obj = container.service() assert arg2.a != original_obj.a assert arg2.b != original_obj.b @@ -23,8 +25,10 @@ async def test_injection_with_args_overriding_async(container) -> None: async def _inner( arg1: bool, # noqa: FBT001 arg2: Service = Provide[container.service], + arg3: int = 100, ) -> None: _ = arg1 + _ = arg3 original_obj = container.service() assert arg2.a != original_obj.a assert arg2.b != original_obj.b @@ -32,3 +36,17 @@ async def _inner( await _inner(True, container.service(b="url", a=2000)) # noqa: FBT003 await _inner(arg1=True, arg2=container.service(b="url", a=2000)) await _inner(True, arg2=container.service(b="url", a=2000)) # noqa: FBT003 + + +async def test_injection_async(container) -> None: + @inject + async def _inner( + arg1: bool, # noqa: FBT001 + arg2: Service = Provide[container.service], + arg3: int = 100, + ) -> None: + assert arg1 + assert isinstance(arg2, Service) + assert arg3 == 100 + + await _inner(True) # noqa: FBT003