diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f0b5cc1..ce919ff 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -85,7 +85,7 @@ jobs: - name: Publish to PyPI uses: pypa/gh-action-pypi-publish@release/v1 - - uses: softprops/action-gh-release@v1 + - uses: softprops/action-gh-release@v2 with: generate_release_notes: true files: dist/* diff --git a/src/in_n_out/_store.py b/src/in_n_out/_store.py index 3469a33..2ef4e25 100644 --- a/src/in_n_out/_store.py +++ b/src/in_n_out/_store.py @@ -27,7 +27,7 @@ ) from ._type_resolution import _resolve_sig_or_inform, resolve_type_hints -from ._util import _split_union, issubclassable +from ._util import _split_union, is_optional, issubclassable logger = getLogger("in_n_out") @@ -775,7 +775,7 @@ def _exec(*args: P.args, **kwargs: P.kwargs) -> R: for param in sig.parameters.values(): if param.name not in bound.arguments: provided = self.provide(param.annotation) - if provided is not None: + if is_optional(param.annotation) or provided is not None: logger.debug( " injecting %s: %s = %r", param.name, diff --git a/src/in_n_out/_util.py b/src/in_n_out/_util.py index d89dbc4..c3fad03 100644 --- a/src/in_n_out/_util.py +++ b/src/in_n_out/_util.py @@ -11,9 +11,15 @@ def _is_union(type_: Any) -> bool: + """Return True if `type_` is a Union type.""" return get_origin(type_) in UNION_TYPES +def is_optional(type_: Any) -> bool: + """Return True if `type_` is Optional[T].""" + return _split_union(type_)[1] + + def _split_union(type_: Any) -> Tuple[List[Type], bool]: optional = False if _is_union(type_): diff --git a/tests/test_injection.py b/tests/test_injection.py index c8de5a3..e7c8c55 100644 --- a/tests/test_injection.py +++ b/tests/test_injection.py @@ -138,7 +138,7 @@ def test_injection_errors(in_func, on_unresolved, on_unannotated): if on_unresolved == "raise": ctx = pytest.raises(NameError, match=UNRESOLVED_MSG) elif on_unresolved == "warn": - ctx = pytest.warns(UserWarning, match=UNRESOLVED_MSG) + ctx = pytest.warns(UserWarning) # will warn both if "unannotated" in in_func.__name__: if on_unannotated == "raise": ctxb = pytest.raises(TypeError, match=UNANNOTATED_MSG) @@ -270,3 +270,23 @@ def func2(bar: "Bar", foo: "Foo") -> tuple["Foo", "Bar"]: # type: ignore # noqa with test_store.register(providers={Foo: lambda: foo}): assert injected(bar=2) == (foo, 2) # type: ignore assert injected2(2) == (foo, 2) # type: ignore + + +def test_inject_into_required_optional() -> None: + class Thing: + ... + + def f(i: Optional[Thing]) -> Optional[Thing]: + return i + + with pytest.raises(TypeError, match="missing 1 required positional argument"): + f() # type: ignore + + assert inject(f)() is None # no provider needed + + with register(providers={Optional[Thing]: lambda: None}): + assert inject(f)() is None + + thing = Thing() + with register(providers={Optional[Thing]: lambda: thing}): + assert inject(f)() is thing