diff --git a/src/in_n_out/_store.py b/src/in_n_out/_store.py index d266180..5d445c2 100644 --- a/src/in_n_out/_store.py +++ b/src/in_n_out/_store.py @@ -5,7 +5,7 @@ import warnings import weakref from functools import cached_property, wraps -from inspect import CO_VARARGS, isgeneratorfunction +from inspect import CO_VARARGS, isgeneratorfunction, unwrap from types import CodeType from typing import ( TYPE_CHECKING, @@ -719,7 +719,7 @@ def _inner(func: Callable[P, R]) -> Callable[P, R]: return self.inject_processors(func) if processors else func # bail if there aren't any annotations at all - code: Optional[CodeType] = getattr(func, "__code__", None) + code: Optional[CodeType] = getattr(unwrap(func), "__code__", None) if (code and not code.co_argcount) and "return" not in getattr( func, "__annotations__", {} ): diff --git a/tests/test_injection.py b/tests/test_injection.py index 7659547..2f8a17c 100644 --- a/tests/test_injection.py +++ b/tests/test_injection.py @@ -1,3 +1,4 @@ +import functools from contextlib import nullcontext from inspect import isgeneratorfunction from typing import ContextManager, Generator, Optional @@ -237,3 +238,22 @@ def generator_func() -> Generator: with pytest.raises(TypeError, match="generator function"): inject(generator_func, processors=True) + + +def test_wrapped_functions(): + def func(foo: Foo): + return foo + + @functools.wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + @functools.wraps(wrapper) + def wrapper2(*args, **kwargs): + return wrapper(*args, **kwargs) + + injected = inject(wrapper2) + + foo = Foo() + with register(providers={Foo: lambda: foo}): + assert injected() == foo diff --git a/tests/test_type_resolution.py b/tests/test_type_resolution.py index f5d9de0..b4c1ae0 100644 --- a/tests/test_type_resolution.py +++ b/tests/test_type_resolution.py @@ -98,3 +98,20 @@ def func2(x: int, y: str, z: list): ppf = pf(z=["hi"]) assert resolve_type_hints(ppf) == {"x": int, "y": str, "z": list} + + +def test_wrapped_resolution() -> None: + from functools import wraps + + def func(x: int, y: str, z: list): + ... + + @wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + @wraps(wrapper) + def wrapper2(*args, **kwargs): + return wrapper(*args, **kwargs) + + assert resolve_type_hints(wrapper2) == {"x": int, "y": str, "z": list}