Hello, {user}!
" \ No newline at end of file + return f"Hello, {user}!
" diff --git a/docs/docs_src/usages/starlette.py b/docs/docs_src/usages/starlette.py index 64aa1d26..f081f307 100644 --- a/docs/docs_src/usages/starlette.py +++ b/docs/docs_src/usages/starlette.py @@ -1,9 +1,11 @@ # Is that FastAPI??? +from pydantic import Field from starlette.applications import Starlette from starlette.responses import PlainTextResponse from starlette.routing import Route -from pydantic import Field -from fast_depends import inject, Depends + +from fast_depends import Depends, inject + def unwrap_path(func): async def wrapper(request): # unwrap incoming params to **kwargs here @@ -20,4 +22,4 @@ async def hello(user: str = Depends(get_user)): app = Starlette(debug=True, routes=[ Route("/{id}", hello) -]) \ No newline at end of file +]) diff --git a/fast_depends/__about__.py b/fast_depends/__about__.py index 513e73f5..2fc3727e 100644 --- a/fast_depends/__about__.py +++ b/fast_depends/__about__.py @@ -1,3 +1,3 @@ """FastDepends - extracted and cleared from HTTP domain FastAPI Dependency Injection System""" -__version__ = "2.4.2" +__version__ = "3.0.0a3" diff --git a/fast_depends/__init__.py b/fast_depends/__init__.py index 5fa239cd..aa28076a 100644 --- a/fast_depends/__init__.py +++ b/fast_depends/__init__.py @@ -1,9 +1,10 @@ -from fast_depends.dependencies import Provider, dependency_provider +from fast_depends.dependencies import Provider +from fast_depends.exceptions import ValidationError from fast_depends.use import Depends, inject __all__ = ( "Depends", - "dependency_provider", + "ValidationError", "Provider", "inject", ) diff --git a/fast_depends/_compat.py b/fast_depends/_compat.py index fd2a7904..c45f3618 100644 --- a/fast_depends/_compat.py +++ b/fast_depends/_compat.py @@ -1,73 +1,104 @@ import sys +import typing from importlib.metadata import version as get_version -from typing import Any, Dict, Optional, Tuple, Type - -from pydantic import BaseModel, create_model -from pydantic.version import VERSION as PYDANTIC_VERSION __all__ = ( - "BaseModel", - "create_model", - "evaluate_forwardref", - "PYDANTIC_V2", - "get_config_base", - "ConfigDict", "ExceptionGroup", + "evaluate_forwardref", ) -PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.") - -default_pydantic_config = {"arbitrary_types_allowed": True} - -evaluate_forwardref: Any -# isort: off -if PYDANTIC_V2: - from pydantic import ConfigDict - from pydantic._internal._typing_extra import ( # type: ignore[no-redef] - eval_type_lenient as evaluate_forwardref, - ) - - def model_schema(model: Type[BaseModel]) -> Dict[str, Any]: - return model.model_json_schema() - - def get_config_base(config_data: Optional[ConfigDict] = None) -> ConfigDict: - return config_data or ConfigDict(**default_pydantic_config) # type: ignore[typeddict-item] - - def get_aliases(model: Type[BaseModel]) -> Tuple[str, ...]: - return tuple(f.alias or name for name, f in model.model_fields.items()) - - class CreateBaseModel(BaseModel): - """Just to support FastStream < 0.3.7.""" - - model_config = ConfigDict(arbitrary_types_allowed=True) +ANYIO_V3 = get_version("anyio").startswith("3.") +if ANYIO_V3: + from anyio import ExceptionGroup as ExceptionGroup else: - from pydantic.typing import evaluate_forwardref as evaluate_forwardref # type: ignore[no-redef] - from pydantic.config import get_config, ConfigDict, BaseConfig - - def get_config_base(config_data: Optional[ConfigDict] = None) -> Type[BaseConfig]: # type: ignore[misc] - return get_config(config_data or ConfigDict(**default_pydantic_config)) # type: ignore[typeddict-item] + if sys.version_info < (3, 11): + from exceptiongroup import ExceptionGroup as ExceptionGroup + else: + ExceptionGroup = ExceptionGroup - def model_schema(model: Type[BaseModel]) -> Dict[str, Any]: - return model.schema() - def get_aliases(model: Type[BaseModel]) -> Tuple[str, ...]: - return tuple(f.alias or name for name, f in model.__fields__.items()) +def evaluate_forwardref( + value: typing.Any, + globalns: typing.Optional[dict[str, typing.Any]] = None, + localns: typing.Optional[dict[str, typing.Any]] = None, +) -> typing.Any: + """Behaves like typing._eval_type, except it won't raise an error if a forward reference can't be resolved.""" + if value is None: + value = NoneType + elif isinstance(value, str): + value = _make_forward_ref(value, is_argument=False, is_class=True) + + try: + return eval_type_backport(value, globalns, localns) + except NameError: + # the point of this function is to be tolerant to this case + return value + + +def eval_type_backport( + value: typing.Any, + globalns: typing.Optional[dict[str, typing.Any]] = None, + localns: typing.Optional[dict[str, typing.Any]] = None, +) -> typing.Any: + """Like `typing._eval_type`, but falls back to the `eval_type_backport` package if it's + installed to let older Python versions use newer typing features. + Specifically, this transforms `X | Y` into `typing.Union[X, Y]` + and `list[X]` into `typing.List[X]` etc. (for all the types made generic in PEP 585) + if the original syntax is not supported in the current Python version. + """ + try: + return typing._eval_type( # type: ignore + value, globalns, localns + ) + except TypeError as e: + if not (isinstance(value, typing.ForwardRef) and is_backport_fixable_error(e)): + raise + try: + from eval_type_backport import eval_type_backport + except ImportError: + raise TypeError( + f"You have a type annotation {value.__forward_arg__!r} " + f"which makes use of newer typing features than are supported in your version of Python. " + f"To handle this error, you should either remove the use of new syntax " + f"or install the `eval_type_backport` package." + ) from e + + return eval_type_backport(value, globalns, localns, try_default=False) + + +def is_backport_fixable_error(e: TypeError) -> bool: + msg = str(e) + return msg.startswith("unsupported operand type(s) for |: ") or "' object is not subscriptable" in msg + + +if sys.version_info < (3, 10): + NoneType = type(None) +else: + from types import NoneType as NoneType - class CreateBaseModel(BaseModel): # type: ignore[no-redef] - """Just to support FastStream < 0.3.7.""" - class Config: - arbitrary_types_allowed = True +if sys.version_info < (3, 9, 8) or (3, 10) <= sys.version_info < (3, 10, 1): + def _make_forward_ref( + arg: typing.Any, + is_argument: bool = True, + *, + is_class: bool = False, + ) -> typing.ForwardRef: + """Wrapper for ForwardRef that accounts for the `is_class` argument missing in older versions. + The `module` argument is omitted as it breaks <3.9.8, =3.10.0 and isn't used in the calls below. + See https://github.com/python/cpython/pull/28560 for some background. + The backport happened on 3.9.8, see: + https://github.com/pydantic/pydantic/discussions/6244#discussioncomment-6275458, + and on 3.10.1 for the 3.10 branch, see: + https://github.com/pydantic/pydantic/issues/6912 -ANYIO_V3 = get_version("anyio").startswith("3.") + Implemented as EAFP with memory. + """ + return typing.ForwardRef(arg, is_argument) -if ANYIO_V3: - from anyio import ExceptionGroup as ExceptionGroup else: - if sys.version_info < (3, 11): - from exceptiongroup import ExceptionGroup as ExceptionGroup - else: - ExceptionGroup = ExceptionGroup + _make_forward_ref = typing.ForwardRef + diff --git a/fast_depends/core/__init__.py b/fast_depends/core/__init__.py index 99799a77..08d566cc 100644 --- a/fast_depends/core/__init__.py +++ b/fast_depends/core/__init__.py @@ -1,5 +1,5 @@ -from fast_depends.core.build import build_call_model -from fast_depends.core.model import CallModel +from .builder import build_call_model +from .model import CallModel __all__ = ( "CallModel", diff --git a/fast_depends/core/build.py b/fast_depends/core/build.py deleted file mode 100644 index 8006eb95..00000000 --- a/fast_depends/core/build.py +++ /dev/null @@ -1,212 +0,0 @@ -import inspect -from typing import ( - Any, - Awaitable, - Callable, - Dict, - List, - Optional, - Sequence, - Tuple, - Type, - Union, -) - -from typing_extensions import ( - Annotated, - ParamSpec, - TypeVar, - get_args, - get_origin, -) - -from fast_depends._compat import ConfigDict, create_model, get_config_base -from fast_depends.core.model import CallModel, ResponseModel -from fast_depends.dependencies import Depends -from fast_depends.library import CustomField -from fast_depends.utils import ( - get_typed_signature, - is_async_gen_callable, - is_coroutine_callable, - is_gen_callable, -) - -CUSTOM_ANNOTATIONS = (Depends, CustomField) - - -P = ParamSpec("P") -T = TypeVar("T") - - -def build_call_model( - call: Union[ - Callable[P, T], - Callable[P, Awaitable[T]], - ], - *, - cast: bool = True, - use_cache: bool = True, - is_sync: Optional[bool] = None, - extra_dependencies: Sequence[Depends] = (), - pydantic_config: Optional[ConfigDict] = None, -) -> CallModel[P, T]: - name = getattr(call, "__name__", type(call).__name__) - - is_call_async = is_coroutine_callable(call) - if is_sync is None: - is_sync = not is_call_async - else: - assert not ( - is_sync and is_call_async - ), f"You cannot use async dependency `{name}` at sync main" - - typed_params, return_annotation = get_typed_signature(call) - if ( - (is_call_generator := is_gen_callable(call) or - is_async_gen_callable(call)) and - (return_args := get_args(return_annotation)) - ): - return_annotation = return_args[0] - - class_fields: Dict[str, Tuple[Any, Any]] = {} - dependencies: Dict[str, CallModel[..., Any]] = {} - custom_fields: Dict[str, CustomField] = {} - positional_args: List[str] = [] - keyword_args: List[str] = [] - - for param_name, param in typed_params.parameters.items(): - dep: Optional[Depends] = None - custom: Optional[CustomField] = None - - if param.annotation is inspect.Parameter.empty: - annotation = Any - - elif get_origin(param.annotation) is Annotated: - annotated_args = get_args(param.annotation) - type_annotation = annotated_args[0] - custom_annotations = [ - arg for arg in annotated_args[1:] if isinstance(arg, CUSTOM_ANNOTATIONS) - ] - - assert ( - len(custom_annotations) <= 1 - ), f"Cannot specify multiple `Annotated` Custom arguments for `{param_name}`!" - - next_custom = next(iter(custom_annotations), None) - if next_custom is not None: - if isinstance(next_custom, Depends): - dep = next_custom - elif isinstance(next_custom, CustomField): - custom = next_custom - else: # pragma: no cover - raise AssertionError("unreachable") - - annotation = type_annotation - else: - annotation = param.annotation - else: - annotation = param.annotation - - default: Any - if param_name == "args": - default = () - elif param_name == "kwargs": - default = {} - else: - default = param.default - - if isinstance(default, Depends): - assert ( - not dep - ), "You can not use `Depends` with `Annotated` and default both" - dep = default - - elif isinstance(default, CustomField): - assert ( - not custom - ), "You can not use `CustomField` with `Annotated` and default both" - custom = default - - elif default is inspect.Parameter.empty: - class_fields[param_name] = (annotation, ...) - - else: - class_fields[param_name] = (annotation, default) - - if dep: - if not cast: - dep.cast = False - - dependencies[param_name] = build_call_model( - dep.dependency, - cast=dep.cast, - use_cache=dep.use_cache, - is_sync=is_sync, - pydantic_config=pydantic_config, - ) - - if dep.cast is True: - class_fields[param_name] = (annotation, ...) - keyword_args.append(param_name) - - elif custom: - assert not ( - is_sync and is_coroutine_callable(custom.use) - ), f"You cannot use async custom field `{type(custom).__name__}` at sync `{name}`" - - custom.set_param_name(param_name) - custom_fields[param_name] = custom - - if custom.cast is False: - annotation = Any - - if custom.required: - class_fields[param_name] = (annotation, ...) - else: - class_fields[param_name] = (Optional[annotation], None) - keyword_args.append(param_name) - - else: - if param.kind is param.KEYWORD_ONLY: - keyword_args.append(param_name) - elif param_name not in ("args", "kwargs"): - positional_args.append(param_name) - - func_model = create_model( # type: ignore[call-overload] - name, - __config__=get_config_base(pydantic_config), - **class_fields, - ) - - response_model: Optional[Type[ResponseModel[T]]] = None - if cast and return_annotation and return_annotation is not inspect.Parameter.empty: - response_model = create_model( - "ResponseModel", - __config__=get_config_base(pydantic_config), # type: ignore[assignment] - response=(return_annotation, ...), - ) - - return CallModel( - call=call, - model=func_model, - response_model=response_model, - params=class_fields, - cast=cast, - use_cache=use_cache, - is_async=is_call_async, - is_generator=is_call_generator, - dependencies=dependencies, - custom_fields=custom_fields, - positional_args=positional_args, - keyword_args=keyword_args, - extra_dependencies=[ - build_call_model( - d.dependency, - cast=d.cast, - use_cache=d.use_cache, - is_sync=is_sync, - pydantic_config=pydantic_config, - ) - for d in extra_dependencies - ], - ) diff --git a/fast_depends/core/builder.py b/fast_depends/core/builder.py new file mode 100644 index 00000000..413ed407 --- /dev/null +++ b/fast_depends/core/builder.py @@ -0,0 +1,267 @@ +import inspect +from collections.abc import Sequence +from copy import deepcopy +from typing import ( + TYPE_CHECKING, + Annotated, + Any, + Callable, + Optional, + TypeVar, +) + +from typing_extensions import ( + ParamSpec, + get_args, + get_origin, +) + +from fast_depends.dependencies.model import Dependant +from fast_depends.library import CustomField +from fast_depends.library.serializer import OptionItem, Serializer, SerializerProto +from fast_depends.utils import ( + get_typed_signature, + is_async_gen_callable, + is_coroutine_callable, + is_gen_callable, +) + +from .model import CallModel + +if TYPE_CHECKING: + from fast_depends.dependencies.provider import Key, Provider + + +CUSTOM_ANNOTATIONS = (Dependant, CustomField,) + + +P = ParamSpec("P") +T = TypeVar("T") + + +def build_call_model( + call: Callable[..., Any], + *, + dependency_provider: "Provider", + use_cache: bool = True, + is_sync: Optional[bool] = None, + extra_dependencies: Sequence[Dependant] = (), + serializer_cls: Optional["SerializerProto"] = None, + serialize_result: bool = True, +) -> CallModel: + name = getattr(inspect.unwrap(call), "__name__", type(call).__name__) + + is_call_async = is_coroutine_callable(call) or is_async_gen_callable(call) + if is_sync is None: + is_sync = not is_call_async + else: + assert not ( + is_sync and is_call_async + ), f"You cannot use async dependency `{name}` at sync main" + + typed_params, return_annotation = get_typed_signature(call) + if ( + (is_call_generator := is_gen_callable(call) or + is_async_gen_callable(call)) and + (return_args := get_args(return_annotation)) + ): + return_annotation = return_args[0] + + if not serialize_result: + return_annotation = inspect.Parameter.empty + + class_fields: list[OptionItem] = [] + dependencies: dict[str, Key] = {} + custom_fields: dict[str, CustomField] = {} + positional_args: list[str] = [] + keyword_args: list[str] = [] + args_name: Optional[str] = None + kwargs_name: Optional[str] = None + + for param_name, param in typed_params.parameters.items(): + dep: Optional[Dependant] = None + custom: Optional[CustomField] = None + + if param.annotation is inspect.Parameter.empty: + annotation = Any + + elif get_origin(param.annotation) is Annotated: + annotated_args = get_args(param.annotation) + type_annotation = annotated_args[0] + + custom_annotations = [] + regular_annotations = [] + for arg in annotated_args[1:]: + if isinstance(arg, CUSTOM_ANNOTATIONS): + custom_annotations.append(arg) + else: + regular_annotations.append(arg) + + assert ( + len(custom_annotations) <= 1 + ), f"Cannot specify multiple `Annotated` Custom arguments for `{param_name}`!" + + next_custom = next(iter(custom_annotations), None) + if next_custom is not None: + if isinstance(next_custom, Dependant): + dep = next_custom + elif isinstance(next_custom, CustomField): + custom = deepcopy(next_custom) + else: # pragma: no cover + raise AssertionError("unreachable") + + if regular_annotations: + annotation = param.annotation + else: + annotation = type_annotation + else: + annotation = param.annotation + else: + annotation = param.annotation + + default: Any + if param.kind is inspect.Parameter.VAR_POSITIONAL: + default = () + elif param.kind is inspect.Parameter.VAR_KEYWORD: + default = {} + else: + default = param.default + + if isinstance(default, Dependant): + assert ( + not dep + ), "You can not use `Depends` with `Annotated` and default both" + dep, default = default, Ellipsis + + elif isinstance(default, CustomField): + assert ( + not custom + ), "You can not use `CustomField` with `Annotated` and default both" + custom, default = default, Ellipsis + + elif not dep and not custom: + class_fields.append(OptionItem( + field_name=param_name, + field_type=annotation, + default_value=... if default is inspect.Parameter.empty else default + )) + + if dep: + dependency = build_call_model( + dep.dependency, + dependency_provider=dependency_provider, + use_cache=dep.use_cache, + is_sync=is_sync, + serializer_cls=serializer_cls, + serialize_result=dep.cast_result, + ) + + key = dependency_provider.add_dependant(dependency) + + overrided_dependency = dependency_provider.get_dependant(key) + + assert not ( + is_sync and is_coroutine_callable(overrided_dependency.call) + ), f"You cannot use async dependency `{overrided_dependency.call_name}` at sync main" + + dependencies[param_name] = key + + if not dep.cast: + annotation = Any + + class_fields.append(OptionItem( + field_name=param_name, + field_type=annotation, + source=dep, + )) + + keyword_args.append(param_name) + + elif custom: + assert not ( + is_sync and is_coroutine_callable(custom.use) + ), f"You cannot use async custom field `{type(custom).__name__}` at sync `{name}`" + + custom.set_param_name(param_name) + custom_fields[param_name] = custom + + if not custom.cast: + annotation = Any + + if custom.required: + class_fields.append(OptionItem( + field_name=param_name, + field_type=annotation, + default_value=default, + source=custom, + )) + + else: + class_fields.append(OptionItem( + field_name=param_name, + field_type=Optional[annotation], + default_value=None if default is Ellipsis else default, + source=custom, + )) + + keyword_args.append(param_name) + + else: + if param.kind is param.KEYWORD_ONLY: + keyword_args.append(param_name) + elif param.kind is param.VAR_KEYWORD: + kwargs_name = param_name + elif param.kind is param.VAR_POSITIONAL: + args_name = param_name + else: + positional_args.append(param_name) + + serializer: Optional[Serializer] = None + if serializer_cls is not None: + serializer = serializer_cls( + name=name, + options=class_fields, + response_type=return_annotation, + ) + + solved_extra_dependencies: list[Key] = [] + for dep in extra_dependencies: + dependency = build_call_model( + dep.dependency, + dependency_provider=dependency_provider, + use_cache=dep.use_cache, + is_sync=is_sync, + serializer_cls=serializer_cls, + ) + + key = dependency_provider.add_dependant(dependency) + + overrided_dependency = dependency_provider.get_dependant(key) + + assert not ( + is_sync and is_coroutine_callable(overrided_dependency.call) + ), f"You cannot use async dependency `{overrided_dependency.call_name}` at sync main" + + solved_extra_dependencies.append(key) + + return CallModel( + call=call, + serializer=serializer, + params=tuple( + i for i in class_fields if ( + i.field_name not in dependencies and + i.field_name not in custom_fields + ) + ), + use_cache=use_cache, + is_async=is_call_async, + is_generator=is_call_generator, + dependencies=dependencies, + custom_fields=custom_fields, + positional_args=positional_args, + keyword_args=keyword_args, + args_name=args_name, + kwargs_name=kwargs_name, + extra_dependencies=solved_extra_dependencies, + dependency_provider=dependency_provider, + ) diff --git a/fast_depends/core/model.py b/fast_depends/core/model.py index 2e9d41f7..7ff0b31b 100644 --- a/fast_depends/core/model.py +++ b/fast_depends/core/model.py @@ -1,29 +1,18 @@ -from collections import namedtuple +from collections.abc import Generator, Iterable, Sequence from contextlib import AsyncExitStack, ExitStack -from functools import partial from inspect import Parameter, unwrap -from itertools import chain from typing import ( + TYPE_CHECKING, Any, - Awaitable, Callable, - Dict, - Generator, - Generic, - Iterable, - List, Optional, - Sequence, - Tuple, - Type, - Union, ) import anyio -from typing_extensions import ParamSpec, TypeVar -from fast_depends._compat import BaseModel, ExceptionGroup, get_aliases -from fast_depends.library import CustomField +from fast_depends._compat import ExceptionGroup +from fast_depends.library.model import CustomField +from fast_depends.library.serializer import OptionItem, Serializer from fast_depends.utils import ( async_map, is_async_gen_callable, @@ -34,354 +23,223 @@ solve_generator_sync, ) -P = ParamSpec("P") -T = TypeVar("T") +if TYPE_CHECKING: + from fast_depends.dependencies.provider import Key, Provider -PriorityPair = namedtuple( - "PriorityPair", ("call", "dependencies_number", "dependencies_names") -) - - -class ResponseModel(BaseModel, Generic[T]): - response: T - - -class CallModel(Generic[P, T]): - call: Union[ - Callable[P, T], - Callable[P, Awaitable[T]], - ] - is_async: bool - is_generator: bool - model: Optional[Type[BaseModel]] - response_model: Optional[Type[ResponseModel[T]]] - - params: Dict[str, Tuple[Any, Any]] - alias_arguments: Tuple[str, ...] - - dependencies: Dict[str, "CallModel[..., Any]"] - extra_dependencies: Iterable["CallModel[..., Any]"] - sorted_dependencies: Tuple[Tuple["CallModel[..., Any]", int], ...] - custom_fields: Dict[str, CustomField] - keyword_args: Tuple[str, ...] - positional_args: Tuple[str, ...] - - # Dependencies and custom fields - use_cache: bool - cast: bool - +class CallModel: __slots__ = ( "call", "is_async", "is_generator", - "model", - "response_model", "params", "alias_arguments", - "keyword_args", + "args_name", "positional_args", + "kwargs_name", + "keyword_args", "dependencies", "extra_dependencies", - "sorted_dependencies", "custom_fields", "use_cache", - "cast", + "serializer", + "dependency_provider", ) + alias_arguments: tuple[str, ...] + @property def call_name(self) -> str: call = unwrap(self.call) return getattr(call, "__name__", type(call).__name__) @property - def flat_params(self) -> Dict[str, Tuple[Any, Any]]: - params = self.params - for d in (*self.dependencies.values(), *self.extra_dependencies): - params.update(d.flat_params) + def flat_params(self) -> list[OptionItem]: + params = list(self.params) + for d in map( + self.dependency_provider.get_dependant, + (*self.dependencies.values(), *self.extra_dependencies), + ): + for p in d.flat_params: + if p.field_name not in (i.field_name for i in params): + params.append(p) return params - @property - def flat_dependencies( - self, - ) -> Dict[ - Callable[..., Any], - Tuple[ - "CallModel[..., Any]", - Tuple[Callable[..., Any], ...], - ], - ]: - flat: Dict[ - Callable[..., Any], - Tuple[ - "CallModel[..., Any]", - Tuple[Callable[..., Any], ...], - ], - ] = {} - - for i in (*self.dependencies.values(), *self.extra_dependencies): - flat.update( - { - i.call: ( - i, - tuple(j.call for j in i.dependencies.values()), - ) - } - ) - - flat.update(i.flat_dependencies) - - return flat - def __init__( self, - /, - call: Union[ - Callable[P, T], - Callable[P, Awaitable[T]], - ], - model: Optional[Type[BaseModel]], - params: Dict[str, Tuple[Any, Any]], - response_model: Optional[Type[ResponseModel[T]]] = None, - use_cache: bool = True, - cast: bool = True, - is_async: bool = False, - is_generator: bool = False, - dependencies: Optional[Dict[str, "CallModel[..., Any]"]] = None, - extra_dependencies: Optional[Iterable["CallModel[..., Any]"]] = None, - keyword_args: Optional[List[str]] = None, - positional_args: Optional[List[str]] = None, - custom_fields: Optional[Dict[str, CustomField]] = None, + *, + call: Callable[..., Any], + serializer: Optional[Serializer], + params: tuple[OptionItem, ...], + use_cache: bool, + is_async: bool, + is_generator: bool, + args_name: Optional[str], + kwargs_name: Optional[str], + dependencies: dict[str, "Key"], + extra_dependencies: Iterable["Key"], + keyword_args: list[str], + positional_args: list[str], + custom_fields: dict[str, CustomField], + dependency_provider: "Provider", ): self.call = call - self.model = model + self.serializer = serializer - if model: - self.alias_arguments = get_aliases(model) + if serializer is not None: + self.alias_arguments = serializer.get_aliases() else: # pragma: no cover self.alias_arguments = () + self.args_name = args_name self.keyword_args = tuple(keyword_args or ()) + self.kwargs_name = kwargs_name self.positional_args = tuple(positional_args or ()) - self.response_model = response_model self.use_cache = use_cache - self.cast = cast self.is_async = ( - is_async or is_coroutine_callable(call) or is_async_gen_callable(call) + is_async or + is_coroutine_callable(call) or + is_async_gen_callable(call) ) self.is_generator = ( - is_generator or is_gen_callable(call) or is_async_gen_callable(call) + is_generator or + is_gen_callable(call) or + is_async_gen_callable(call) ) self.dependencies = dependencies or {} - self.extra_dependencies = extra_dependencies or () + self.extra_dependencies = tuple(extra_dependencies or ()) self.custom_fields = custom_fields or {} - sorted_dep: List["CallModel[..., Any]"] = [] - flat = self.flat_dependencies - for calls in flat.values(): - _sort_dep(sorted_dep, calls, flat) - - self.sorted_dependencies = tuple( - (i, len(i.sorted_dependencies)) for i in sorted_dep if i.use_cache - ) - - for name in chain(self.dependencies.keys(), self.custom_fields.keys()): - params.pop(name, None) self.params = params + self.dependency_provider = dependency_provider def _solve( self, /, - *args: Tuple[Any, ...], - cache_dependencies: Dict[ - Union[ - Callable[P, T], - Callable[P, Awaitable[T]], - ], - T, - ], - dependency_overrides: Optional[ - Dict[ - Union[ - Callable[P, T], - Callable[P, Awaitable[T]], - ], - Union[ - Callable[P, T], - Callable[P, Awaitable[T]], - ], - ] - ] = None, - **kwargs: Dict[str, Any], + *args: tuple[Any, ...], + cache_dependencies: dict[Callable[..., Any], Any], + **kwargs: dict[str, Any], ) -> Generator[ - Tuple[ + tuple[ Sequence[Any], - Dict[str, Any], - Callable[..., Any], + dict[str, Any], ], Any, - T, + Any, ]: - if dependency_overrides: - call = dependency_overrides.get(self.call, self.call) - assert self.is_async or not is_coroutine_callable( - call - ), f"You cannot use async dependency `{self.call_name}` at sync main" - - else: - call = self.call - - if self.use_cache and call in cache_dependencies: - return cache_dependencies[call] - - kw: Dict[str, Any] = {} + if self.use_cache and self.call in cache_dependencies: + return cache_dependencies[self.call] + kw: dict[str, Any] = {} for arg in self.keyword_args: if (v := kwargs.pop(arg, Parameter.empty)) is not Parameter.empty: kw[arg] = v - if "kwargs" in self.alias_arguments: - kw["kwargs"] = kwargs + if self.kwargs_name in self.alias_arguments: + kw[self.kwargs_name] = kwargs else: kw.update(kwargs) for arg in self.positional_args: - if args: - kw[arg], args = args[0], args[1:] - else: - break + if arg not in kw: + if args: + kw[arg], args = args[0], args[1:] + else: + break - if has_args := "args" in self.alias_arguments: - kw["args"] = args + keyword_args: Iterable[str] + if self.args_name in self.alias_arguments: + kw[self.args_name] = args keyword_args = self.keyword_args else: - keyword_args = set(self.keyword_args + self.positional_args) - for arg in keyword_args - set(self.dependencies.keys()): - if args: - kw[arg], args = args[0], args[1:] - else: + keyword_args = self.keyword_args + self.positional_args + for arg in keyword_args: + if not args: break - solved_kw: Dict[str, Any] - solved_kw = yield args, kw, call + if arg not in self.dependencies and arg not in kw: + kw[arg], args = args[0], args[1:] - args_: Sequence[Any] - if self.cast: - assert self.model, "Cast should be used only with model" - casted_model = self.model(**solved_kw) - - kwargs_ = { - arg: getattr(casted_model, arg, solved_kw.get(arg)) - for arg in keyword_args - } - kwargs_.update(getattr(casted_model, "kwargs", {})) - - if has_args: - args_ = [ - getattr(casted_model, arg, solved_kw.get(arg)) - for arg in self.positional_args - ] - args_.extend(getattr(casted_model, "args", ())) - else: - args_ = () + solved_kw: dict[str, Any] + solved_kw = yield args, kw + args_: Sequence[Any] + if self.serializer is not None: + casted_options = self.serializer(solved_kw) + solved_kw.update(casted_options) + + if self.args_name: + args_ = ( + *map(solved_kw.pop, self.positional_args), + *solved_kw.get(self.args_name, args), + ) else: - kwargs_ = {arg: solved_kw.get(arg) for arg in keyword_args} + args_ = () - if has_args: - args_ = tuple(map(solved_kw.get, self.positional_args)) - else: - args_ = () + kwargs_ = { + arg: solved_kw.pop(arg) + for arg in keyword_args + if arg in solved_kw + } + if self.kwargs_name: + kwargs_.update(solved_kw.get(self.kwargs_name, solved_kw)) - response: T - response = yield args_, kwargs_, call + response = yield args_, kwargs_ - if self.cast and not self.is_generator: + if not self.is_generator: response = self._cast_response(response) if self.use_cache: # pragma: no branch - cache_dependencies[call] = response + cache_dependencies[self.call] = response return response def _cast_response(self, /, value: Any) -> Any: - if self.response_model is not None: - return self.response_model(response=value).response - else: - return value + if self.serializer is not None: + return self.serializer.response(value) + return value def solve( self, /, - *args: Tuple[Any, ...], + *args: tuple[Any, ...], stack: ExitStack, - cache_dependencies: Dict[ - Union[ - Callable[P, T], - Callable[P, Awaitable[T]], - ], - T, - ], - dependency_overrides: Optional[ - Dict[ - Union[ - Callable[P, T], - Callable[P, Awaitable[T]], - ], - Union[ - Callable[P, T], - Callable[P, Awaitable[T]], - ], - ] - ] = None, + cache_dependencies: dict[Callable[..., Any], Any], nested: bool = False, - **kwargs: Dict[str, Any], - ) -> T: + **kwargs: dict[str, Any], + ) -> Any: cast_gen = self._solve( *args, cache_dependencies=cache_dependencies, - dependency_overrides=dependency_overrides, **kwargs, ) try: - args, kwargs, _ = next(cast_gen) + args, kwargs = next(cast_gen) except StopIteration as e: - cached_value: T = e.value + cached_value = e.value return cached_value - # Heat cache and solve extra dependencies - for dep, _ in self.sorted_dependencies: + for dep in map(self.dependency_provider.get_dependant, self.extra_dependencies): dep.solve( *args, stack=stack, cache_dependencies=cache_dependencies, - dependency_overrides=dependency_overrides, nested=True, **kwargs, ) - # Always get from cache - for dep in self.extra_dependencies: - dep.solve( - *args, - stack=stack, - cache_dependencies=cache_dependencies, - dependency_overrides=dependency_overrides, - nested=True, - **kwargs, - ) - - for dep_arg, dep in self.dependencies.items(): - kwargs[dep_arg] = dep.solve( - stack=stack, - cache_dependencies=cache_dependencies, - dependency_overrides=dependency_overrides, - nested=True, - **kwargs, - ) + for dep_arg, dep_key in self.dependencies.items(): + if dep_arg not in kwargs: + kwargs[dep_arg] = self.dependency_provider.get_dependant(dep_key).solve( + *args, + stack=stack, + cache_dependencies=cache_dependencies, + nested=True, + **kwargs, + ) for custom in self.custom_fields.values(): if custom.field: @@ -389,117 +247,73 @@ def solve( else: kwargs = custom.use(**kwargs) - final_args, final_kwargs, call = cast_gen.send(kwargs) + final_args, final_kwargs = cast_gen.send(kwargs) if self.is_generator and nested: response = solve_generator_sync( *final_args, - call=call, + call=self.call, stack=stack, **final_kwargs, ) else: - response = call(*final_args, **final_kwargs) + response = self.call(*final_args, **final_kwargs) try: cast_gen.send(response) except StopIteration as e: - value: T = e.value + value = e.value - if not self.cast or nested or not self.is_generator: + if self.serializer is None or nested or not self.is_generator: return value else: - return map(self._cast_response, value) # type: ignore[no-any-return, call-overload] + return map(self._cast_response, value) raise AssertionError("unreachable") async def asolve( self, /, - *args: Tuple[Any, ...], + *args: tuple[Any, ...], stack: AsyncExitStack, - cache_dependencies: Dict[ - Union[ - Callable[P, T], - Callable[P, Awaitable[T]], - ], - T, - ], - dependency_overrides: Optional[ - Dict[ - Union[ - Callable[P, T], - Callable[P, Awaitable[T]], - ], - Union[ - Callable[P, T], - Callable[P, Awaitable[T]], - ], - ] - ] = None, + cache_dependencies: dict[Callable[..., Any], Any], nested: bool = False, - **kwargs: Dict[str, Any], - ) -> T: + **kwargs: dict[str, Any], + ) -> Any: cast_gen = self._solve( *args, cache_dependencies=cache_dependencies, - dependency_overrides=dependency_overrides, **kwargs, ) try: - args, kwargs, _ = next(cast_gen) + args, kwargs = next(cast_gen) except StopIteration as e: - cached_value: T = e.value + cached_value = e.value return cached_value - # Heat cache and solve extra dependencies - dep_to_solve: List[Callable[..., Awaitable[Any]]] = [] - try: - async with anyio.create_task_group() as tg: - for dep, subdep in self.sorted_dependencies: - solve = partial( - dep.asolve, - *args, - stack=stack, - cache_dependencies=cache_dependencies, - dependency_overrides=dependency_overrides, - nested=True, - **kwargs, - ) - if not subdep: - tg.start_soon(solve) - else: - dep_to_solve.append(solve) - except ExceptionGroup as exgr: - for ex in exgr.exceptions: - raise ex from None - - for i in dep_to_solve: - await i() - - # Always get from cache - for dep in self.extra_dependencies: + for dep in map(self.dependency_provider.get_dependant, self.extra_dependencies): + # TODO: run concurrently await dep.asolve( *args, stack=stack, cache_dependencies=cache_dependencies, - dependency_overrides=dependency_overrides, nested=True, **kwargs, ) - for dep_arg, dep in self.dependencies.items(): - kwargs[dep_arg] = await dep.asolve( - stack=stack, - cache_dependencies=cache_dependencies, - dependency_overrides=dependency_overrides, - nested=True, - **kwargs, - ) + for dep_arg, dep_key in self.dependencies.items(): + if dep_arg not in kwargs: + kwargs[dep_arg] = await self.dependency_provider.get_dependant(dep_key).asolve( + *args, + stack=stack, + cache_dependencies=cache_dependencies, + nested=True, + **kwargs, + ) - custom_to_solve: List[CustomField] = [] + custom_to_solve: list[CustomField] = [] try: async with anyio.create_task_group() as tg: @@ -516,60 +330,27 @@ async def asolve( for j in custom_to_solve: kwargs = await run_async(j.use, **kwargs) - final_args, final_kwargs, call = cast_gen.send(kwargs) + final_args, final_kwargs = cast_gen.send(kwargs) if self.is_generator and nested: response = await solve_generator_async( *final_args, - call=call, + call=self.call, stack=stack, **final_kwargs, ) else: - response = await run_async(call, *final_args, **final_kwargs) + response = await run_async(self.call, *final_args, **final_kwargs) try: cast_gen.send(response) except StopIteration as e: - value: T = e.value + value = e.value - if not self.cast or nested or not self.is_generator: + if self.serializer is None or nested or not self.is_generator: return value else: - return async_map(self._cast_response, value) # type: ignore[return-value, arg-type] + return async_map(self._cast_response, value) raise AssertionError("unreachable") - - -def _sort_dep( - collector: List["CallModel[..., Any]"], - items: Tuple[ - "CallModel[..., Any]", - Tuple[Callable[..., Any], ...], - ], - flat: Dict[ - Callable[..., Any], - Tuple[ - "CallModel[..., Any]", - Tuple[Callable[..., Any], ...], - ], - ], -) -> None: - model, calls = items - - if model in collector: - return - - if not calls: - position = -1 - - else: - for i in calls: - sub_model, _ = flat[i] - if sub_model not in collector: # pragma: no branch - _sort_dep(collector, flat[i], flat) - - position = max(collector.index(flat[i][0]) for i in calls) - - collector.insert(position + 1, model) diff --git a/fast_depends/dependencies/__init__.py b/fast_depends/dependencies/__init__.py index cfb43807..42a43498 100644 --- a/fast_depends/dependencies/__init__.py +++ b/fast_depends/dependencies/__init__.py @@ -1,8 +1,7 @@ -from fast_depends.dependencies.model import Depends -from fast_depends.dependencies.provider import Provider, dependency_provider +from .model import Dependant +from .provider import Provider __all__ = ( - "Depends", + "Dependant", "Provider", - "dependency_provider", ) diff --git a/fast_depends/dependencies/model.py b/fast_depends/dependencies/model.py index 2163f45f..ed8d3e14 100644 --- a/fast_depends/dependencies/model.py +++ b/fast_depends/dependencies/model.py @@ -1,7 +1,8 @@ +from inspect import unwrap from typing import Any, Callable -class Depends: +class Dependant: use_cache: bool cast: bool @@ -9,14 +10,17 @@ def __init__( self, dependency: Callable[..., Any], *, - use_cache: bool = True, - cast: bool = True, + use_cache: bool, + cast: bool, + cast_result: bool, ) -> None: self.dependency = dependency self.use_cache = use_cache self.cast = cast + self.cast_result = cast_result def __repr__(self) -> str: - attr = getattr(self.dependency, "__name__", type(self.dependency).__name__) + call = unwrap(self.dependency) + attr = getattr(call, "__name__", type(call).__name__) cache = "" if self.use_cache else ", use_cache=False" return f"{self.__class__.__name__}({attr}{cache})" diff --git a/fast_depends/dependencies/provider.py b/fast_depends/dependencies/provider.py index f3e2f0fc..4f2e3edb 100644 --- a/fast_depends/dependencies/provider.py +++ b/fast_depends/dependencies/provider.py @@ -1,22 +1,63 @@ +from collections.abc import Hashable, Iterator from contextlib import contextmanager -from typing import Any, Callable, Dict, Iterator +from typing import TYPE_CHECKING, Any, Callable + +from typing_extensions import TypeAlias + +from fast_depends.core import build_call_model + +if TYPE_CHECKING: + from fast_depends.core import CallModel + + +Key: TypeAlias = Hashable class Provider: - dependency_overrides: Dict[Callable[..., Any], Callable[..., Any]] + dependencies: dict[Key, "CallModel"] + overrides: dict[Key, "CallModel"] def __init__(self) -> None: - self.dependency_overrides = {} + self.dependencies = {} + self.overrides = {} def clear(self) -> None: - self.dependency_overrides = {} + self.dependencies = {} + self.overrides = {} + + def add_dependant( + self, + dependant: "CallModel", + ) -> Key: + key = self.__get_original_key(dependant.call) + self.dependencies[key] = dependant + return key + + def get_dependant(self, key: Key) -> "CallModel": + return self.overrides.get(key) or self.dependencies[key] def override( self, original: Callable[..., Any], override: Callable[..., Any], ) -> None: - self.dependency_overrides[original] = override + key = self.__get_original_key(original) + + override_model = build_call_model( + override, + dependency_provider=self, + ) + + if (original_dependant := self.dependencies.get(key)): + override_model.serializer = original_dependant.serializer + + else: + self.dependencies[key] = build_call_model( + original, + dependency_provider=self, + ) + + self.overrides[key] = override_model @contextmanager def scope( @@ -24,9 +65,9 @@ def scope( original: Callable[..., Any], override: Callable[..., Any], ) -> Iterator[None]: - self.dependency_overrides[original] = override + self.override(original, override) yield - self.dependency_overrides.pop(original, None) - + self.overrides.pop(self.__get_original_key(original), None) -dependency_provider = Provider() + def __get_original_key(self, original: Callable[..., Any]) -> Key: + return original diff --git a/fast_depends/exceptions.py b/fast_depends/exceptions.py new file mode 100644 index 00000000..f2d72562 --- /dev/null +++ b/fast_depends/exceptions.py @@ -0,0 +1,42 @@ +from collections.abc import Sequence +from typing import Any + +from fast_depends.library.serializer import OptionItem + + +class FastDependsError(Exception): + pass + + +class ValidationError(ValueError, FastDependsError): + def __init__( + self, + *, + incoming_options: Any, + locations: Sequence[Any], + expected: dict[str, OptionItem], + original_error: Exception, + ) -> None: + self.original_error = original_error + self.incoming_options = incoming_options + + self.error_fields: tuple[OptionItem, ...] = tuple( + expected[x] for x in locations if x in expected + ) + if not self.error_fields: + self.error_fields = tuple(expected.values()) + + super().__init__() + + def __str__(self) -> str: + if isinstance(self.incoming_options, dict): + content = ", ".join(f"{k}=`{v}`" for k, v in self.incoming_options.items()) + else: + content = f"`{self.incoming_options}`" + + return ( + "\n Incoming options: " + + content + + "\n In the following option types error occured:\n " + + "\n ".join(map(str, self.error_fields)) + ) diff --git a/fast_depends/library/__init__.py b/fast_depends/library/__init__.py index 4cb1f427..474a66bd 100644 --- a/fast_depends/library/__init__.py +++ b/fast_depends/library/__init__.py @@ -1,3 +1,4 @@ from fast_depends.library.model import CustomField +from fast_depends.library.serializer import Serializer -__all__ = ("CustomField",) +__all__ = ("CustomField", "Serializer",) diff --git a/fast_depends/library/model.py b/fast_depends/library/model.py index 8b18ea46..d448592a 100644 --- a/fast_depends/library/model.py +++ b/fast_depends/library/model.py @@ -1,5 +1,5 @@ from abc import ABC -from typing import Any, Dict, Optional, TypeVar +from typing import Any, Optional, TypeVar Cls = TypeVar("Cls", bound="CustomField") @@ -31,9 +31,12 @@ def set_param_name(self: Cls, name: str) -> Cls: self.param_name = name return self - def use(self, /, **kwargs: Any) -> Dict[str, Any]: + def use(self, /, **kwargs: Any) -> dict[str, Any]: assert self.param_name, "You should specify `param_name` before using" return kwargs - def use_field(self, kwargs: Dict[str, Any]) -> None: - raise NotImplementedError("You should implement `use_field` method.") + def use_field(self, kwargs: dict[str, Any]) -> None: + raise NotImplementedError + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(required={self.required}, cast={self.cast})" diff --git a/fast_depends/library/serializer.py b/fast_depends/library/serializer.py new file mode 100644 index 00000000..021f1c06 --- /dev/null +++ b/fast_depends/library/serializer.py @@ -0,0 +1,71 @@ +from abc import ABC, abstractmethod +from typing import Any, Protocol + + +class OptionItem: + __slots__ = ( + "field_name", + "field_type", + "default_value", + "source", + ) + + def __init__( + self, + field_name: str, + field_type: Any, + source: Any = None, + default_value: Any = ..., + ) -> None: + self.field_name = field_name + self.field_type = field_type + self.default_value = default_value + self.source = source + + def __repr__(self) -> str: + type_name = getattr(self.field_type, "__name__", str(self.field_type)) + content = f"{self.field_name}, type=`{type_name}`" + if self.default_value is not Ellipsis: + content = f"{content}, default=`{self.default_value}`" + if self.source: + content = f"{content}, source=`{self.source}`" + return f"OptionItem[{content}]" + + +class Serializer(ABC): + def __init__( + self, + *, + name: str, + options: list[OptionItem], + response_type: Any, + ): + self.name = name + self.options = { + i.field_name: i for i in options + } + self.response_option = { + "return": OptionItem(field_name="return", field_type=response_type), + } + + + @abstractmethod + def __call__(self, call_kwargs: dict[str, Any]) -> dict[str, Any]: + raise NotImplementedError + + def get_aliases(self) -> tuple[str, ...]: + return () + + def response(self, value: Any) -> Any: + return value + + +class SerializerProto(Protocol): + def __call__( + self, + *, + name: str, + options: list[OptionItem], + response_type: Any, + ) -> Serializer: + ... diff --git a/fast_depends/msgspec/__init__.py b/fast_depends/msgspec/__init__.py new file mode 100644 index 00000000..dfe18f36 --- /dev/null +++ b/fast_depends/msgspec/__init__.py @@ -0,0 +1,5 @@ +from fast_depends.msgspec.serializer import MsgSpecSerializer + +__all__ = ( + "MsgSpecSerializer", +) diff --git a/fast_depends/msgspec/serializer.py b/fast_depends/msgspec/serializer.py new file mode 100644 index 00000000..88161795 --- /dev/null +++ b/fast_depends/msgspec/serializer.py @@ -0,0 +1,95 @@ +import inspect +import re +from collections.abc import Iterator, Sequence +from contextlib import contextmanager +from typing import Any + +import msgspec + +from fast_depends.exceptions import ValidationError +from fast_depends.library.serializer import OptionItem, Serializer + + +class MsgSpecSerializer(Serializer): + __slots__ = ("aliases", "model", "response_type", "name", "options", "response_option",) + + def __init__( + self, + *, + name: str, + options: list[OptionItem], + response_type: Any, + ): + model_options = [] + aliases = {} + for i in options: + if isinstance(msgspec.inspect.type_info(i.field_type), msgspec.inspect.CustomType): + continue + + default_value = i.default_value + + if isinstance(default_value, msgspec._core.Field) and default_value.name: + aliases[i.field_name] = default_value.name + else: + aliases[i.field_name] = i.field_name + + if default_value is Ellipsis: + model_options.append(( + i.field_name, + i.field_type, + )) + else: + model_options.append(( + i.field_name, + i.field_type, + default_value, + )) + + self.aliases = aliases + self.model = msgspec.defstruct(name, model_options, kw_only=True) + + if response_type is not inspect.Parameter.empty: + self.response_type = response_type + else: + self.response_type = None + + super().__init__(name=name, options=options, response_type=response_type) + + def __call__(self, call_kwargs: dict[str, Any]) -> dict[str, Any]: + with self._try_msgspec(call_kwargs, self.options): + casted_model = msgspec.convert( + call_kwargs, + type=self.model, + strict=False, + str_keys=True, + ) + + return { + out_field: getattr(casted_model, out_field, None) + for out_field in self.aliases.keys() + } + + def get_aliases(self) -> tuple[str, ...]: + return tuple(self.aliases.values()) + + def response(self, value: Any) -> Any: + if self.response_type is not None: + with self._try_msgspec(value, self.response_option, ("return",)): + return msgspec.convert(value, type=self.response_type, strict=False) + return value + + @property + def json_schema(self) -> dict[str, Any]: + return msgspec.json.schema(self.model) + + @contextmanager + def _try_msgspec(self, call_kwargs: Any, options: dict[str, OptionItem], locations: Sequence[str] = (),) -> Iterator[None]: + try: + yield + except msgspec.ValidationError as er: + raise ValidationError( + incoming_options=call_kwargs, + expected=options, + locations=locations or re.findall(r"at `\$\.(.)`", str(er.args)), + original_error=er, + ) from er diff --git a/fast_depends/pydantic/__init__.py b/fast_depends/pydantic/__init__.py new file mode 100644 index 00000000..89891385 --- /dev/null +++ b/fast_depends/pydantic/__init__.py @@ -0,0 +1,5 @@ +from fast_depends.pydantic.serializer import PydanticSerializer + +__all__ = ( + "PydanticSerializer", +) diff --git a/fast_depends/pydantic/_compat.py b/fast_depends/pydantic/_compat.py new file mode 100644 index 00000000..7369d2c0 --- /dev/null +++ b/fast_depends/pydantic/_compat.py @@ -0,0 +1,56 @@ +from typing import Any, Optional + +from pydantic import BaseModel, create_model +from pydantic.version import VERSION as PYDANTIC_VERSION + +__all__ = ( + "BaseModel", + "create_model", + "PYDANTIC_V2", + "get_config_base", + "ConfigDict", + "TypeAdapter", + "PydanticUserError", +) + + +PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.") + +default_pydantic_config = {"arbitrary_types_allowed": True} + +# isort: off +if PYDANTIC_V2: + from pydantic import ConfigDict, TypeAdapter + from pydantic.fields import FieldInfo + from pydantic.errors import PydanticUserError + + def model_schema(model: type[BaseModel]) -> dict[str, Any]: + return model.model_json_schema() + + def get_config_base(config_data: Optional[ConfigDict] = None) -> ConfigDict: + return config_data or ConfigDict(**default_pydantic_config) # type: ignore[typeddict-item] + + def get_aliases(model: type[BaseModel]) -> tuple[str, ...]: + return tuple(f.alias or name for name, f in model.model_fields.items()) + + def get_model_fields(model: type[BaseModel]) -> dict[str, FieldInfo]: + return model.model_fields + +else: + from pydantic.config import get_config, ConfigDict, BaseConfig + from pydantic.fields import ModelField + + TypeAdapter = None + PydanticUserError = Exception + + def get_config_base(config_data: Optional[ConfigDict] = None) -> type[BaseConfig]: # type: ignore[misc] + return get_config(config_data or ConfigDict(**default_pydantic_config)) # type: ignore[typeddict-item] + + def model_schema(model: type[BaseModel]) -> dict[str, Any]: + return model.schema() + + def get_aliases(model: type[BaseModel]) -> tuple[str, ...]: + return tuple(f.alias or name for name, f in model.__fields__.items()) + + def get_model_fields(model: type[BaseModel]) -> dict[str, ModelField]: + return model.__fields__ diff --git a/fast_depends/schema.py b/fast_depends/pydantic/schema.py similarity index 62% rename from fast_depends/schema.py rename to fast_depends/pydantic/schema.py index 2f065ef9..a2940b03 100644 --- a/fast_depends/schema.py +++ b/fast_depends/pydantic/schema.py @@ -1,24 +1,30 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Optional -from fast_depends._compat import PYDANTIC_V2, create_model, model_schema from fast_depends.core import CallModel +from fast_depends.pydantic._compat import PYDANTIC_V2, create_model, model_schema def get_schema( - call: CallModel[Any, Any], + call: CallModel, embed: bool = False, resolve_refs: bool = False, -) -> Dict[str, Any]: - assert call.model, "Call should has a model" - params_model = create_model( # type: ignore[call-overload] - call.model.__name__, - **call.flat_params - ) +) -> dict[str, Any]: + class_options: dict[str, Any] = { + i.field_name: (i.field_type, i.default_value) + for i in call.flat_params + } + + name = getattr(call.serializer, "name", "Undefined") - body: Dict[str, Any] = model_schema(params_model) + if not class_options: + return {"title": name, "type": "null"} + + params_model = create_model( + name, + **class_options + ) - if not call.flat_params: - body = {"title": body["title"], "type": "null"} + body = model_schema(params_model) if resolve_refs: pydantic_key = "$defs" if PYDANTIC_V2 else "definitions" @@ -34,9 +40,9 @@ def get_schema( def _move_pydantic_refs( original: Any, key: str, - refs: Optional[Dict[str, Any]] = None + refs: Optional[dict[str, Any]] = None ) -> Any: - if not isinstance(original, Dict): + if not isinstance(original, dict): return original data = original.copy() @@ -53,7 +59,7 @@ def _move_pydantic_refs( elif isinstance(data[k], dict): data[k] = _move_pydantic_refs(data[k], key, refs) - elif isinstance(data[k], List): + elif isinstance(data[k], list): for i in range(len(data[k])): data[k][i] = _move_pydantic_refs(data[k][i], key, refs) diff --git a/fast_depends/pydantic/serializer.py b/fast_depends/pydantic/serializer.py new file mode 100644 index 00000000..5c9c5277 --- /dev/null +++ b/fast_depends/pydantic/serializer.py @@ -0,0 +1,130 @@ +import inspect +from collections.abc import Iterator, Sequence +from contextlib import contextmanager +from itertools import chain +from typing import Any, Callable, Optional + +from pydantic import ValidationError as PValidationError + +from fast_depends.exceptions import ValidationError +from fast_depends.library.serializer import OptionItem, Serializer, SerializerProto +from fast_depends.pydantic._compat import ( + PYDANTIC_V2, + BaseModel, + ConfigDict, + PydanticUserError, + TypeAdapter, + create_model, + get_aliases, + get_config_base, + get_model_fields, +) + + +class PydanticSerializer(SerializerProto): + __slots__ = ("pydantic_config",) + + def __init__(self, pydantic_config: Optional[ConfigDict] = None) -> None: + self.config = pydantic_config + + def __call__( + self, + *, + name: str, + options: list[OptionItem], + response_type: Any, + ) -> "_PydanticSerializer": + return _PydanticSerializer( + name=name, + options=options, + response_type=response_type, + pydantic_config=self.config, + ) + + +class _PydanticSerializer(Serializer): + __slots__ = ("model", "response_callback", "name", "options", "response_option",) + + def __init__( + self, + *, + name: str, + options: list[OptionItem], + response_type: Any, + pydantic_config: Optional[ConfigDict] = None, + ): + class_options: dict[str, Any] = { + i.field_name: (i.field_type, i.default_value) + for i in options + } + + config = get_config_base(pydantic_config) + + self.model = create_model( + name, + __config__=config, + **class_options, + ) + + self.response_callback: Optional[Callable[[Any], Any]] = None + + if response_type is not inspect.Parameter.empty: + try: + is_model = issubclass(response_type or object, BaseModel) + except Exception: + is_model = False + + if is_model: + if PYDANTIC_V2: + self.response_callback = response_type.model_validate + else: + self.response_callback = response_type.validate + + elif PYDANTIC_V2: + try: + response_pydantic_type = TypeAdapter(response_type, config=config) + except PydanticUserError: + pass + else: + self.response_callback = response_pydantic_type.validate_python + + if self.response_callback is None and not (response_type is None and not PYDANTIC_V2): + response_model = create_model( + "ResponseModel", + __config__=config, + r=(response_type or Any, ...), + ) + + self.response_callback = lambda x: response_model(r=x).r # type: ignore[attr-defined] + + super().__init__(name=name, options=options, response_type=response_type) + + def __call__(self, call_kwargs: dict[str, Any]) -> dict[str, Any]: + with self._try_pydantic(call_kwargs, self.options): + casted_model = self.model(**call_kwargs) + + return { + i: getattr(casted_model, i) + for i in get_model_fields(casted_model).keys() + } + + def get_aliases(self) -> tuple[str, ...]: + return get_aliases(self.model) + + def response(self, value: Any) -> Any: + if self.response_callback is not None: + with self._try_pydantic(value, self.response_option, ("return",)): + return self.response_callback(value) + return value + + @contextmanager + def _try_pydantic(self, call_kwargs: Any, options: dict[str, OptionItem], locations: Sequence[str] = (),) -> Iterator[None]: + try: + yield + except PValidationError as er: + raise ValidationError( + incoming_options=call_kwargs, + expected=options, + locations=locations or tuple(chain(*(one_error["loc"] for one_error in er.errors()))), + original_error=er, + ) from er diff --git a/fast_depends/use.py b/fast_depends/use.py index d30728f1..b05ebf24 100644 --- a/fast_depends/use.py +++ b/fast_depends/use.py @@ -1,93 +1,129 @@ +from collections.abc import AsyncIterator, Iterator, Sequence from contextlib import AsyncExitStack, ExitStack from functools import partial, wraps from typing import ( + TYPE_CHECKING, Any, - AsyncIterator, Callable, - Iterator, Optional, - Sequence, + Protocol, + TypeVar, Union, cast, overload, ) -from typing_extensions import ParamSpec, Protocol, TypeVar +from typing_extensions import Literal, ParamSpec -from fast_depends._compat import ConfigDict from fast_depends.core import CallModel, build_call_model -from fast_depends.dependencies import dependency_provider, model +from fast_depends.dependencies import Dependant, Provider +from fast_depends.library.serializer import SerializerProto + +SerializerCls: Optional["SerializerProto"] = None + +if SerializerCls is None: + try: + from fast_depends.pydantic import PydanticSerializer + SerializerCls = PydanticSerializer() + except ImportError: + pass + +if SerializerCls is None: + try: + from fast_depends.msgspec import MsgSpecSerializer + SerializerCls = MsgSpecSerializer + except ImportError: + pass + P = ParamSpec("P") T = TypeVar("T") +if TYPE_CHECKING: + from fast_depends.library.serializer import SerializerProto + + class InjectWrapper(Protocol): + def __call__( + self, + func: Callable[..., T], + model: Optional[CallModel] = None, + ) -> Callable[..., T]: + ... + def Depends( - dependency: Callable[P, T], + dependency: Callable[..., Any], *, use_cache: bool = True, cast: bool = True, + cast_result: bool = False, ) -> Any: - return model.Depends( + return Dependant( dependency=dependency, use_cache=use_cache, cast=cast, + cast_result=cast_result, ) -class _InjectWrapper(Protocol[P, T]): - def __call__( - self, - func: Callable[P, T], - model: Optional[CallModel[P, T]] = None, - ) -> Callable[P, T]: - ... - - @overload -def inject( # pragma: no cover - func: None, +def inject( + func: Callable[..., T], *, cast: bool = True, - extra_dependencies: Sequence[model.Depends] = (), - pydantic_config: Optional[ConfigDict] = None, - dependency_overrides_provider: Optional[Any] = dependency_provider, - wrap_model: Callable[[CallModel[P, T]], CallModel[P, T]] = lambda x: x, -) -> _InjectWrapper[P, T]: + cast_result: bool = True, + extra_dependencies: Sequence[Dependant] = (), + dependency_provider: Optional["Provider"] = None, + wrap_model: Callable[["CallModel"], "CallModel"] = lambda x: x, + serializer_cls: Optional["SerializerProto"] = SerializerCls, + **call_extra: Any, +) -> Callable[..., T]: ... - @overload -def inject( # pragma: no cover - func: Callable[P, T], +def inject( + func: Literal[None] = None, *, cast: bool = True, - extra_dependencies: Sequence[model.Depends] = (), - pydantic_config: Optional[ConfigDict] = None, - dependency_overrides_provider: Optional[Any] = dependency_provider, - wrap_model: Callable[[CallModel[P, T]], CallModel[P, T]] = lambda x: x, -) -> Callable[P, T]: + cast_result: bool = True, + extra_dependencies: Sequence[Dependant] = (), + dependency_provider: Optional["Provider"] = None, + wrap_model: Callable[["CallModel"], "CallModel"] = lambda x: x, + serializer_cls: Optional["SerializerProto"] = SerializerCls, + **call_extra: Any, +) -> "InjectWrapper": ... - def inject( - func: Optional[Callable[P, T]] = None, + func: Optional[Callable[..., T]] = None, *, cast: bool = True, - extra_dependencies: Sequence[model.Depends] = (), - pydantic_config: Optional[ConfigDict] = None, - dependency_overrides_provider: Optional[Any] = dependency_provider, - wrap_model: Callable[[CallModel[P, T]], CallModel[P, T]] = lambda x: x, + cast_result: bool = True, + extra_dependencies: Sequence[Dependant] = (), + dependency_provider: Optional["Provider"] = None, + wrap_model: Callable[["CallModel"], "CallModel"] = lambda x: x, + serializer_cls: Optional["SerializerProto"] = SerializerCls, + **call_extra: Any, ) -> Union[ - Callable[P, T], - _InjectWrapper[P, T], + Callable[..., T], + Callable[ + [Callable[..., T]], + Callable[..., T] + ], ]: + if dependency_provider is None: + dependency_provider = Provider() + + if not cast: + serializer_cls = None + decorator = _wrap_inject( - dependency_overrides_provider=dependency_overrides_provider, + dependency_provider=dependency_provider, wrap_model=wrap_model, extra_dependencies=extra_dependencies, - cast=cast, - pydantic_config=pydantic_config, + serializer_cls=serializer_cls, + cast_result=cast_result, + **call_extra, ) if func is None: @@ -98,103 +134,95 @@ def inject( def _wrap_inject( - dependency_overrides_provider: Optional[Any], - wrap_model: Callable[ - [CallModel[P, T]], - CallModel[P, T], - ], - extra_dependencies: Sequence[model.Depends], - cast: bool, - pydantic_config: Optional[ConfigDict], -) -> _InjectWrapper[P, T]: - if ( - dependency_overrides_provider - and getattr(dependency_overrides_provider, "dependency_overrides", None) - is not None - ): - overrides = dependency_overrides_provider.dependency_overrides - else: - overrides = None - + *, + dependency_provider: "Provider", + wrap_model: Callable[["CallModel"], "CallModel"], + extra_dependencies: Sequence[Dependant], + serializer_cls: Optional["SerializerProto"], + cast_result: bool, + **call_extra: Any, +) -> Callable[ + [Callable[P, T]], + Callable[..., T] +]: def func_wrapper( func: Callable[P, T], - model: Optional[CallModel[P, T]] = None, - ) -> Callable[P, T]: + model: Optional["CallModel"] = None, + ) -> Callable[..., T]: if model is None: real_model = wrap_model( build_call_model( call=func, extra_dependencies=extra_dependencies, - cast=cast, - pydantic_config=pydantic_config, + dependency_provider=dependency_provider, + serializer_cls=serializer_cls, + serialize_result=cast_result, ) ) else: real_model = model if real_model.is_async: - injected_wrapper: Callable[P, T] + injected_wrapper: Callable[..., T] if real_model.is_generator: - injected_wrapper = partial(solve_async_gen, real_model, overrides) # type: ignore[assignment] + injected_wrapper = partial( # type: ignore[assignment] + solve_async_gen, + real_model, + ) else: - @wraps(func) - async def injected_wrapper(*args: P.args, **kwargs: P.kwargs) -> T: + async def injected_wrapper(*args: P.args, **kwargs: P.kwargs) -> T: # type: ignore[misc] async with AsyncExitStack() as stack: - r = await real_model.asolve( + return await real_model.asolve( # type: ignore[no-any-return] *args, stack=stack, - dependency_overrides=overrides, cache_dependencies={}, nested=False, - **kwargs, + **(call_extra | kwargs), ) - return r raise AssertionError("unreachable") else: if real_model.is_generator: - injected_wrapper = partial(solve_gen, real_model, overrides) # type: ignore[assignment] + injected_wrapper = partial( # type: ignore[assignment] + solve_gen, + real_model, + ) else: - @wraps(func) def injected_wrapper(*args: P.args, **kwargs: P.kwargs) -> T: with ExitStack() as stack: - r = real_model.solve( + return real_model.solve( # type: ignore[no-any-return] *args, stack=stack, - dependency_overrides=overrides, cache_dependencies={}, nested=False, - **kwargs, + **(call_extra | kwargs), ) - return r raise AssertionError("unreachable") - return injected_wrapper + return wraps(func)(injected_wrapper) return func_wrapper class solve_async_gen: - _iter: Optional[AsyncIterator[Any]] + _iter: Optional[AsyncIterator[Any]] = None def __init__( self, - model: "CallModel[..., Any]", - overrides: Optional[Any], + model: "CallModel", *args: Any, **kwargs: Any, ): self.call = model self.args = args self.kwargs = kwargs - self.overrides = overrides def __aiter__(self) -> "solve_async_gen": self._iter = None @@ -211,7 +239,6 @@ async def __anext__(self) -> Any: await self.call.asolve( *self.args, stack=stack, - dependency_overrides=self.overrides, cache_dependencies={}, nested=False, **self.kwargs, @@ -221,27 +248,25 @@ async def __anext__(self) -> Any: try: r = await self._iter.__anext__() - except StopAsyncIteration as e: + except StopAsyncIteration: await self.stack.__aexit__(None, None, None) - raise e + raise else: return r class solve_gen: - _iter: Optional[Iterator[Any]] + _iter: Optional[Iterator[Any]] = None def __init__( self, - model: "CallModel[..., Any]", - overrides: Optional[Any], + model: "CallModel", *args: Any, **kwargs: Any, ): self.call = model self.args = args self.kwargs = kwargs - self.overrides = overrides def __iter__(self) -> "solve_gen": self._iter = None @@ -258,7 +283,6 @@ def __next__(self) -> Any: self.call.solve( *self.args, stack=stack, - dependency_overrides=self.overrides, cache_dependencies={}, nested=False, **self.kwargs, @@ -268,8 +292,8 @@ def __next__(self) -> Any: try: r = next(self._iter) - except StopIteration as e: + except StopIteration: self.stack.__exit__(None, None, None) - raise e + raise else: return r diff --git a/fast_depends/utils.py b/fast_depends/utils.py index 2d19e891..9d3eb72b 100644 --- a/fast_depends/utils.py +++ b/fast_depends/utils.py @@ -1,28 +1,28 @@ import asyncio import functools import inspect -from contextlib import AsyncExitStack, ExitStack, asynccontextmanager, contextmanager +from collections.abc import AsyncGenerator, AsyncIterable, Awaitable +from contextlib import ( + AbstractContextManager, + AsyncExitStack, + ExitStack, + asynccontextmanager, + contextmanager, +) from typing import ( TYPE_CHECKING, + Annotated, Any, - AsyncGenerator, - AsyncIterable, - Awaitable, Callable, - ContextManager, - Dict, ForwardRef, - List, - Tuple, + TypeVar, Union, cast, ) import anyio from typing_extensions import ( - Annotated, ParamSpec, - TypeVar, get_args, get_origin, ) @@ -51,7 +51,7 @@ async def run_async( async def run_in_threadpool( - func: Callable[P, T], *args: P.args, **kwargs: P.kwargs + func: Callable[P, T], *args: P.args, **kwargs: P.kwargs, ) -> T: if kwargs: func = functools.partial(func, **kwargs) @@ -59,7 +59,7 @@ async def run_in_threadpool( async def solve_generator_async( - *sub_args: Any, call: Callable[..., Any], stack: AsyncExitStack, **sub_values: Any + *sub_args: Any, call: Callable[..., Any], stack: AsyncExitStack, **sub_values: Any, ) -> Any: if is_gen_callable(call): cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values)) @@ -69,17 +69,20 @@ async def solve_generator_async( def solve_generator_sync( - *sub_args: Any, call: Callable[..., Any], stack: ExitStack, **sub_values: Any + *sub_args: Any, call: Callable[..., Any], stack: ExitStack, **sub_values: Any, ) -> Any: cm = contextmanager(call)(*sub_args, **sub_values) return stack.enter_context(cm) -def get_typed_signature(call: Callable[..., Any]) -> Tuple[inspect.Signature, Any]: +def get_typed_signature(call: Callable[..., Any]) -> tuple[inspect.Signature, Any]: signature = inspect.signature(call) locals = collect_outer_stack_locals() + # We unwrap call to get the original unwrapped function + call = inspect.unwrap(call) + globalns = getattr(call, "__globals__", {}) typed_params = [ inspect.Parameter( @@ -102,10 +105,10 @@ def get_typed_signature(call: Callable[..., Any]) -> Tuple[inspect.Signature, An ) -def collect_outer_stack_locals() -> Dict[str, Any]: +def collect_outer_stack_locals() -> dict[str, Any]: frame = inspect.currentframe() - frames: List["FrameType"] = [] + frames: list[FrameType] = [] while frame is not None: if "fast_depends" not in frame.f_code.co_filename: frames.append(frame) @@ -120,8 +123,8 @@ def collect_outer_stack_locals() -> Dict[str, Any]: def get_typed_annotation( annotation: Any, - globalns: Dict[str, Any], - locals: Dict[str, Any], + globalns: dict[str, Any], + locals: dict[str, Any], ) -> Any: if isinstance(annotation, str): annotation = ForwardRef(annotation) @@ -141,7 +144,7 @@ def get_typed_annotation( @asynccontextmanager async def contextmanager_in_threadpool( - cm: ContextManager[T], + cm: AbstractContextManager[T], ) -> AsyncGenerator[T, None]: exit_limiter = anyio.CapacityLimiter(1) try: diff --git a/pyproject.toml b/pyproject.toml index 0250790c..696bff49 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,7 @@ authors = [ keywords = ["fastapi", "dependency injection"] -requires-python = ">=3.8" +requires-python = ">=3.9" classifiers = [ "Development Status :: 5 - Production/Stable", @@ -21,11 +21,11 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3 :: Only", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", "Operating System :: OS Independent", "Topic :: Software Development :: Libraries :: Python Modules", "Topic :: Software Development :: Libraries", @@ -41,8 +41,17 @@ classifiers = [ dynamic = ["version"] dependencies = [ - "pydantic>=1.7.4,!=1.8,!=1.8.1,<3.0.0", "anyio>=3.0.0,<5.0.0", + "typing-extensions!=4.12.1", +] + +[project.optional-dependencies] +pydantic = [ + "pydantic>=1.7.4,!=1.8,!=1.8.1,<3.0.0", +] + +msgspec = [ + "msgspec", ] [project.urls] @@ -68,12 +77,16 @@ exclude = [ [tool.mypy] strict = true ignore_missing_imports = true +python_version = "3.9" [tool.isort] profile = "black" known_third_party = ["pydantic", "anyio"] [tool.ruff] +target-version = "py39" + +[tool.ruff.lint] select = [ "E", # pycodestyle errors https://docs.astral.sh/ruff/rules/#error-e "W", # pycodestyle warnings https://docs.astral.sh/ruff/rules/#warning-w @@ -91,7 +104,7 @@ ignore = [ "C901", # too complex ] -[tool.ruff.flake8-bugbear] +[tool.ruff.lint.flake8-bugbear] extend-immutable-calls = [ "fast_depends.Depends", "AsyncHeader", "Header", "MyDep", diff --git a/requirements.dev.txt b/requirements.dev.txt index 78ae4469..9bfd390c 100644 --- a/requirements.dev.txt +++ b/requirements.dev.txt @@ -1,4 +1,4 @@ --e . +-e .[pydantic] -r requirements.docs.txt -r requirements.test.txt diff --git a/requirements.test.txt b/requirements.test.txt index c77973ae..b21c3d58 100644 --- a/requirements.test.txt +++ b/requirements.test.txt @@ -2,4 +2,6 @@ coverage[toml] >=7.2.0,<8.0.0 pytest >=8.0.0,<9 -dirty-equals >=0.7.0,<0.8 \ No newline at end of file +dirty-equals >=0.7.0,<0.9 + +annotated_types diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/async/test_cast.py b/tests/async/test_cast.py index f6ac5ae0..cf06bc0d 100644 --- a/tests/async/test_cast.py +++ b/tests/async/test_cast.py @@ -1,12 +1,10 @@ -from typing import Dict, Iterator, Tuple +from collections.abc import Iterator import pytest -from annotated_types import Ge -from pydantic import BaseModel, Field, ValidationError -from typing_extensions import Annotated from fast_depends import inject -from tests.marks import pydanticV2 +from fast_depends.exceptions import ValidationError +from tests.marks import serializer @pytest.mark.anyio @@ -18,16 +16,6 @@ async def some_func(a, b): assert isinstance(await some_func("1", "2"), str) -@pytest.mark.anyio -async def test_annotated_partial(): - @inject - async def some_func(a, b: int): - assert isinstance(b, int) - return a + b - - assert isinstance(await some_func(1, "2"), int) - - @pytest.mark.anyio async def test_arbitrary_args(): class ArbitraryType: @@ -55,85 +43,12 @@ async def some_func(a: ArbitraryType) -> ArbitraryType: @pytest.mark.anyio -async def test_types_casting(): - @inject - async def some_func(a: int, b: int) -> float: - assert isinstance(a, int) - assert isinstance(b, int) - r = a + b - assert isinstance(r, int) - return r - - assert isinstance(await some_func("1", "2"), float) - - -@pytest.mark.anyio -async def test_types_casting_from_str(): - @inject - async def some_func(a: "int") -> float: - return a - - assert isinstance(await some_func("1"), float) - - -@pytest.mark.anyio -async def test_pydantic_types_casting(): - class SomeModel(BaseModel): - field: int - - @inject - async def some_func(a: SomeModel): - return a.field - - assert isinstance(await some_func({"field": "31"}), int) - - -@pytest.mark.anyio -async def test_pydantic_field_types_casting(): - @inject - async def some_func(a: int = Field(..., alias="b")) -> float: - assert isinstance(a, int) - return a - - @inject - async def another_func(a=Field(..., alias="b")) -> float: - assert isinstance(a, str) - return a - - assert isinstance(await some_func(b="2", c=3), float) - assert isinstance(await another_func(b="2"), float) - - -@pytest.mark.anyio -async def test_wrong_incoming_types(): +async def test_args(): @inject - async def some_func(a: int): # pragma: no cover - return a + async def some_func(a, *ar): + return a, ar - with pytest.raises(ValidationError): - await some_func({"key", 1}) - - -@pytest.mark.anyio -async def test_wrong_return_types(): - @inject - async def some_func(a: int) -> dict: - return a - - with pytest.raises(ValidationError): - await some_func("2") - - -@pytest.mark.anyio -async def test_annotated(): - A = Annotated[int, Field(..., alias="b")] - - @inject - async def some_func(a: A) -> float: - assert isinstance(a, int) - return a - - assert isinstance(await some_func(b="2"), float) + assert (1, (2,)) == await some_func(1, 2) @pytest.mark.anyio @@ -141,9 +56,9 @@ async def test_args_kwargs_1(): @inject async def simple_func( a: int, - *args: Tuple[float, ...], + *args: tuple[float, ...], b: int, - **kwargs: Dict[str, int], + **kwargs: dict[str, int], ): return a, args, b, kwargs @@ -157,7 +72,7 @@ async def test_args_kwargs_2(): @inject async def simple_func( a: int, - *args: Tuple[float, ...], + *args: tuple[float, ...], b: int, ): return a, args, b @@ -183,37 +98,107 @@ async def simple_func(a: int, *, b: int): @pytest.mark.anyio -async def test_generator(): +async def test_args_kwargs_4(): @inject - async def simple_func(a: str) -> int: - for _ in range(2): - yield a + async def simple_func( + *args: tuple[float, ...], + **kwargs: dict[str, int], + ): + return args, kwargs - async for i in simple_func("1"): - assert i == 1 + assert ( + (1.0, 2.0, 3.0), + { + "key": 1, + "b": 3, + }, + ) == await simple_func(1.0, 2.0, 3, b=3.0, key=1.0) @pytest.mark.anyio -async def test_generator_iterator_type(): +async def test_args_kwargs_5(): @inject - async def simple_func(a: str) -> Iterator[int]: - for _ in range(2): - yield a - - async for i in simple_func("1"): - assert i == 1 - - -@pytest.mark.anyio -@pydanticV2 -async def test_multi_annotated(): - from pydantic.functional_validators import AfterValidator - - @inject() - async def f(a: Annotated[int, Ge(10), AfterValidator(lambda x: x + 10)]) -> int: - return a - - with pytest.raises(ValidationError): - await f(1) - - assert await f(10) == 20 + async def simple_func( + *a: tuple[float, ...], + **kw: dict[str, int], + ): + return a, kw + + assert ( + (1.0, 2.0, 3.0), + { + "key": 1, + "b": 3, + }, + ) == await simple_func(1.0, 2.0, 3, b=3.0, key=1.0) + + +@serializer +@pytest.mark.anyio +class TestSerialization: + async def test_no_cast_result(self): + @inject(cast_result=False) + async def some_func(a: int, b: int) -> str: + return a + b + + assert await some_func("1", "2") == 3 + + async def test_annotated_partial(self): + @inject + async def some_func(a, b: int): + assert isinstance(b, int) + return a + b + + assert isinstance(await some_func(1, "2"), int) + + async def test_types_casting(self): + @inject + async def some_func(a: int, b: int) -> float: + assert isinstance(a, int) + assert isinstance(b, int) + r = a + b + assert isinstance(r, int) + return r + + assert isinstance(await some_func("1", "2"), float) + + async def test_types_casting_from_str(self): + @inject + async def some_func(a: "int") -> float: + return a + + assert isinstance(await some_func("1"), float) + + async def test_wrong_incoming_types(self): + @inject + async def some_func(a: int): # pragma: no cover + return a + + with pytest.raises(ValidationError): + await some_func({"key", 1}) + + async def test_wrong_return_types(self): + @inject + async def some_func(a: int) -> dict: + return a + + with pytest.raises(ValidationError): + await some_func("2") + + async def test_generator(self): + @inject + async def simple_func(a: str) -> int: + for _ in range(2): + yield a + + async for i in simple_func("1"): + assert i == 1 + + async def test_generator_iterator_type(self): + @inject + async def simple_func(a: str) -> Iterator[int]: + for _ in range(2): + yield a + + async for i in simple_func("1"): + assert i == 1 diff --git a/tests/async/test_depends.py b/tests/async/test_depends.py index 753c0a0e..aa0b0eb2 100644 --- a/tests/async/test_depends.py +++ b/tests/async/test_depends.py @@ -1,129 +1,108 @@ import logging +from contextlib import asynccontextmanager from dataclasses import dataclass from functools import partial +from typing import Annotated from unittest.mock import Mock import pytest -from pydantic import ValidationError -from typing_extensions import Annotated from fast_depends import Depends, inject +from fast_depends.exceptions import ValidationError +from tests.marks import serializer @pytest.mark.anyio async def test_depends(): - async def dep_func(b: int, a: int = 3) -> float: + async def dep_func(b, a=3): return a + b @inject - async def some_func(b: int, c=Depends(dep_func)) -> int: - assert isinstance(c, float) + async def some_func(b: int, c=Depends(dep_func)): return b + c - assert (await some_func("2")) == 7 + assert (await some_func(2)) == 7 @pytest.mark.anyio -async def test_empty_main_body(): - async def dep_func(a: int) -> float: - return a - - @inject - async def some_func(c=Depends(dep_func)): - assert isinstance(c, float) - assert c == 1.0 - - await some_func("1") - +async def test_ignore_depends_if_setted_manual(): + mock = Mock() -@pytest.mark.anyio -async def test_sync_depends(): - def sync_dep_func(a: int) -> float: - return a + async def dep_func(a, b) -> int: + mock(a, b) + return a + b @inject - async def some_func(a: int, b: int, c=Depends(sync_dep_func)) -> float: - assert isinstance(c, float) - return a + b + c + async def some_func(c=Depends(dep_func)) -> int: + return c + + assert (await some_func(c=2)) == 2 + assert not mock.called - assert await some_func("1", "2") + assert (await some_func(1, 2)) == 3 + mock.assert_called_once_with(1, 2) @pytest.mark.anyio -async def test_depends_response_cast(): +async def test_empty_main_body(): async def dep_func(a): return a @inject - async def some_func(a: int, b: int, c: int = Depends(dep_func)) -> float: - assert isinstance(c, int) - return a + b + c + async def some_func(c=Depends(dep_func)): + return c - assert await some_func("1", "2") + assert await some_func("1") == "1" @pytest.mark.anyio -async def test_depends_error(): - async def dep_func(b: dict, a: int = 3) -> float: # pragma: no cover - return a + b +async def test_empty_main_body_multiple_args(): + def dep2(b): + return b - async def another_dep_func(b: int, a: int = 3) -> dict: # pragma: no cover - return a + b + async def dep(a): + return a - @inject - async def some_func( - b: int, c=Depends(dep_func), d=Depends(another_dep_func) - ) -> int: # pragma: no cover - assert c is None - return b + @inject() + async def handler(d=Depends(dep2), c=Depends(dep)): + return d, c - with pytest.raises(ValidationError): - assert await some_func("2") + await handler(a=1, b=2) == (2, 1) + await handler(1, b=2) == (2, 1) + await handler(1, a=2) == (1, 2) + await handler(1, 2) == (1, 1) # all dependencies takes the first arg @pytest.mark.anyio -async def test_depends_annotated(): - async def dep_func(a): +async def test_sync_depends(): + def sync_dep_func(a): return a - D = Annotated[int, Depends(dep_func)] - @inject - async def some_func(a: int, b: int, c: D = None) -> float: - assert isinstance(c, int) + async def some_func(a: int, b: int, c=Depends(sync_dep_func)): return a + b + c - @inject - async def another_func(a: int, c: D): - return a + c - - assert await some_func("1", "2") - assert (await another_func(3)) == 6.0 + assert await some_func(1, 2) == 4 @pytest.mark.anyio -async def test_async_depends_annotated_str(): +async def test_depends_annotated(): async def dep_func(a): return a + D = Annotated[int, Depends(dep_func)] + @inject - async def some_func( - a: int, - b: int, - c: "Annotated[int, Depends(dep_func)]", - ) -> float: + async def some_func(a: int, b: int, c: D): assert isinstance(c, int) return a + b + c @inject - async def another_func( - a: int, - c: "Annotated[int, Depends(dep_func)]", - ): + async def another_func(a: int, c: D): return a + c - assert await some_func("1", "2") - assert await another_func("3") == 6.0 + assert await some_func(1, 2) == 4 + assert (await another_func(3)) == 6 @pytest.mark.anyio @@ -136,8 +115,7 @@ async def some_func( a: int, b: int, c: Annotated["float", Depends(adep_func)], - ) -> float: - assert isinstance(c, float) + ): return a + b + c @inject @@ -147,8 +125,8 @@ async def another_func( ): return a + c - assert await some_func("1", "2") - assert await another_func("3") == 6.0 + assert await some_func(1, 2) == 4 + assert (await another_func(3)) == 6 @pytest.mark.anyio @@ -347,31 +325,6 @@ async def some_func(a=Depends(MyDep(3).call)): # noqa: B008 await some_func() -@pytest.mark.anyio -async def test_not_cast(): - @dataclass - class A: - a: int - - async def dep() -> A: - return A(a=1) - - async def get_logger() -> logging.Logger: - return logging.getLogger(__file__) - - @inject - async def some_func( - b, - a: A = Depends(dep, cast=False), - logger: logging.Logger = Depends(get_logger, cast=False), - ): - assert a.a == 1 - assert logger - return b - - assert (await some_func(1)) == 1 - - @pytest.mark.anyio async def test_not_cast_main(): @dataclass @@ -421,21 +374,34 @@ async def some_func(): async def test_generator(): mock = Mock() + def sync_simple_func(): + mock.sync_simple() + + async def simple_func(): + mock.simple() + async def func(): mock.start() yield mock.end() @inject - async def simple_func(a: str, d=Depends(func)) -> int: + async def simple_func( + a: str, + d3=Depends(sync_simple_func), + d2=Depends(simple_func), + d=Depends(func), + ): for _ in range(2): yield a async for i in simple_func("1"): mock.start.assert_called_once() assert not mock.end.called - assert i == 1 + assert i == "1" + mock.sync_simple.assert_called_once() + mock.simple.assert_called_once() mock.end.assert_called_once() @@ -449,3 +415,118 @@ async def func(a=Depends(partial(dep, 10))): # noqa D008 return a assert await func() == 10 + + +@serializer +@pytest.mark.anyio +class TestSerializer: + @pytest.mark.anyio + async def test_not_cast(self): + @dataclass + class A: + a: int + + async def dep1() -> A: + return {"a": 1} + + async def dep2() -> A: + return {"a": 1} + + async def dep3() -> A: + return 1 + + async def get_logger() -> logging.Logger: + return logging.getLogger(__file__) + + @inject + async def some_func( + b, + a1: A = Depends(dep1, cast=False, cast_result=True), + a2: A = Depends(dep2, cast=True, cast_result=False), + a3: A = Depends(dep3, cast=False, cast_result=False), + logger: logging.Logger = Depends(get_logger), + ): + assert a1.a == 1 + assert a2.a == 1 + assert a3 == 1 + assert logger + return b + + assert (await some_func(1)) == 1 + + async def test_depends_error(self): + async def dep_func(b: dict, a: int = 3) -> float: # pragma: no cover + return a + b + + async def another_dep_func(b: int, a: int = 3) -> dict: # pragma: no cover + return a + b + + @inject + async def some_func( + b: int, c=Depends(dep_func), d=Depends(another_dep_func) + ) -> int: # pragma: no cover + assert c is None + return b + + with pytest.raises(ValidationError): + assert await some_func("2") + + async def test_depends_response_cast(self): + async def dep_func(a): + return a + + @inject + async def some_func(a: int, b: int, c: int = Depends(dep_func)) -> float: + assert isinstance(c, int) + return a + b + c + + assert await some_func("1", "2") + + async def test_async_depends_annotated_str(self): + async def dep_func(a): + return a + + @inject + async def some_func( + a: int, + b: int, + c: "Annotated[int, Depends(dep_func)]", + ) -> float: + assert isinstance(c, int) + return a + b + c + + @inject + async def another_func( + a: int, + c: "Annotated[int, Depends(dep_func)]", + ): + return a + c + + assert await some_func("1", "2") + assert await another_func("3") == 6.0 + + +@pytest.mark.anyio +async def test_default_key_value(): + async def dep(a: str = "a"): + return a + + @inject(cast=False) + async def func(a=Depends(dep)): + return a + + assert await func() == "a" + + +@pytest.mark.anyio +async def test_asynccontextmanager(): + async def dep(a: str): + return a + + @asynccontextmanager + @inject + async def func(a: str, b: str = Depends(dep)): + yield a == b + + async with func("a") as is_equal: + assert is_equal diff --git a/tests/library/test_custom.py b/tests/library/test_custom.py index 79096b85..40ab9fc3 100644 --- a/tests/library/test_custom.py +++ b/tests/library/test_custom.py @@ -1,18 +1,18 @@ import logging from time import monotonic_ns -from typing import Any, Dict +from typing import Annotated, Any import anyio -import pydantic import pytest -from typing_extensions import Annotated from fast_depends import Depends, inject +from fast_depends.exceptions import ValidationError from fast_depends.library import CustomField +from tests.marks import serializer class Header(CustomField): - def use(self, /, **kwargs: Any) -> Dict[str, Any]: + def use(self, /, **kwargs: Any) -> dict[str, Any]: kwargs = super().use(**kwargs) if v := kwargs.get("headers", {}).get(self.param_name): kwargs[self.param_name] = v @@ -30,7 +30,7 @@ def use_field(self, kwargs: Any) -> None: class AsyncHeader(Header): - async def use(self, /, **kwargs: Any) -> Dict[str, Any]: + async def use(self, /, **kwargs: Any) -> dict[str, Any]: return super().use(**kwargs) @@ -50,7 +50,7 @@ def test_header(): def sync_catch(key: int = Header()): # noqa: B008 return key - assert sync_catch(headers={"key": "1"}) == 1 + assert sync_catch(headers={"key": 1}) == 1 def test_custom_with_class(): @@ -59,112 +59,143 @@ class T: def __init__(self, key: int = Header()): self.key = key - assert T(headers={"key": "1"}).key == 1 + assert T(headers={"key": 1}).key == 1 + +def test_reusable_annotated() -> None: + HeaderKey = Annotated[float, Header(cast=False)] -@pytest.mark.anyio -async def test_header_async(): @inject - async def async_catch(key: int = Header()): # noqa: B008 + def sync_catch(key: HeaderKey) -> float: return key - assert (await async_catch(headers={"key": "1"})) == 1 - - -def test_multiple_header(): @inject - def sync_catch(key: str = Header(), key2: int = Header()): # noqa: B008 - assert key == "1" - assert key2 == 2 + def sync_catch2(key2: HeaderKey) -> float: + return key2 - sync_catch(headers={"key": "1", "key2": "2"}) + assert sync_catch(headers={"key": 1}) == 1 + assert sync_catch2(headers={"key2": 1}) == 1 -@pytest.mark.anyio -async def test_async_header_async(): +def test_arguments_mapping(): @inject - async def async_catch( # noqa: B008 - key: float = AsyncHeader(), key2: int = AsyncHeader() + def func( + d: int = CustomField(cast=False), + b: int = CustomField(cast=False), + c: int = CustomField(cast=False), + a: int = CustomField(cast=False), ): - return key, key2 + assert d == 4 + assert b == 2 + assert c == 3 + assert a == 1 - assert (await async_catch(headers={"key": "1", "key2": 1})) == (1.0, 1) + for _ in range(50): + func(4, 2, 3, 1) -def test_sync_field_header(): - @inject - def sync_catch(key: float = FieldHeader(), key2: int = FieldHeader()): # noqa: B008 - return key, key2 +@serializer +class TestSerializer: + @pytest.mark.anyio + async def test_header_async(self): + @inject + async def async_catch(key: int = Header()): # noqa: B008 + return key - assert sync_catch(headers={"key": "1", "key2": 1}) == (1.0, 1) + assert (await async_catch(headers={"key": "1"})) == 1 + def test_multiple_header(self): + @inject + def sync_catch(key: str = Header(), key2: int = Header()): # noqa: B008 + assert key == "1" + assert key2 == 2 -@pytest.mark.anyio -async def test_async_field_header(): - @inject - async def async_catch( # noqa: B008 - key: float = AsyncFieldHeader(), key2: int = AsyncFieldHeader() - ): - return key, key2 + sync_catch(headers={"key": "1", "key2": "2"}) + + @pytest.mark.anyio + async def test_async_header_async(self): + @inject + async def async_catch( # noqa: B008 + key: float = AsyncHeader(), key2: int = AsyncHeader() + ): + return key, key2 - start = monotonic_ns() - assert (await async_catch(headers={"key": "1", "key2": 1})) == (1.0, 1) - assert (monotonic_ns() - start) / 10**9 < 0.2 + assert (await async_catch(headers={"key": "1", "key2": 1})) == (1.0, 1) + def test_sync_field_header(self): + @inject + def sync_catch(key: float = FieldHeader(), key2: int = FieldHeader()): # noqa: B008 + return key, key2 -def test_async_header_sync(): - with pytest.raises(AssertionError): + assert sync_catch(headers={"key": "1", "key2": 1}) == (1.0, 1) + @pytest.mark.anyio + async def test_async_field_header(self): @inject - def sync_catch(key: str = AsyncHeader()): # pragma: no cover # noqa: B008 - return key + async def async_catch( # noqa: B008 + key: float = AsyncFieldHeader(), key2: int = AsyncFieldHeader() + ): + return key, key2 + start = monotonic_ns() + assert (await async_catch(headers={"key": "1", "key2": 1})) == (1.0, 1) + assert (monotonic_ns() - start) / 10**9 < 0.2 -def test_header_annotated(): - @inject - def sync_catch(key: Annotated[int, Header()]): - return key + def test_async_header_sync(self): + with pytest.raises(AssertionError): - assert sync_catch(headers={"key": "1"}) == 1 + @inject + def sync_catch(key: str = AsyncHeader()): # pragma: no cover # noqa: B008 + return key + def test_header_annotated(self): + @inject + def sync_catch(key: Annotated[int, Header()]): + return key -def test_header_required(): - @inject - def sync_catch(key2=Header()): # pragma: no cover # noqa: B008 - return key2 + assert sync_catch(headers={"key": 1}) == 1 - with pytest.raises(pydantic.ValidationError): - sync_catch() + def test_header_required(self): + @inject + def sync_catch(key=Header()): # pragma: no cover # noqa: B008 + return key + with pytest.raises(ValidationError): + sync_catch() -def test_header_not_required(): - @inject - def sync_catch(key2=Header(required=False)): # noqa: B008 - assert key2 is None + def test_header_not_required(self): + @inject + def sync_catch(key2=Header(required=False)): # noqa: B008 + assert key2 is None - sync_catch() + sync_catch() + def test_header_not_required_with_default(self): + @inject + def sync_catch(key2: Annotated[str, Header(required=False)] = "1"): # noqa: B008 + return key2 == "1" -def test_depends(): - def dep(key: Annotated[int, Header()]): - return key + assert sync_catch() - @inject - def sync_catch(k=Depends(dep)): - return k + def test_depends(self): + def dep(key: Annotated[int, Header()]): + return key - assert sync_catch(headers={"key": "1"}) == 1 + @inject + def sync_catch(k=Depends(dep)): + return k + assert sync_catch(headers={"key": 1}) == 1 -def test_not_cast(): - @inject - def sync_catch(key: Annotated[float, Header(cast=False)]): - return key + def test_not_cast(self): + @inject + def sync_catch(key: Annotated[float, Header(cast=False)]): + return key - assert sync_catch(headers={"key": 1}) == 1 + assert sync_catch(headers={"key": 1}) == 1 - @inject - def sync_catch(key: logging.Logger = Header(cast=False)): # noqa: B008 - return key + @inject + def sync_catch(key: logging.Logger = Header(cast=False)): # noqa: B008 + return key - assert sync_catch(headers={"key": 1}) == 1 + assert sync_catch(headers={"key": 1}) == 1 diff --git a/tests/marks.py b/tests/marks.py index 7e4dcaa4..e7541b40 100644 --- a/tests/marks.py +++ b/tests/marks.py @@ -1,7 +1,34 @@ import pytest -from fast_depends._compat import PYDANTIC_V2 +try: + from fast_depends.pydantic._compat import PYDANTIC_V2 -pydanticV1 = pytest.mark.skipif(PYDANTIC_V2, reason="requires PydanticV2") # noqa: N816 + HAS_PYDANTIC = True -pydanticV2 = pytest.mark.skipif(not PYDANTIC_V2, reason="requires PydanticV1") # noqa: N816 +except ImportError: + HAS_PYDANTIC = False + PYDANTIC_V2 = False + +try: + from fast_depends.msgspec import MsgSpecSerializer # noqa: F401 + + HAS_MSGSPEC = True +except ImportError: + HAS_MSGSPEC = False + + +serializer = pytest.mark.skipif( + not HAS_MSGSPEC and not HAS_PYDANTIC, reason="requires serializer" +) # noqa: N816 + +msgspec = pytest.mark.skipif(not HAS_MSGSPEC, reason="requires Msgspec") # noqa: N816 + +pydantic = pytest.mark.skipif(not HAS_PYDANTIC, reason="requires Pydantic") # noqa: N816 + +pydanticV1 = pytest.mark.skipif( + not HAS_PYDANTIC or PYDANTIC_V2, reason="requires PydanticV2" +) # noqa: N816 + +pydanticV2 = pytest.mark.skipif( + not HAS_PYDANTIC or not PYDANTIC_V2, reason="requires PydanticV1" +) # noqa: N816 diff --git a/tests/pydantic_specific/__init__.py b/tests/pydantic_specific/__init__.py new file mode 100644 index 00000000..e2ff924f --- /dev/null +++ b/tests/pydantic_specific/__init__.py @@ -0,0 +1,3 @@ +import pytest + +pytest.importorskip("pydantic") diff --git a/tests/pydantic_specific/async/__init__.py b/tests/pydantic_specific/async/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/pydantic_specific/async/test_cast.py b/tests/pydantic_specific/async/test_cast.py new file mode 100644 index 00000000..d21810ec --- /dev/null +++ b/tests/pydantic_specific/async/test_cast.py @@ -0,0 +1,90 @@ +from typing import Annotated + +import pytest +from annotated_types import Ge +from pydantic import BaseModel, Field + +from fast_depends import inject +from fast_depends.exceptions import ValidationError +from tests.marks import pydanticV2 + + +@pytest.mark.anyio +async def test_pydantic_types_casting(): + class SomeModel(BaseModel): + field: int + + @inject + async def some_func(a: SomeModel): + return a.field + + assert isinstance(await some_func({"field": "31"}), int) + + +@pytest.mark.anyio +async def test_pydantic_field_types_casting(): + @inject + async def some_func(a: int = Field(..., alias="b")) -> float: + assert isinstance(a, int) + return a + + @inject + async def another_func(a=Field(..., alias="b")) -> float: + assert isinstance(a, str) + return a + + assert isinstance(await some_func(b="2", c=3), float) + assert isinstance(await another_func(b="2"), float) + + +@pytest.mark.anyio +async def test_annotated(): + A = Annotated[int, Field(..., alias="b")] + + @inject + async def some_func(a: A) -> float: + assert isinstance(a, int) + return a + + assert isinstance(await some_func(b="2"), float) + + +@pytest.mark.anyio +async def test_generator(): + @inject + async def simple_func(a: str) -> int: + for _ in range(2): + yield a + + async for i in simple_func("1"): + assert i == 1 + + +@pytest.mark.anyio +async def test_validation_error(): + @inject + async def some_func(a, b: str = Field(..., max_length=1)): + return 1 + + assert await some_func(1, "a") == 1 + + with pytest.raises(ValidationError): + assert await some_func() + + with pytest.raises(ValidationError): + assert await some_func(1, "dsdas") + + +@pytest.mark.anyio +@pydanticV2 +async def test_multi_annotated(): + from pydantic.functional_validators import AfterValidator + + @inject() + async def f(a: Annotated[int, Ge(10), AfterValidator(lambda x: x + 10)]) -> int: + return a + + with pytest.raises(ValidationError): + await f(1) + + assert await f(10) == 20 diff --git a/tests/async/test_config.py b/tests/pydantic_specific/async/test_config.py similarity index 56% rename from tests/async/test_config.py rename to tests/pydantic_specific/async/test_config.py index 25bfd4ba..9664918f 100644 --- a/tests/async/test_config.py +++ b/tests/pydantic_specific/async/test_config.py @@ -1,15 +1,20 @@ import pytest -from pydantic import ValidationError from fast_depends import Depends, inject -from fast_depends._compat import PYDANTIC_V2 +from fast_depends.exceptions import ValidationError +from fast_depends.pydantic import PydanticSerializer +from tests.marks import PYDANTIC_V2 async def dep(a: str): return a -@inject(pydantic_config={"str_max_length" if PYDANTIC_V2 else "max_anystr_length": 1}) +@inject( + serializer_cls=PydanticSerializer( + {"str_max_length" if PYDANTIC_V2 else "max_anystr_length": 1} + ) +) async def limited_str(a=Depends(dep)): ... diff --git a/tests/pydantic_specific/sync/__init__.py b/tests/pydantic_specific/sync/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/pydantic_specific/sync/test_cast.py b/tests/pydantic_specific/sync/test_cast.py new file mode 100644 index 00000000..1579cf75 --- /dev/null +++ b/tests/pydantic_specific/sync/test_cast.py @@ -0,0 +1,84 @@ +from typing import Annotated + +import pytest +from annotated_types import Ge +from pydantic import BaseModel, Field + +from fast_depends import inject +from fast_depends.exceptions import ValidationError +from tests.marks import pydanticV2 + + +def test_pydantic_types_casting(): + class SomeModel(BaseModel): + field: int + + @inject + def some_func(a: SomeModel): + return a.field + + assert isinstance(some_func({"field": "31"}), int) + + +def test_pydantic_field_types_casting(): + @inject + def some_func(a: int = Field(..., alias="b")) -> float: + assert isinstance(a, int) + return a + + @inject + def another_func(a=Field(..., alias="b")) -> float: + assert isinstance(a, str) + return a + + assert isinstance(some_func(b="2", c=3), float) + assert isinstance(another_func(b="2"), float) + + +def test_annotated(): + A = Annotated[int, Field(..., alias="b")] + + @inject + def some_func(a: A) -> float: + assert isinstance(a, int) + return a + + assert isinstance(some_func(b="2"), float) + + +def test_generator(): + @inject + def simple_func(a: str) -> int: + for _ in range(2): + yield a + + for i in simple_func("1"): + assert i == 1 + + +def test_validation_error(): + @inject + def some_func(a, b: str = Field(..., max_length=1)): + return 1 + + assert some_func(1, "a") == 1 + + with pytest.raises(ValidationError): + assert some_func() + + with pytest.raises(ValidationError): + assert some_func(1, "dsdas") + + +@pydanticV2 +def test_multi_annotated(): + from pydantic.functional_validators import AfterValidator + + @inject() + def f(a: Annotated[int, Ge(10), AfterValidator(lambda x: x + 10)]) -> int: + return a + + with pytest.raises(ValidationError): + f(1) + + assert f(10) == 20 diff --git a/tests/sync/test_config.py b/tests/pydantic_specific/sync/test_config.py similarity index 52% rename from tests/sync/test_config.py rename to tests/pydantic_specific/sync/test_config.py index eba08a33..8b792077 100644 --- a/tests/sync/test_config.py +++ b/tests/pydantic_specific/sync/test_config.py @@ -1,15 +1,20 @@ import pytest -from pydantic import ValidationError from fast_depends import Depends, inject -from fast_depends._compat import PYDANTIC_V2 +from fast_depends.exceptions import ValidationError +from fast_depends.pydantic import PydanticSerializer +from tests.marks import PYDANTIC_V2 def dep(a: str): return a -@inject(pydantic_config={"str_max_length" if PYDANTIC_V2 else "max_anystr_length": 1}) +@inject( + serializer_cls=PydanticSerializer( + {"str_max_length" if PYDANTIC_V2 else "max_anystr_length": 1} + ) +) def limited_str(a=Depends(dep)): ... diff --git a/tests/pydantic_specific/test_custom.py b/tests/pydantic_specific/test_custom.py new file mode 100644 index 00000000..03b44095 --- /dev/null +++ b/tests/pydantic_specific/test_custom.py @@ -0,0 +1,31 @@ +from typing import Annotated, Any + +import pytest +from annotated_types import Ge + +from fast_depends import inject +from fast_depends.exceptions import ValidationError +from fast_depends.library import CustomField +from tests.marks import pydanticV2 + + +class Header(CustomField): + def use(self, /, **kwargs: Any) -> dict[str, Any]: + kwargs = super().use(**kwargs) + if v := kwargs.get("headers", {}).get(self.param_name): + kwargs[self.param_name] = v + return kwargs + + +@pydanticV2 +def test_annotated_header_with_meta(): + @inject + def sync_catch(key: Annotated[int, Header(), Ge(3)] = 3): # noqa: B008 + return key + + assert sync_catch(headers={"key": "4"}) == 4 + + assert sync_catch(headers={}) == 3 + + with pytest.raises(ValidationError): + sync_catch(headers={"key": "2"}) diff --git a/tests/test_locals.py b/tests/pydantic_specific/test_locals.py similarity index 100% rename from tests/test_locals.py rename to tests/pydantic_specific/test_locals.py diff --git a/tests/pydantic_specific/test_prebuild.py b/tests/pydantic_specific/test_prebuild.py new file mode 100644 index 00000000..2a1da292 --- /dev/null +++ b/tests/pydantic_specific/test_prebuild.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +from pydantic import BaseModel + +from fast_depends import Provider +from fast_depends.core import build_call_model +from fast_depends.pydantic import PydanticSerializer +from fast_depends.pydantic._compat import PYDANTIC_V2 + +from .wrapper import noop_wrap + + +class Model(BaseModel): + a: str + + +def model_func(m: Model) -> str: + return m.a + + +def test_prebuild_with_wrapper(): + func = noop_wrap(model_func) + assert func(Model(a="Hi!")) == "Hi!" + + # build_call_model should work even if function is wrapped with a + # wrapper that is imported from different module + call_model = build_call_model( + func, + dependency_provider=Provider(), + serializer_cls=PydanticSerializer(), + ) + + model = call_model.serializer.model + assert model + # Fails if function unwrapping is not done at type introspection + + if PYDANTIC_V2: + model.model_rebuild() + else: + # pydantic v1 + model.update_forward_refs() diff --git a/tests/test_schema.py b/tests/pydantic_specific/test_schema.py similarity index 67% rename from tests/test_schema.py rename to tests/pydantic_specific/test_schema.py index ff685eef..f3b717e0 100644 --- a/tests/test_schema.py +++ b/tests/pydantic_specific/test_schema.py @@ -1,21 +1,31 @@ from typing import Optional from dirty_equals import IsDict, IsPartialDict -from pydantic import BaseModel, Field -from fast_depends import Depends -from fast_depends._compat import PYDANTIC_V2 +from fast_depends import Depends, Provider from fast_depends.core import build_call_model -from fast_depends.schema import get_schema -REF_KEY = "$defs" if PYDANTIC_V2 else "definitions" +try: + from pydantic import BaseModel, Field + + from fast_depends.pydantic._compat import PYDANTIC_V2 + from fast_depends.pydantic.schema import get_schema + from fast_depends.pydantic.serializer import PydanticSerializer + + REF_KEY = "$defs" if PYDANTIC_V2 else "definitions" +except ImportError: + REF_KEY = "" def test_base(): def handler(): pass - schema = get_schema(build_call_model(handler)) + schema = get_schema( + build_call_model( + handler, serializer_cls=PydanticSerializer(), dependency_provider=Provider() + ) + ) assert schema == {"title": "handler", "type": "null"}, schema @@ -24,7 +34,13 @@ def test_no_type(self): def handler(a): pass - schema = get_schema(build_call_model(handler)) + schema = get_schema( + build_call_model( + handler, + serializer_cls=PydanticSerializer(), + dependency_provider=Provider(), + ) + ) assert schema == { "properties": {"a": {"title": "A"}}, "required": ["a"], @@ -36,14 +52,27 @@ def test_no_type_embeded(self): def handler(a): pass - schema = get_schema(build_call_model(handler), embed=True) + schema = get_schema( + build_call_model( + handler, + serializer_cls=PydanticSerializer(), + dependency_provider=Provider(), + ), + embed=True, + ) assert schema == {"title": "A"}, schema def test_no_type_with_default(self): def handler(a=None): pass - schema = get_schema(build_call_model(handler)) + schema = get_schema( + build_call_model( + handler, + serializer_cls=PydanticSerializer(), + dependency_provider=Provider(), + ) + ) assert schema == { "properties": {"a": IsPartialDict({"title": "A"})}, "title": "handler", @@ -54,7 +83,14 @@ def test_no_type_with_default_and_embed(self): def handler(a=None): pass - schema = get_schema(build_call_model(handler), embed=True) + schema = get_schema( + build_call_model( + handler, + serializer_cls=PydanticSerializer(), + dependency_provider=Provider(), + ), + embed=True, + ) assert schema == IsPartialDict({"title": "A"}), schema @@ -63,7 +99,13 @@ def test_one_arg(self): def handler(a: int): pass - schema = get_schema(build_call_model(handler)) + schema = get_schema( + build_call_model( + handler, + serializer_cls=PydanticSerializer(), + dependency_provider=Provider(), + ) + ) assert schema == { "properties": {"a": {"title": "A", "type": "integer"}}, "required": ["a"], @@ -75,14 +117,27 @@ def test_one_arg_with_embed(self): def handler(a: int): pass - schema = get_schema(build_call_model(handler), embed=True) + schema = get_schema( + build_call_model( + handler, + serializer_cls=PydanticSerializer(), + dependency_provider=Provider(), + ), + embed=True, + ) assert schema == {"title": "A", "type": "integer"}, schema def test_one_arg_with_optional(self): def handler(a: Optional[int]): pass - schema = get_schema(build_call_model(handler)) + schema = get_schema( + build_call_model( + handler, + serializer_cls=PydanticSerializer(), + dependency_provider=Provider(), + ) + ) assert schema == { "properties": { "a": IsDict( @@ -99,7 +154,13 @@ def test_one_arg_with_default(self): def handler(a: int = 0): pass - schema = get_schema(build_call_model(handler)) + schema = get_schema( + build_call_model( + handler, + serializer_cls=PydanticSerializer(), + dependency_provider=Provider(), + ) + ) assert schema == { "properties": {"a": {"default": 0, "title": "A", "type": "integer"}}, "title": "handler", @@ -110,7 +171,14 @@ def test_one_arg_with_default_and_embed(self): def handler(a: int = 0): pass - schema = get_schema(build_call_model(handler), embed=True) + schema = get_schema( + build_call_model( + handler, + serializer_cls=PydanticSerializer(), + dependency_provider=Provider(), + ), + embed=True, + ) assert schema == {"default": 0, "title": "A", "type": "integer"}, schema @@ -122,7 +190,13 @@ class Model(BaseModel): def handler(a: Model): pass - schema = get_schema(build_call_model(handler)) + schema = get_schema( + build_call_model( + handler, + serializer_cls=PydanticSerializer(), + dependency_provider=Provider(), + ) + ) assert schema == { REF_KEY: { "Model": { @@ -145,7 +219,14 @@ class Model(BaseModel): def handler(a: Model): pass - schema = get_schema(build_call_model(handler), resolve_refs=True) + schema = get_schema( + build_call_model( + handler, + serializer_cls=PydanticSerializer(), + dependency_provider=Provider(), + ), + resolve_refs=True, + ) assert schema == { "properties": { "a": { @@ -167,7 +248,14 @@ class Model(BaseModel): def handler(a: Optional[Model] = None): pass - schema = get_schema(build_call_model(handler), resolve_refs=True) + schema = get_schema( + build_call_model( + handler, + serializer_cls=PydanticSerializer(), + dependency_provider=Provider(), + ), + resolve_refs=True, + ) assert schema == { "properties": { "a": IsDict( @@ -204,7 +292,15 @@ class Model(BaseModel): def handler(a: Optional[Model]): pass - schema = get_schema(build_call_model(handler), resolve_refs=True, embed=True) + schema = get_schema( + build_call_model( + handler, + serializer_cls=PydanticSerializer(), + dependency_provider=Provider(), + ), + resolve_refs=True, + embed=True, + ) assert schema == IsDict( { "anyOf": [ @@ -236,7 +332,14 @@ class Model(BaseModel): def handler(a: Model): pass - schema = get_schema(build_call_model(handler), resolve_refs=True) + schema = get_schema( + build_call_model( + handler, + serializer_cls=PydanticSerializer(), + dependency_provider=Provider(), + ), + resolve_refs=True, + ) assert schema == { "properties": { "a": { @@ -265,7 +368,15 @@ class Model(BaseModel): def handler(a: Model): pass - schema = get_schema(build_call_model(handler), resolve_refs=True, embed=True) + schema = get_schema( + build_call_model( + handler, + serializer_cls=PydanticSerializer(), + dependency_provider=Provider(), + ), + resolve_refs=True, + embed=True, + ) assert schema == { "properties": {"a": {"title": "A", "type": "integer"}}, "required": ["a"], @@ -283,7 +394,15 @@ class Model(BaseModel): def handler(a: Model): pass - schema = get_schema(build_call_model(handler), resolve_refs=True, embed=True) + schema = get_schema( + build_call_model( + handler, + serializer_cls=PydanticSerializer(), + dependency_provider=Provider(), + ), + resolve_refs=True, + embed=True, + ) assert schema == { "properties": { "a": { @@ -304,7 +423,13 @@ def test_base(self): def handler(a, b): pass - schema = get_schema(build_call_model(handler)) + schema = get_schema( + build_call_model( + handler, + serializer_cls=PydanticSerializer(), + dependency_provider=Provider(), + ) + ) assert schema == { "properties": {"a": {"title": "A"}, "b": {"title": "B"}}, "required": ["a", "b"], @@ -316,7 +441,13 @@ def test_types_and_default(self): def handler(a: str, b: int = 0): pass - schema = get_schema(build_call_model(handler)) + schema = get_schema( + build_call_model( + handler, + serializer_cls=PydanticSerializer(), + dependency_provider=Provider(), + ) + ) assert schema == { "properties": { "a": {"title": "A", "type": "string"}, @@ -331,7 +462,14 @@ def test_ignores_embed(self): def handler(a: str, b: int = 0): pass - schema = get_schema(build_call_model(handler), embed=True) + schema = get_schema( + build_call_model( + handler, + serializer_cls=PydanticSerializer(), + dependency_provider=Provider(), + ), + embed=True, + ) assert schema == { "properties": { "a": {"title": "A", "type": "string"}, @@ -349,7 +487,13 @@ class Model(BaseModel): def handler(a: str, b: Model): pass - schema = get_schema(build_call_model(handler)) + schema = get_schema( + build_call_model( + handler, + serializer_cls=PydanticSerializer(), + dependency_provider=Provider(), + ) + ) assert schema == { REF_KEY: { "Model": { @@ -381,7 +525,14 @@ class Model(BaseModel): def handler(a: str, b: Model): pass - schema = get_schema(build_call_model(handler), resolve_refs=True) + schema = get_schema( + build_call_model( + handler, + serializer_cls=PydanticSerializer(), + dependency_provider=Provider(), + ), + resolve_refs=True, + ) assert schema == { "properties": { "a": {"title": "A", "type": "string"}, @@ -427,7 +578,12 @@ def handler( ... schema = get_schema( - build_call_model(handler, extra_dependencies=(Depends(dep4),)), + build_call_model( + handler, + extra_dependencies=(Depends(dep4),), + serializer_cls=PydanticSerializer(), + dependency_provider=Provider(), + ), resolve_refs=True, embed=True, ) diff --git a/tests/pydantic_specific/wrapper.py b/tests/pydantic_specific/wrapper.py new file mode 100644 index 00000000..0f83b577 --- /dev/null +++ b/tests/pydantic_specific/wrapper.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +from functools import wraps + + +def noop_wrap(func): + @wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + return wrapper diff --git a/tests/sync/test_cast.py b/tests/sync/test_cast.py index b50b1335..3332a6b6 100644 --- a/tests/sync/test_cast.py +++ b/tests/sync/test_cast.py @@ -1,12 +1,10 @@ -from typing import Dict, Iterator, Tuple +from collections.abc import Iterator import pytest -from annotated_types import Ge -from pydantic import BaseModel, Field, ValidationError -from typing_extensions import Annotated from fast_depends import inject -from tests.marks import pydanticV2 +from fast_depends.exceptions import ValidationError +from tests.marks import serializer def test_not_annotated(): @@ -17,15 +15,6 @@ def some_func(a, b): assert isinstance(some_func("1", "2"), str) -def test_annotated_partial(): - @inject - def some_func(a, b: int): - assert isinstance(b, int) - return a + b - - assert isinstance(some_func(1, "2"), int) - - def test_arbitrary_args(): class ArbitraryType: def __init__(self): @@ -50,100 +39,21 @@ def some_func(a: ArbitraryType) -> ArbitraryType: assert isinstance(some_func(ArbitraryType()), ArbitraryType) -def test_validation_error(): - @inject - def some_func(a, b: str = Field(..., max_length=1)): # pragma: no cover - pass - - with pytest.raises(ValidationError): - assert some_func() - - with pytest.raises(ValidationError): - assert some_func(1, "dsdas") - - -def test_types_casting(): - @inject - def some_func(a: int, b: int) -> float: - assert isinstance(a, int) - assert isinstance(b, int) - r = a + b - assert isinstance(r, int) - return r - - assert isinstance(some_func("1", "2"), float) - - -def test_types_casting_from_str(): +def test_args(): @inject - def some_func(a: "int") -> float: - return a - - assert isinstance(some_func("1"), float) - - -def test_pydantic_types_casting(): - class SomeModel(BaseModel): - field: int - - @inject - def some_func(a: SomeModel): - return a.field - - assert isinstance(some_func({"field": "31"}), int) - - -def test_pydantic_field_types_casting(): - @inject - def some_func(a: int = Field(..., alias="b")) -> float: - assert isinstance(a, int) - return a - - @inject - def another_func(a=Field(..., alias="b")) -> float: - assert isinstance(a, str) - return a + def some_func(a, *ar): + return a, ar - assert isinstance(some_func(b="2"), float) - assert isinstance(another_func(b="2"), float) - - -def test_wrong_incoming_types(): - @inject - def some_func(a: int): # pragma: no cover - return a - - with pytest.raises(ValidationError): - some_func({"key", 1}) - - -def test_wrong_return_types(): - @inject - def some_func(a: int) -> dict: - return a - - with pytest.raises(ValidationError): - some_func("2") - - -def test_annotated(): - A = Annotated[int, Field(..., alias="b")] - - @inject - def some_func(a: A) -> float: - assert isinstance(a, int) - return a - - assert isinstance(some_func(b="2"), float) + assert (1, (2,)) == some_func(1, 2) def test_args_kwargs_1(): @inject def simple_func( a: int, - *args: Tuple[float, ...], + *args: tuple[float, ...], b: int, - **kwargs: Dict[str, int], + **kwargs: dict[str, int], ): return a, args, b, kwargs @@ -154,7 +64,7 @@ def test_args_kwargs_2(): @inject def simple_func( a: int, - *args: Tuple[float, ...], + *args: tuple[float, ...], b: int, ): return a, args, b @@ -178,35 +88,105 @@ def simple_func(a: int, *, b: int): ) -def test_generator(): +def test_args_kwargs_4(): @inject - def simple_func(a: str) -> int: - for _ in range(2): - yield a + def simple_func( + *args: tuple[float, ...], + **kwargs: dict[str, int], + ): + return args, kwargs - for i in simple_func("1"): - assert i == 1 + assert ( + (1.0, 2.0, 3.0), + { + "key": 1, + "b": 3, + }, + ) == simple_func(1.0, 2.0, 3, b=3.0, key=1.0) -def test_generator_iterator_type(): +def test_args_kwargs_5(): @inject - def simple_func(a: str) -> Iterator[int]: - for _ in range(2): - yield a - - for i in simple_func("1"): - assert i == 1 - - -@pydanticV2 -def test_multi_annotated(): - from pydantic.functional_validators import AfterValidator - - @inject() - def f(a: Annotated[int, Ge(10), AfterValidator(lambda x: x + 10)]) -> int: - return a - - with pytest.raises(ValidationError): - f(1) - - assert f(10) == 20 + def simple_func( + *a: tuple[float, ...], + **kw: dict[str, int], + ): + return a, kw + + assert ( + (1.0, 2.0, 3.0), + { + "key": 1, + "b": 3, + }, + ) == simple_func(1.0, 2.0, 3, b=3.0, key=1.0) + + +@serializer +class TestSerializer: + def test_no_cast_result(self): + @inject(cast_result=False) + def some_func(a: int, b: int) -> str: + return a + b + + assert some_func("1", "2") == 3 + + def test_annotated_partial(self): + @inject + def some_func(a, b: int): + assert isinstance(b, int) + return a + b + + assert isinstance(some_func(1, "2"), int) + + def test_types_casting(self): + @inject + def some_func(a: int, b: int) -> float: + assert isinstance(a, int) + assert isinstance(b, int) + r = a + b + assert isinstance(r, int) + return r + + assert isinstance(some_func("1", "2"), float) + + def test_types_casting_from_str(self): + @inject + def some_func(a: "int") -> float: + return a + + assert isinstance(some_func("1"), float) + + def test_wrong_incoming_types(self): + @inject + def some_func(a: int): # pragma: no cover + return a + + with pytest.raises(ValidationError): + some_func({"key", 1}) + + def test_wrong_return_type(self): + @inject + def some_func(a: int) -> dict: + return a + + with pytest.raises(ValidationError): + some_func("2") + + def test_generator(self): + @inject + def simple_func(a: str) -> int: + for _ in range(2): + yield a + + for i in simple_func("1"): + assert i == 1 + + def test_generator_iterator_type(self): + @inject + def simple_func(a: str) -> Iterator[int]: + for _ in range(2): + yield a + + for i in simple_func("1"): + assert i == 1 diff --git a/tests/sync/test_depends.py b/tests/sync/test_depends.py index 8b0aca0e..2603c584 100644 --- a/tests/sync/test_depends.py +++ b/tests/sync/test_depends.py @@ -1,113 +1,89 @@ import logging +from contextlib import contextmanager from dataclasses import dataclass from functools import partial +from typing import Annotated from unittest.mock import Mock import pytest -from pydantic import ValidationError -from typing_extensions import Annotated from fast_depends import Depends, inject +from fast_depends.exceptions import ValidationError +from tests.marks import serializer def test_depends(): - def dep_func(b: int, a: int = 3) -> float: + def dep_func(b: int, a: int = 3): return a + b @inject - def some_func(b: int, c=Depends(dep_func)) -> int: - assert isinstance(c, float) + def some_func(b: int, c=Depends(dep_func)): return b + c - assert some_func("2") == 7 + assert some_func(2) == 7 def test_empty_main_body(): - def dep_func(a: int) -> float: + def dep_func(a): return a @inject def some_func(c=Depends(dep_func)): - assert isinstance(c, float) - assert c == 1.0 - - some_func("1") + return c + assert some_func(1) == 1 -def test_depends_error(): - def dep_func(b: dict, a: int = 3) -> float: # pragma: no cover - return a + b - def another_func(b: int, a: int = 3) -> dict: # pragma: no cover - return a + b +def test_class_depends(): + class MyDep: + def __init__(self, a): + self.a = a @inject - def some_func( - b: int, c=Depends(dep_func), d=Depends(another_func) - ) -> int: # pragma: no cover - assert c is None - return b - - with pytest.raises(ValidationError): - assert some_func("2") == 7 - - -def test_depends_response_cast(): - def dep_func(a): + def some_func(a=Depends(MyDep)): + assert isinstance(a, MyDep) + assert a.a == 3 return a - @inject - def some_func(a: int, b: int, c: int = Depends(dep_func)) -> float: - assert isinstance(c, int) - return a + b + c + some_func(3) - assert some_func("1", "2") +def test_empty_main_body_multiple_args(): + def dep2(b): + return b -def test_depends_annotated(): - def dep_func(a): + def dep(a): return a - D = Annotated[int, Depends(dep_func)] + @inject() + def handler(d=Depends(dep2), c=Depends(dep)): + return d, c - @inject - def some_func(a: int, b: int, c: D = None) -> float: - assert isinstance(c, int) - return a + b + c + assert handler(a=1, b=2) == (2, 1) + assert handler(1, b=2) == (2, 1) + assert handler(1, a=2) == (1, 2) + assert handler(1, 2) == (1, 1) # all dependencies takes the first arg - @inject - def another_func(a: int, c: D): - return a + c - - assert some_func("1", "2") - assert another_func("3") == 6.0 +def test_ignore_depends_if_setted_manual(): + mock = Mock() -def test_depends_annotated_str(): - def dep_func(a): - return a + def dep_func(a, b) -> int: + mock(a, b) + return a + b @inject - def some_func( - a: int, - b: int, - c: "Annotated[int, Depends(dep_func)]", - ) -> float: - assert isinstance(c, int) - return a + b + c + def some_func(c=Depends(dep_func)) -> int: + return c - @inject - def another_func( - a: int, - c: "Annotated[int, Depends(dep_func)]", - ): - return a + c + assert some_func(c=2) == 2 + assert not mock.called - assert some_func("1", "2") - assert another_func("3") == 6.0 + assert some_func(1, 2) == 3 + mock.assert_called_once_with(1, 2) -def test_depends_annotated_str_partial(): +def test_depends_annotated_type_str(): def dep_func(a): return a @@ -115,20 +91,19 @@ def dep_func(a): def some_func( a: int, b: int, - c: Annotated["float", Depends(dep_func)], - ) -> float: - assert isinstance(c, float) + c: Annotated["int", Depends(dep_func)], + ): return a + b + c @inject def another_func( a: int, - c: Annotated["float", Depends(dep_func)], + c: Annotated["int", Depends(dep_func)], ): return a + c - assert some_func("1", "2") - assert another_func("3") == 6.0 + assert some_func(1, 2) == 4 + assert another_func(3) == 6 def test_cache(): @@ -194,20 +169,6 @@ def some_func(a=Depends(dep_func)): mock.exit.assert_called_once() -def test_class_depends(): - class MyDep: - def __init__(self, a: int): - self.a = a - - @inject - def some_func(a=Depends(MyDep)): - assert isinstance(a, MyDep) - assert a.a == 3 - return a - - some_func(3) - - def test_callable_class_depends(): class MyDep: def __init__(self, a: int): @@ -224,30 +185,6 @@ def some_func(a: int = Depends(MyDep(3))): # noqa: B008 some_func() -def test_not_cast(): - @dataclass - class A: - a: int - - def dep() -> A: - return A(a=1) - - def get_logger() -> logging.Logger: - return logging.getLogger(__file__) - - @inject - def some_func( - b, - a: A = Depends(dep, cast=False), - logger: logging.Logger = Depends(get_logger, cast=False), - ): - assert a.a == 1 - assert logger - return b - - assert some_func(1) == 1 - - def test_not_cast_main(): @dataclass class A: @@ -311,24 +248,39 @@ def some_func(a: int, b: int, c=Depends(dep_func)) -> str: # pragma: no cover return a + b + c +def test_async_extra_depends(): + async def dep_func(a: int) -> float: # pragma: no cover + return a + + with pytest.raises(AssertionError): + + @inject(extra_dependencies=(Depends(dep_func),)) + def some_func(a: int, b: int) -> str: # pragma: no cover + return a + b + + def test_generator(): mock = Mock() + def simple_func(): + mock.simple() + def func(): mock.start() yield mock.end() @inject - def simple_func(a: str, d=Depends(func)) -> int: + def simple_func(a: str, d2=Depends(simple_func), d=Depends(func)): for _ in range(2): yield a for i in simple_func("1"): mock.start.assert_called_once() assert not mock.end.called - assert i == 1 + assert i == "1" + mock.simple.assert_called_once() mock.end.assert_called_once() @@ -341,3 +293,115 @@ def func(a=Depends(partial(dep, 10))): # noqa: B008 return a assert func() == 10 + + +@serializer +class TestSerializer: + def test_not_cast(self): + @dataclass + class A: + a: int + + def dep1() -> A: + return {"a": 1} + + def dep2() -> A: + return {"a": 1} + + def dep3() -> A: + return 1 + + def get_logger() -> logging.Logger: + return logging.getLogger(__file__) + + @inject + def some_func( + b, + a1: A = Depends(dep1, cast=False, cast_result=True), + a2: A = Depends(dep2, cast=True, cast_result=False), + a3: A = Depends(dep3, cast=False, cast_result=False), + logger: logging.Logger = Depends(get_logger), + ): + assert a1.a == 1 + assert a2.a == 1 + assert a3 == 1 + assert logger + return b + + assert some_func(1) == 1 + + def test_depends_error(self): + def dep_func(b: dict, a: int = 3) -> float: # pragma: no cover + return a + b + + def another_func(b: int, a: int = 3) -> dict: # pragma: no cover + return a + b + + @inject + def some_func( + b: int, c=Depends(dep_func), d=Depends(another_func) + ) -> int: # pragma: no cover + assert c is None + return b + + with pytest.raises(ValidationError): + assert some_func("2") == 7 + + def test_depends_response_cast(self): + def dep_func(a): + return a + + @inject + def some_func(a: int, b: int, c: int = Depends(dep_func)) -> float: + assert isinstance(c, int) + assert c == a + return a + b + c + + assert some_func("1", "2") + + def test_depends_annotated_str(self): + def dep_func(a): + return a + + @inject + def some_func( + a: int, + b: int, + c: "Annotated[int, Depends(dep_func)]", + ) -> float: + assert isinstance(c, int) + return a + b + c + + @inject + def another_func( + a: int, + c: "Annotated[int, Depends(dep_func)]", + ): + return a + c + + assert some_func("1", "2") + assert another_func("3") == 6.0 + + +def test_default_key_value(): + def dep(a: str = "a"): + return a + + @inject(cast=False) + def func(a=Depends(dep)): + return a + + assert func() == "a" + + +def test_contextmanager(): + def dep(a: str): + return a + + @contextmanager + @inject + def func(a: str, b: str = Depends(dep)): + yield a == b + + with func("a") as is_equal: + assert is_equal diff --git a/tests/test_no_serializer.py b/tests/test_no_serializer.py new file mode 100644 index 00000000..42d2a016 --- /dev/null +++ b/tests/test_no_serializer.py @@ -0,0 +1,11 @@ +from fast_depends import inject + + +def test_generator(): + @inject(serializer_cls=None) + def simple_func(a: str) -> str: + for _ in range(2): + yield a + + for i in simple_func(1): + assert i == 1 diff --git a/tests/test_overrides.py b/tests/test_overrides.py index da3f8999..edaaf5c7 100644 --- a/tests/test_overrides.py +++ b/tests/test_overrides.py @@ -2,13 +2,14 @@ import pytest -from fast_depends import Depends, dependency_provider, inject +from fast_depends import Depends, Provider, inject @pytest.fixture -def provider(): - yield dependency_provider - dependency_provider.clear() +def provider() -> Provider: + provider = Provider() + yield provider + provider.clear() def test_not_override(provider): @@ -18,7 +19,7 @@ def base_dep(): # pragma: no cover mock.original() return 1 - @inject(dependency_overrides_provider=None) + @inject(dependency_provider=provider) def func(d=Depends(base_dep)): assert d == 1 @@ -40,7 +41,7 @@ def override_dep(): provider.override(base_dep, override_dep) - @inject + @inject(dependency_provider=provider) def func(d=Depends(base_dep)): assert d == 2 @@ -57,7 +58,7 @@ def base_dep(): def override_dep(): return 2 - @inject + @inject(dependency_provider=provider) def func(d=Depends(base_dep)): return d @@ -76,12 +77,30 @@ async def override_dep(): # pragma: no cover provider.override(base_dep, override_dep) - @inject - def func(d=Depends(base_dep)): - pass + with pytest.raises(AssertionError): + + @inject(dependency_provider=provider) + def func(d=Depends(base_dep)): + pass + + +def test_sync_by_async_override_in_extra(provider): + def base_dep(): # pragma: no cover + return 1 + + async def override_dep(): # pragma: no cover + return 2 + + provider.override(base_dep, override_dep) with pytest.raises(AssertionError): - func() + + @inject( + dependency_provider=provider, + extra_dependencies=(Depends(base_dep),), + ) + def func(): + pass @pytest.mark.anyio @@ -98,7 +117,7 @@ async def override_dep(): provider.override(base_dep, override_dep) - @inject + @inject(dependency_provider=provider) async def func(d=Depends(base_dep)): assert d == 2 @@ -122,7 +141,7 @@ def override_dep(): provider.override(base_dep, override_dep) - @inject + @inject(dependency_provider=provider) async def func(d=Depends(base_dep)): assert d == 2 @@ -130,3 +149,34 @@ async def func(d=Depends(base_dep)): mock.override.assert_called_once() assert not mock.original.called + + +def test_deep_overrides(provider): + mock = Mock() + + def dep1(c=Depends(mock.dep2)): + mock.dep1() + + def dep3(c=Depends(mock.dep4)): + mock.dep3() + + @inject( + dependency_provider=provider, + extra_dependencies=(Depends(dep1),), + ) + def func(): + return + + func() + mock.dep1.assert_called_once() + mock.dep2.assert_called_once() + assert not mock.dep3.called + assert not mock.dep4.called + mock.reset_mock() + + with provider.scope(dep1, dep3): + func() + assert not mock.dep1.called + assert not mock.dep2.called + mock.dep3.assert_called_once() + mock.dep4.assert_called_once() diff --git a/tests/test_params.py b/tests/test_params.py index 9e428c33..c739da1a 100644 --- a/tests/test_params.py +++ b/tests/test_params.py @@ -1,4 +1,4 @@ -from fast_depends import Depends +from fast_depends import Depends, Provider from fast_depends.core import build_call_model from fast_depends.library import CustomField @@ -19,7 +19,91 @@ def main(a, b, m=Depends(func2), k=Depends(func3)): def extra_func(n): ... - model = build_call_model(main, extra_dependencies=(Depends(extra_func),)) + model = build_call_model( + main, extra_dependencies=(Depends(extra_func),), dependency_provider=Provider() + ) - assert set(model.params.keys()) == {"a", "b"} - assert set(model.flat_params.keys()) == {"a", "b", "c", "m", "n"} + assert {p.field_name for p in model.params} == {"a", "b"} + assert {p.field_name for p in model.flat_params} == {"a", "b", "c", "m", "n"} + + +def test_args_kwargs_params(): + def func1(m): + ... + + def func2(c, b=Depends(func1), d=CustomField()): # noqa: B008 + ... + + def func3(b): + ... + + def default_var_names(a, *args, b, m=Depends(func2), k=Depends(func3), **kwargs): + return a, args, b, kwargs + + def extra_func(n): + ... + + model = build_call_model( + default_var_names, + extra_dependencies=(Depends(extra_func),), + dependency_provider=Provider(), + ) + + assert {p.field_name for p in model.params} == {"a", "args", "b", "kwargs"} + assert {p.field_name for p in model.flat_params} == { + "a", + "args", + "b", + "kwargs", + "c", + "m", + "n", + } + + assert default_var_names(1, *("a"), b=2, **{"kw": "kw"}) == ( + 1, + ("a",), + 2, + {"kw": "kw"}, + ) + + +def test_custom_args_kwargs_params(): + def func1(m): + ... + + def func2(c, b=Depends(func1), d=CustomField()): # noqa: B008 + ... + + def func3(b): + ... + + def extra_func(n): + ... + + def custom_var_names(a, *args_, b, m=Depends(func2), k=Depends(func3), **kwargs_): + return a, args_, b, kwargs_ + + model = build_call_model( + custom_var_names, + extra_dependencies=(Depends(extra_func),), + dependency_provider=Provider(), + ) + + assert {p.field_name for p in model.params} == {"a", "args_", "b", "kwargs_"} + assert {p.field_name for p in model.flat_params} == { + "a", + "args_", + "b", + "kwargs_", + "c", + "m", + "n", + } + + assert custom_var_names(1, *("a"), b=2, **{"kw": "kw"}) == ( + 1, + ("a",), + 2, + {"kw": "kw"}, + ) diff --git a/tests/test_prebuild.py b/tests/test_prebuild.py index c6e01775..903d8472 100644 --- a/tests/test_prebuild.py +++ b/tests/test_prebuild.py @@ -1,5 +1,5 @@ +from fast_depends import Provider, inject from fast_depends.core import build_call_model -from fast_depends.use import inject def base_func(a: int) -> str: @@ -7,5 +7,5 @@ def base_func(a: int) -> str: def test_prebuild(): - model = build_call_model(base_func) + model = build_call_model(base_func, dependency_provider=Provider()) inject()(None, model)(1)