Skip to content

Commit

Permalink
Merge branch 'main' into ruff-format
Browse files Browse the repository at this point in the history
  • Loading branch information
tlambert03 committed Mar 18, 2024
2 parents 95520c8 + bf0bd26 commit fa865fd
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 4 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ jobs:
- name: Publish to PyPI
uses: pypa/gh-action-pypi-publish@release/v1

- uses: softprops/action-gh-release@v1
- uses: softprops/action-gh-release@v2
with:
generate_release_notes: true
files: dist/*
4 changes: 2 additions & 2 deletions src/in_n_out/_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
)

from ._type_resolution import _resolve_sig_or_inform, resolve_type_hints
from ._util import _split_union, issubclassable
from ._util import _split_union, is_optional, issubclassable

logger = getLogger("in_n_out")

Expand Down Expand Up @@ -775,7 +775,7 @@ def _exec(*args: P.args, **kwargs: P.kwargs) -> R:
for param in sig.parameters.values():
if param.name not in bound.arguments:
provided = self.provide(param.annotation)
if provided is not None:
if is_optional(param.annotation) or provided is not None:
logger.debug(
" injecting %s: %s = %r",
param.name,
Expand Down
6 changes: 6 additions & 0 deletions src/in_n_out/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,15 @@


def _is_union(type_: Any) -> bool:
"""Return True if `type_` is a Union type."""
return get_origin(type_) in UNION_TYPES


def is_optional(type_: Any) -> bool:
"""Return True if `type_` is Optional[T]."""
return _split_union(type_)[1]


def _split_union(type_: Any) -> Tuple[List[Type], bool]:
optional = False
if _is_union(type_):
Expand Down
22 changes: 21 additions & 1 deletion tests/test_injection.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def test_injection_errors(in_func, on_unresolved, on_unannotated):
if on_unresolved == "raise":
ctx = pytest.raises(NameError, match=UNRESOLVED_MSG)
elif on_unresolved == "warn":
ctx = pytest.warns(UserWarning, match=UNRESOLVED_MSG)
ctx = pytest.warns(UserWarning) # will warn both
if "unannotated" in in_func.__name__:
if on_unannotated == "raise":
ctxb = pytest.raises(TypeError, match=UNANNOTATED_MSG)
Expand Down Expand Up @@ -270,3 +270,23 @@ def func2(bar: "Bar", foo: "Foo") -> tuple["Foo", "Bar"]: # type: ignore # noqa
with test_store.register(providers={Foo: lambda: foo}):
assert injected(bar=2) == (foo, 2) # type: ignore
assert injected2(2) == (foo, 2) # type: ignore


def test_inject_into_required_optional() -> None:
class Thing:
...

def f(i: Optional[Thing]) -> Optional[Thing]:
return i

with pytest.raises(TypeError, match="missing 1 required positional argument"):
f() # type: ignore

assert inject(f)() is None # no provider needed

with register(providers={Optional[Thing]: lambda: None}):
assert inject(f)() is None

thing = Thing()
with register(providers={Optional[Thing]: lambda: thing}):
assert inject(f)() is thing

0 comments on commit fa865fd

Please sign in to comment.