Skip to content

Commit

Permalink
Remove container registry and feature with pass provider via string
Browse files Browse the repository at this point in the history
  • Loading branch information
ivan committed Jul 27, 2024
1 parent e4517b1 commit 189bfc4
Show file tree
Hide file tree
Showing 13 changed files with 21 additions and 116 deletions.
5 changes: 2 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,7 @@ def create_app():
return app


RedisDependency = Annotated[Redis, Depends(Provide["redis"])]
RedisDependencyExplicit = Annotated[Redis, Depends(Provide[Container.redis])]
RedisDependency = Annotated[Redis, Depends(Provide[Container.redis])]


@router.get("/values")
Expand All @@ -106,7 +105,7 @@ def some_get_endpoint_handler(redis: RedisDependency):

@router.post("/values")
@inject
async def some_get_async_endpoint_handler(redis: RedisDependencyExplicit):
async def some_get_async_endpoint_handler(redis: RedisDependency):
value = redis.get(399)
return {"detail": value}

Expand Down
2 changes: 1 addition & 1 deletion docs/testing/provider-overriding.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class DIContainer(DeclarativeContainer):


@inject
def exec_query_example(some_sqla_dao=Provide["some_sqla_dao"]):
def exec_query_example(some_sqla_dao=Provide[DIContainer.some_sqla_dao]):
with some_sqla_dao:
result = some_sqla_dao.exec_query('SELECT 234')

Expand Down
11 changes: 0 additions & 11 deletions src/injection/base_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,6 @@ def _get_providers_generator(cls) -> Iterator[BaseProvider]:
def get_providers(cls) -> List[BaseProvider]:
return list(cls.__get_providers().values())

@classmethod
def get_provider_by_attr_name(cls, provider_name: str) -> BaseProvider:
providers = cls.__get_providers()
provider = providers.get(provider_name)

if provider_name not in providers:
msg = f"Provider {provider_name!r} not found"
raise Exception(msg)

return provider

