diff --git a/rodi/__init__.py b/rodi/__init__.py index 7f27091..139e39a 100644 --- a/rodi/__init__.py +++ b/rodi/__init__.py @@ -837,6 +837,9 @@ def __call__(self, context, activating_type): return self.factory(context) +_ContainerSelf = TypeVar("_ContainerSelf", bound="Container") + + class Container(ContainerProtocol): """ Configuration class for a collection of services. @@ -844,7 +847,7 @@ class Container(ContainerProtocol): __slots__ = ("_map", "_aliases", "_exact_aliases", "strict") - def __init__(self, *, strict: bool = False): + def __init__(self, *, strict: bool = False) -> None: self._map: Dict[Type, Callable] = {} self._aliases: DefaultDict[str, Set[Type]] = defaultdict(set) self._exact_aliases: Dict[str, Type] = {} @@ -857,18 +860,18 @@ def provider(self) -> Services: self._provider = self.build_provider() return self._provider - def __iter__(self): + def __iter__(self) -> Iterator[tuple[Type, Callable]]: yield from self._map.items() - def __contains__(self, key): + def __contains__(self, key: object) -> bool: return key in self._map def bind_types( - self, + self: _ContainerSelf, obj_type: Any, concrete_type: Any = None, life_style: ServiceLifeStyle = ServiceLifeStyle.TRANSIENT, - ): + ) -> _ContainerSelf: try: assert issubclass(concrete_type, obj_type), ( f"Cannot register {class_name(obj_type)} for abstract class " @@ -881,13 +884,13 @@ def bind_types( return self def register( - self, + self: _ContainerSelf, obj_type: Any, sub_type: Any = None, instance: Any = None, *args, **kwargs, - ) -> "Container": + ) -> _ContainerSelf: """ Registers a type in this container. """ @@ -913,7 +916,11 @@ def resolve( """ return self.provider.get(obj_type, scope=scope) - def add_alias(self, name: str, desired_type: Type): + def add_alias( + self: _ContainerSelf, + name: str, + desired_type: Type, + ) -> _ContainerSelf: """ Adds an alias to the set of inferred aliases. @@ -928,7 +935,7 @@ def add_alias(self, name: str, desired_type: Type): self._aliases[name].add(desired_type) return self - def add_aliases(self, values: AliasesTypeHint): + def add_aliases(self: _ContainerSelf, values: AliasesTypeHint) -> _ContainerSelf: """ Adds aliases to the set of inferred aliases. @@ -939,7 +946,12 @@ def add_aliases(self, values: AliasesTypeHint): self.add_alias(key, value) return self - def set_alias(self, name: str, desired_type: Type, override: bool = False): + def set_alias( + self: _ContainerSelf, + name: str, + desired_type: Type, + override: bool = False, + ) -> _ContainerSelf: """ Sets an exact alias for a desired type. @@ -955,7 +967,11 @@ def set_alias(self, name: str, desired_type: Type, override: bool = False): self._exact_aliases[name] = desired_type return self - def set_aliases(self, values: AliasesTypeHint, override: bool = False): + def set_aliases( + self: _ContainerSelf, + values: AliasesTypeHint, + override: bool = False, + ) -> _ContainerSelf: """Sets many exact aliases for desired types. :param values: mapping object (parameter name: class) @@ -984,8 +1000,8 @@ def _bind(self, key: Type, value: Any) -> None: self._aliases[to_standard_param_name(key_name)].add(key) def add_instance( - self, instance: Any, declared_class: Optional[Type] = None - ) -> "Container": + self: _ConstainerSelf, instance: Any, declared_class: Optional[Type] = None + ) -> _ContainerSelf: """ Registers an exact instance, optionally by declared class. @@ -1001,8 +1017,8 @@ def add_instance( return self def add_singleton( - self, base_type: Type, concrete_type: Optional[Type] = None - ) -> "Container": + self: _ContainerSelf, base_type: Type, concrete_type: Optional[Type] = None + ) -> _ContainerSelf: """ Registers a type by base type, to be instantiated with singleton lifetime. If a single type is given, the method `add_exact_singleton` is used. @@ -1018,8 +1034,10 @@ def add_singleton( return self.bind_types(base_type, concrete_type, ServiceLifeStyle.SINGLETON) def add_scoped( - self, base_type: Type, concrete_type: Optional[Type] = None - ) -> "Container": + self: _ConstainerSelf, + base_type: Type, + concrete_type: Optional[Type] = None, + ) -> _ContainerSelf: """ Registers a type by base type, to be instantiated with scoped lifetime. If a single type is given, the method `add_exact_scoped` is used. @@ -1035,8 +1053,10 @@ def add_scoped( return self.bind_types(base_type, concrete_type, ServiceLifeStyle.SCOPED) def add_transient( - self, base_type: Type, concrete_type: Optional[Type] = None - ) -> "Container": + self: _ContainerSelf, + base_type: Type, + concrete_type: Optional[Type] = None, + ) -> _ContainerSelf: """ Registers a type by base type, to be instantiated with transient lifetime. If a single type is given, the method `add_exact_transient` is used. @@ -1051,7 +1071,7 @@ def add_transient( return self.bind_types(base_type, concrete_type, ServiceLifeStyle.TRANSIENT) - def _add_exact_singleton(self, concrete_type: Type) -> "Container": + def _add_exact_singleton(self: _ContainerSelf, concrete_type: Type) -> _ContaineSelf: """ Registers an exact type, to be instantiated with singleton lifetime. @@ -1065,7 +1085,7 @@ def _add_exact_singleton(self, concrete_type: Type) -> "Container": ) return self - def _add_exact_scoped(self, concrete_type: Type) -> "Container": + def _add_exact_scoped(self: _ContainerSelf, concrete_type: Type) -> _ContainerSelf: """ Registers an exact type, to be instantiated with scoped lifetime. @@ -1078,7 +1098,7 @@ def _add_exact_scoped(self, concrete_type: Type) -> "Container": ) return self - def _add_exact_transient(self, concrete_type: Type) -> "Container": + def _add_exact_transient(self: _ContainerSelf, concrete_type: Type) -> _ContainerSelf: """ Registers an exact type, to be instantiated with transient lifetime. @@ -1093,20 +1113,26 @@ def _add_exact_transient(self, concrete_type: Type) -> "Container": return self def add_singleton_by_factory( - self, factory: FactoryCallableType, return_type: Optional[Type] = None - ) -> "Container": + self: _ContainerSelf, + factory: FactoryCallableType, + return_type: Optional[Type] = None, + ) -> _ContainerSelf: self.register_factory(factory, return_type, ServiceLifeStyle.SINGLETON) return self def add_transient_by_factory( - self, factory: FactoryCallableType, return_type: Optional[Type] = None - ) -> "Container": + self: _ContainerSelf, + factory: FactoryCallableType, + return_type: Optional[Type] = None, + ) -> _ContainerSelf: self.register_factory(factory, return_type, ServiceLifeStyle.TRANSIENT) return self def add_scoped_by_factory( - self, factory: FactoryCallableType, return_type: Optional[Type] = None - ) -> "Container": + self: _ContainerSelf, + factory: FactoryCallableType, + return_type: Optional[Type] = None, + ) -> _ContainerSelf: self.register_factory(factory, return_type, ServiceLifeStyle.SCOPED) return self