From 8f8ec3584e0ce4626a5c3022f3a687e7ee36fa14 Mon Sep 17 00:00:00 2001 From: Grzegorz Bokota Date: Mon, 22 Apr 2024 20:44:46 +0200 Subject: [PATCH] feat: Try to inject parameter value if default value is None (#110) * feat: Try to inject parameter value if default value is None * style(pre-commit.ci): auto fixes [...] * test: add test --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Talley Lambert --- src/in_n_out/_store.py | 5 ++++- tests/test_injection.py | 15 +++++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/src/in_n_out/_store.py b/src/in_n_out/_store.py index fee18fd..0006bf7 100644 --- a/src/in_n_out/_store.py +++ b/src/in_n_out/_store.py @@ -778,7 +778,10 @@ def _exec(*args: P.args, **kwargs: P.kwargs) -> R: # first, get and call the provider functions for each parameter type: _injected_names: set[str] = set() for param in sig.parameters.values(): - if param.name not in bound.arguments: + if ( + param.name not in bound.arguments + or bound.arguments[param.name] is None + ): provided = self.provide(param.annotation) if provided is not None or is_optional(param.annotation): logger.debug( diff --git a/tests/test_injection.py b/tests/test_injection.py index f0fd052..74b44e8 100644 --- a/tests/test_injection.py +++ b/tests/test_injection.py @@ -296,3 +296,18 @@ def f(i: Optional[Thing]) -> Optional[Thing]: thing = Thing() with register(providers={Optional[Thing]: lambda: thing}): assert inject(f)() is thing + + +def test_inject_into_optional_with_default() -> None: + class Thing: ... + + def f(i: Optional[Thing] = None) -> Optional[Thing]: + return i + + thing = Thing() + with register(providers={Optional[Thing]: lambda: thing}): + assert inject(f)() is thing + with register(providers={Thing: lambda: thing}): + assert inject(f)() is thing + + assert inject(f)() is None