diff --git a/README.md b/README.md index f1c4937..a5b27f0 100644 --- a/README.md +++ b/README.md @@ -93,8 +93,7 @@ def create_app(): return app -RedisDependency = Annotated[Redis, Depends(Provide["redis"])] -RedisDependencyExplicit = Annotated[Redis, Depends(Provide[Container.redis])] +RedisDependency = Annotated[Redis, Depends(Provide[Container.redis])] @router.get("/values") @@ -106,7 +105,7 @@ def some_get_endpoint_handler(redis: RedisDependency): @router.post("/values") @inject -async def some_get_async_endpoint_handler(redis: RedisDependencyExplicit): +async def some_get_async_endpoint_handler(redis: RedisDependency): value = redis.get(399) return {"detail": value} diff --git a/docs/testing/provider-overriding.md b/docs/testing/provider-overriding.md index f20f5c1..069edca 100644 --- a/docs/testing/provider-overriding.md +++ b/docs/testing/provider-overriding.md @@ -43,7 +43,7 @@ class DIContainer(DeclarativeContainer): @inject -def exec_query_example(some_sqla_dao=Provide["some_sqla_dao"]): +def exec_query_example(some_sqla_dao=Provide[DIContainer.some_sqla_dao]): with some_sqla_dao: result = some_sqla_dao.exec_query('SELECT 234') diff --git a/src/injection/base_container.py b/src/injection/base_container.py index 58a49ac..c383fd8 100644 --- a/src/injection/base_container.py +++ b/src/injection/base_container.py @@ -38,17 +38,6 @@ def _get_providers_generator(cls) -> Iterator[BaseProvider]: def get_providers(cls) -> List[BaseProvider]: return list(cls.__get_providers().values()) - @classmethod - def get_provider_by_attr_name(cls, provider_name: str) -> BaseProvider: - providers = cls.__get_providers() - provider = providers.get(provider_name) - - if provider_name not in providers: - msg = f"Provider {provider_name!r} not found" - raise Exception(msg) - - return provider - @classmethod @contextmanager def override_providers( diff --git a/src/injection/container_registry.py b/src/injection/container_registry.py deleted file mode 100644 index 1aec0fc..0000000 --- a/src/injection/container_registry.py +++ /dev/null @@ -1,28 +0,0 @@ -from typing import ClassVar, List, Type, Union - -from injection.base_container import DeclarativeContainer - -ContainersType = List[Union[DeclarativeContainer, Type[DeclarativeContainer]]] - - -class ContainerRegistry: - __containers: ClassVar[ContainersType] = [] - - @classmethod - def __get_containers(cls) -> ContainersType: - if not cls.__containers: - cls.__containers = DeclarativeContainer.__subclasses__() - return cls.__containers - - @classmethod - def get_containers_count(cls) -> int: - return len(cls.__get_containers()) - - @classmethod - def get_default_container(cls) -> Type[DeclarativeContainer]: - containers = cls.__get_containers() - if len(containers) == 0: - msg = "You should create at least one container" - raise Exception(msg) - - return containers[0] diff --git a/src/injection/inject.py b/src/injection/inject.py index 76933da..62e5bbc 100644 --- a/src/injection/inject.py +++ b/src/injection/inject.py @@ -3,7 +3,6 @@ from functools import wraps from typing import Any, Callable, Dict, TypeVar, Union -from injection.container_registry import ContainerRegistry from injection.provide import Provide from injection.providers.base import BaseProvider @@ -64,24 +63,11 @@ def _resolve_provide_marker(marker: Provide) -> BaseProvider: marker_provider = marker.provider - if not isinstance(marker_provider, (str, BaseProvider)): - msg = f"Incorrect marker type: {type(marker_provider)!r}. Marker parameter must be either str or BaseProvider." + if not isinstance(marker_provider, BaseProvider): + msg = f"Incorrect marker type: {type(marker_provider)!r}. Marker parameter must be either BaseProvider." raise TypeError(msg) - if isinstance(marker_provider, BaseProvider): - return marker_provider - - containers_count = ContainerRegistry.get_containers_count() - - if isinstance(marker_provider, str): - if containers_count > 1: - msg = "Please specify the container and its provider explicitly" - raise Exception(msg) - - if containers_count == 1: - container = ContainerRegistry.get_default_container() - provider = container.get_provider_by_attr_name(marker_provider) - return provider + return marker_provider def _extract_provider_values_from_markers(markers: Markers) -> Dict[str, Any]: diff --git a/src/injection/provide.py b/src/injection/provide.py index f517398..6373a98 100644 --- a/src/injection/provide.py +++ b/src/injection/provide.py @@ -1,4 +1,4 @@ -from typing import Generic, TypeVar, Union +from typing import Generic, TypeVar from injection.providers.base import BaseProvider @@ -6,12 +6,12 @@ class ClassGetItemMeta(Generic[T], type): - def __getitem__(cls, item: Union[str, BaseProvider[T]]) -> T: + def __getitem__(cls, item: BaseProvider[T]) -> T: return cls(item) class Provide(metaclass=ClassGetItemMeta): - def __init__(self, provider: Union[str, BaseProvider[T]]) -> None: + def __init__(self, provider: BaseProvider[T]) -> None: self.provider = provider def __call__(self) -> T: diff --git a/tests/conftest.py b/tests/conftest.py index c53cdd0..af3c561 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,20 +1,8 @@ import pytest -from injection.container_registry import ContainerRegistry from tests.container_objects import Container -@pytest.fixture(scope="session") -def container_registry(): - return ContainerRegistry - - @pytest.fixture(scope="session") def container(): return Container - - -# need this because container registry is singleton -# @pytest.fixture(autouse=True) -# def _force_clean_container_registry(container): -# container.reset_override() diff --git a/tests/container_objects.py b/tests/container_objects.py index e28289e..995e0fd 100644 --- a/tests/container_objects.py +++ b/tests/container_objects.py @@ -83,8 +83,8 @@ def func_with_injections( *, ddd, redis=Provide[Container.redis], - svc1=Provide["service"], - svc2=Provide["some_service"], + svc1=Provide[Container.service], + svc2=Provide[Container.some_service], numms=Provide[Container.num], partial_callable_param=Provide[Container.partial_callable], ): diff --git a/tests/test_container_registry.py b/tests/test_container_registry.py deleted file mode 100644 index 873cf27..0000000 --- a/tests/test_container_registry.py +++ /dev/null @@ -1,16 +0,0 @@ -from unittest import mock - -import pytest -from injection.container_registry import ContainerRegistry - - -@mock.patch.object(ContainerRegistry, "_ContainerRegistry__get_containers") -def test_container_registry_fail_on_get_default_container( - mock_get_containers_method, - container_registry, -): - match = "You should create at least one container" - mock_get_containers_method.return_value = [] - - with pytest.raises(Exception, match=match): - container_registry.get_default_container() diff --git a/tests/test_inject.py b/tests/test_inject.py index d2086e1..16bae19 100644 --- a/tests/test_inject.py +++ b/tests/test_inject.py @@ -1,8 +1,5 @@ -from unittest.mock import patch - import pytest from injection import Provide -from injection.container_registry import ContainerRegistry from injection.inject import _resolve_provide_marker @@ -20,18 +17,5 @@ def test_resolve_provide_marker_fail_when_marker_parameter_has_incorrect_type(): with pytest.raises(Exception) as e: _resolve_provide_marker(Provide[object]) - error_msg = f"Incorrect marker type: {type(object)!r}. Marker parameter must be either str or BaseProvider." - assert e.value.args[0] == error_msg - - -@patch.object(ContainerRegistry, "get_containers_count") -def test_container_registry_fail_with_string_marker_when_containers_more_than_one( - mock_get_containers_count_method, -): - error_msg = "Please specify the container and its provider explicitly" - mock_get_containers_count_method.return_value = 2 - - with pytest.raises(Exception) as e: - _resolve_provide_marker(Provide["redis"]) - + error_msg = f"Incorrect marker type: {type(object)!r}. Marker parameter must be either BaseProvider." assert e.value.args[0] == error_msg diff --git a/tests/test_integrations/test_drf/drf_test_project/views.py b/tests/test_integrations/test_drf/drf_test_project/views.py index be22bf8..30ca7c8 100644 --- a/tests/test_integrations/test_drf/drf_test_project/views.py +++ b/tests/test_integrations/test_drf/drf_test_project/views.py @@ -4,6 +4,8 @@ from rest_framework.response import Response from rest_framework.views import APIView +from tests.container_objects import Container + class PostEndpointBodySerializer(serializers.Serializer): key = serializers.IntegerField() @@ -11,12 +13,12 @@ class PostEndpointBodySerializer(serializers.Serializer): class View(APIView): @inject - def get(self, _: Request, redis=Provide["redis"]): + def get(self, _: Request, redis=Provide[Container.redis]): response_body = {"redis_url": redis.url} return Response(response_body, status=status.HTTP_200_OK) @inject - def post(self, request: Request, redis=Provide["redis"]): + def post(self, request: Request, redis=Provide[Container.redis]): body_serializer = PostEndpointBodySerializer(data=request.data) body_serializer.is_valid() key = body_serializer.validated_data["key"] diff --git a/tests/test_integrations/test_fastapi/handlers.py b/tests/test_integrations/test_fastapi/handlers.py index ddb1844..fb54a4d 100644 --- a/tests/test_integrations/test_fastapi/handlers.py +++ b/tests/test_integrations/test_fastapi/handlers.py @@ -12,8 +12,7 @@ router = APIRouter(prefix="/api") -RedisDependency = Annotated[Redis, Depends(Provide["redis"])] -RedisDependencyExplicit = Annotated[Redis, Depends(Provide[Container.redis])] +RedisDependency = Annotated[Redis, Depends(Provide[Container.redis])] @router.get("/values") @@ -25,6 +24,6 @@ def some_get_endpoint_handler(redis: RedisDependency): @router.post("/values") @inject -async def some_get_async_endpoint_handler(redis: RedisDependencyExplicit): +async def some_get_async_endpoint_handler(redis: RedisDependency): value = redis.get(399) return {"detail": value} diff --git a/tests/test_integrations/test_flask/test_integration.py b/tests/test_integrations/test_flask/test_integration.py index b11cf33..224cf4d 100644 --- a/tests/test_integrations/test_flask/test_integration.py +++ b/tests/test_integrations/test_flask/test_integration.py @@ -1,13 +1,15 @@ from flask import Flask from injection import Provide, inject +from tests.container_objects import Container + app = Flask(__name__) app.config.update({"TESTING": True}) @app.route("/some_resource") @inject -def flask_endpoint(redis=Provide["redis"]): +def flask_endpoint(redis=Provide[Container.redis]): value = redis.get(-900) return {"detail": value}