Skip to content

Commit

Permalink
feat: Try to inject parameter value if default value is None (#110)
Browse files Browse the repository at this point in the history
* feat: Try to inject parameter value if default value is None

* style(pre-commit.ci): auto fixes [...]

* test: add test

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Talley Lambert <[email protected]>
  • Loading branch information
3 people authored Apr 22, 2024
1 parent 190cdb9 commit 8f8ec35
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 1 deletion.
5 changes: 4 additions & 1 deletion src/in_n_out/_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,7 +778,10 @@ def _exec(*args: P.args, **kwargs: P.kwargs) -> R:
# first, get and call the provider functions for each parameter type:
_injected_names: set[str] = set()
for param in sig.parameters.values():
if param.name not in bound.arguments:
if (
param.name not in bound.arguments
or bound.arguments[param.name] is None
):
provided = self.provide(param.annotation)
if provided is not None or is_optional(param.annotation):
logger.debug(
Expand Down
15 changes: 15 additions & 0 deletions tests/test_injection.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,3 +296,18 @@ def f(i: Optional[Thing]) -> Optional[Thing]:
thing = Thing()
with register(providers={Optional[Thing]: lambda: thing}):
assert inject(f)() is thing


def test_inject_into_optional_with_default() -> None:
class Thing: ...

def f(i: Optional[Thing] = None) -> Optional[Thing]:
return i

thing = Thing()
with register(providers={Optional[Thing]: lambda: thing}):
assert inject(f)() is thing
with register(providers={Thing: lambda: thing}):
assert inject(f)() is thing

assert inject(f)() is None

0 comments on commit 8f8ec35

Please sign in to comment.