@classmethod
@contextmanager
def override_providers(
Expand Down
28 changes: 0 additions & 28 deletions src/injection/container_registry.py

This file was deleted.

20 changes: 3 additions & 17 deletions src/injection/inject.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from functools import wraps
from typing import Any, Callable, Dict, TypeVar, Union

from injection.container_registry import ContainerRegistry
from injection.provide import Provide
from injection.providers.base import BaseProvider

Expand Down Expand Up @@ -64,24 +63,11 @@ def _resolve_provide_marker(marker: Provide) -> BaseProvider:

marker_provider = marker.provider

if not isinstance(marker_provider, (str, BaseProvider)):
msg = f"Incorrect marker type: {type(marker_provider)!r}. Marker parameter must be either str or BaseProvider."
if not isinstance(marker_provider, BaseProvider):
msg = f"Incorrect marker type: {type(marker_provider)!r}. Marker parameter must be either BaseProvider."
raise TypeError(msg)

if isinstance(marker_provider, BaseProvider):
return marker_provider

containers_count = ContainerRegistry.get_containers_count()

if isinstance(marker_provider, str):
if containers_count > 1:
msg = "Please specify the container and its provider explicitly"
raise Exception(msg)

if containers_count == 1:
container = ContainerRegistry.get_default_container()
provider = container.get_provider_by_attr_name(marker_provider)
return provider
return marker_provider


def _extract_provider_values_from_markers(markers: Markers) -> Dict[str, Any]:
Expand Down
6 changes: 3 additions & 3 deletions src/injection/provide.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
from typing import Generic, TypeVar, Union
from typing import Generic, TypeVar

from injection.providers.base import BaseProvider

T = TypeVar("T")


class ClassGetItemMeta(Generic[T], type):
def __getitem__(cls, item: Union[str, BaseProvider[T]]) -> T:
def __getitem__(cls, item: BaseProvider[T]) -> T:
return cls(item)


class Provide(metaclass=ClassGetItemMeta):
def __init__(self, provider: Union[str, BaseProvider[T]]) -> None:
def __init__(self, provider: BaseProvider[T]) -> None:
self.provider = provider

def __call__(self) -> T:
Expand Down
12 changes: 0 additions & 12 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,8 @@
import pytest
from injection.container_registry import ContainerRegistry

from tests.container_objects import Container


@pytest.fixture(scope="session")
def container_registry():
return ContainerRegistry


@pytest.fixture(scope="session")
def container():
return Container


# need this because container registry is singleton
# @pytest.fixture(autouse=True)
# def _force_clean_container_registry(container):
# container.reset_override()
4 changes: 2 additions & 2 deletions tests/container_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ def func_with_injections(
*,
ddd,
redis=Provide[Container.redis],
svc1=Provide["service"],
svc2=Provide["some_service"],
svc1=Provide[Container.service],
svc2=Provide[Container.some_service],
numms=Provide[Container.num],
partial_callable_param=Provide[Container.partial_callable],
):
Expand Down
16 changes: 0 additions & 16 deletions tests/test_container_registry.py

This file was deleted.

18 changes: 1 addition & 17 deletions tests/test_inject.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
from unittest.mock import patch

import pytest
from injection import Provide
from injection.container_registry import ContainerRegistry
from injection.inject import _resolve_provide_marker


Expand All @@ -20,18 +17,5 @@ def test_resolve_provide_marker_fail_when_marker_parameter_has_incorrect_type():
with pytest.raises(Exception) as e:
_resolve_provide_marker(Provide[object])

error_msg = f"Incorrect marker type: {type(object)!r}. Marker parameter must be either str or BaseProvider."
assert e.value.args[0] == error_msg


@patch.object(ContainerRegistry, "get_containers_count")
def test_container_registry_fail_with_string_marker_when_containers_more_than_one(
mock_get_containers_count_method,
):
error_msg = "Please specify the container and its provider explicitly"
mock_get_containers_count_method.return_value = 2

with pytest.raises(Exception) as e:
_resolve_provide_marker(Provide["redis"])

error_msg = f"Incorrect marker type: {type(object)!r}. Marker parameter must be either BaseProvider."
assert e.value.args[0] == error_msg
6 changes: 4 additions & 2 deletions tests/test_integrations/test_drf/drf_test_project/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,21 @@
from rest_framework.response import Response
from rest_framework.views import APIView

from tests.container_objects import Container


class PostEndpointBodySerializer(serializers.Serializer):
key = serializers.IntegerField()


class View(APIView):
@inject
def get(self, _: Request, redis=Provide["redis"]):
def get(self, _: Request, redis=Provide[Container.redis]):
response_body = {"redis_url": redis.url}
return Response(response_body, status=status.HTTP_200_OK)

@inject
def post(self, request: Request, redis=Provide["redis"]):
def post(self, request: Request, redis=Provide[Container.redis]):
body_serializer = PostEndpointBodySerializer(data=request.data)
body_serializer.is_valid()
key = body_serializer.validated_data["key"]
Expand Down
5 changes: 2 additions & 3 deletions tests/test_integrations/test_fastapi/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@

router = APIRouter(prefix="/api")

RedisDependency = Annotated[Redis, Depends(Provide["redis"])]
RedisDependencyExplicit = Annotated[Redis, Depends(Provide[Container.redis])]
RedisDependency = Annotated[Redis, Depends(Provide[Container.redis])]


@router.get("/values")
Expand All @@ -25,6 +24,6 @@ def some_get_endpoint_handler(redis: RedisDependency):

@router.post("/values")
@inject
async def some_get_async_endpoint_handler(redis: RedisDependencyExplicit):
async def some_get_async_endpoint_handler(redis: RedisDependency):
value = redis.get(399)
return {"detail": value}
4 changes: 3 additions & 1 deletion tests/test_integrations/test_flask/test_integration.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from flask import Flask
from injection import Provide, inject

from tests.container_objects import Container

app = Flask(__name__)
app.config.update({"TESTING": True})


@app.route("/some_resource")
@inject
def flask_endpoint(redis=Provide["redis"]):
def flask_endpoint(redis=Provide[Container.redis]):
value = redis.get(-900)
return {"detail": value}

Expand Down

0 comments on commit 189bfc4

Please sign in to comment.