Skip to content

Commit

Permalink
fix: injecting None into takes optional (#106)
Browse files Browse the repository at this point in the history
* fix: injecting into takes optional

* style(pre-commit.ci): auto fixes [...]

* add test

* fix: test

* fix py38

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
tlambert03 and pre-commit-ci[bot] authored Mar 18, 2024
1 parent 27a3f9e commit bc485a6
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 3 deletions.
4 changes: 2 additions & 2 deletions src/in_n_out/_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
)

from ._type_resolution import _resolve_sig_or_inform, resolve_type_hints
from ._util import _split_union, issubclassable
from ._util import _split_union, is_optional, issubclassable

logger = getLogger("in_n_out")

Expand Down Expand Up @@ -781,7 +781,7 @@ def _exec(*args: P.args, **kwargs: P.kwargs) -> R:
for param in sig.parameters.values():
if param.name not in bound.arguments:
provided = self.provide(param.annotation)
if provided is not None:
if is_optional(param.annotation) or provided is not None:
logger.debug(
" injecting %s: %s = %r",
param.name,
Expand Down
6 changes: 6 additions & 0 deletions src/in_n_out/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,15 @@


def _is_union(type_: Any) -> bool:
"""Return True if `type_` is a Union type."""
return get_origin(type_) in UNION_TYPES


def is_optional(type_: Any) -> bool:
"""Return True if `type_` is Optional[T]."""
return _split_union(type_)[1]


def _split_union(type_: Any) -> Tuple[List[Type], bool]:
optional = False
if _is_union(type_):
Expand Down
22 changes: 21 additions & 1 deletion tests/test_injection.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def test_injection_errors(in_func, on_unresolved, on_unannotated):
if on_unresolved == "raise":
ctx = pytest.raises(NameError, match=UNRESOLVED_MSG)
elif on_unresolved == "warn":
ctx = pytest.warns(UserWarning, match=UNRESOLVED_MSG)
ctx = pytest.warns(UserWarning) # will warn both
if "unannotated" in in_func.__name__:
if on_unannotated == "raise":
ctxb = pytest.raises(TypeError, match=UNANNOTATED_MSG)
Expand Down Expand Up @@ -271,3 +271,23 @@ def func2(bar: "Bar", foo: "Foo"): # noqa
with test_store.register(providers={Foo: lambda: foo}):
assert injected(bar=2) == (foo, 2) # type: ignore
assert injected2(2) == (foo, 2) # type: ignore


def test_inject_into_required_optional() -> None:
class Thing:
...

def f(i: Optional[Thing]) -> Optional[Thing]:
return i

with pytest.raises(TypeError, match="missing 1 required positional argument"):
f() # type: ignore

assert inject(f)() is None # no provider needed

with register(providers={Optional[Thing]: lambda: None}):
assert inject(f)() is None

thing = Thing()
with register(providers={Optional[Thing]: lambda: thing}):
assert inject(f)() is thing

0 comments on commit bc485a6

Please sign in to comment.