From bc485a6d233f9453065bce69d494dcc7c78a414e Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Mon, 18 Mar 2024 12:02:45 -0400 Subject: [PATCH] fix: injecting None into takes optional (#106) * fix: injecting into takes optional * style(pre-commit.ci): auto fixes [...] * add test * fix: test * fix py38 --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- src/in_n_out/_store.py | 4 ++-- src/in_n_out/_util.py | 6 ++++++ tests/test_injection.py | 22 +++++++++++++++++++++- 3 files changed, 29 insertions(+), 3 deletions(-) diff --git a/src/in_n_out/_store.py b/src/in_n_out/_store.py index 9d049d5..7c03a1d 100644 --- a/src/in_n_out/_store.py +++ b/src/in_n_out/_store.py @@ -27,7 +27,7 @@ ) from ._type_resolution import _resolve_sig_or_inform, resolve_type_hints -from ._util import _split_union, issubclassable +from ._util import _split_union, is_optional, issubclassable logger = getLogger("in_n_out") @@ -781,7 +781,7 @@ def _exec(*args: P.args, **kwargs: P.kwargs) -> R: for param in sig.parameters.values(): if param.name not in bound.arguments: provided = self.provide(param.annotation) - if provided is not None: + if is_optional(param.annotation) or provided is not None: logger.debug( " injecting %s: %s = %r", param.name, diff --git a/src/in_n_out/_util.py b/src/in_n_out/_util.py index d89dbc4..c3fad03 100644 --- a/src/in_n_out/_util.py +++ b/src/in_n_out/_util.py @@ -11,9 +11,15 @@ def _is_union(type_: Any) -> bool: + """Return True if `type_` is a Union type.""" return get_origin(type_) in UNION_TYPES +def is_optional(type_: Any) -> bool: + """Return True if `type_` is Optional[T].""" + return _split_union(type_)[1] + + def _split_union(type_: Any) -> Tuple[List[Type], bool]: optional = False if _is_union(type_): diff --git a/tests/test_injection.py b/tests/test_injection.py index 6b999c1..7a55be7 100644 --- a/tests/test_injection.py +++ b/tests/test_injection.py @@ -139,7 +139,7 @@ def test_injection_errors(in_func, on_unresolved, on_unannotated): if on_unresolved == "raise": ctx = pytest.raises(NameError, match=UNRESOLVED_MSG) elif on_unresolved == "warn": - ctx = pytest.warns(UserWarning, match=UNRESOLVED_MSG) + ctx = pytest.warns(UserWarning) # will warn both if "unannotated" in in_func.__name__: if on_unannotated == "raise": ctxb = pytest.raises(TypeError, match=UNANNOTATED_MSG) @@ -271,3 +271,23 @@ def func2(bar: "Bar", foo: "Foo"): # noqa with test_store.register(providers={Foo: lambda: foo}): assert injected(bar=2) == (foo, 2) # type: ignore assert injected2(2) == (foo, 2) # type: ignore + + +def test_inject_into_required_optional() -> None: + class Thing: + ... + + def f(i: Optional[Thing]) -> Optional[Thing]: + return i + + with pytest.raises(TypeError, match="missing 1 required positional argument"): + f() # type: ignore + + assert inject(f)() is None # no provider needed + + with register(providers={Optional[Thing]: lambda: None}): + assert inject(f)() is None + + thing = Thing() + with register(providers={Optional[Thing]: lambda: thing}): + assert inject(f)() is thing