Skip to content

Commit

Permalink
fix: fix functools.wrapped functions (#28)
Browse files Browse the repository at this point in the history
* fix: fix functools.wrapped functions

* test: add test for type resolution as well
  • Loading branch information
tlambert03 authored Jul 15, 2022
1 parent a5c1803 commit eb6cbb0
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/in_n_out/_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import warnings
import weakref
from functools import cached_property, wraps
from inspect import CO_VARARGS, isgeneratorfunction
from inspect import CO_VARARGS, isgeneratorfunction, unwrap
from types import CodeType
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -719,7 +719,7 @@ def _inner(func: Callable[P, R]) -> Callable[P, R]:
return self.inject_processors(func) if processors else func

# bail if there aren't any annotations at all
code: Optional[CodeType] = getattr(func, "__code__", None)
code: Optional[CodeType] = getattr(unwrap(func), "__code__", None)
if (code and not code.co_argcount) and "return" not in getattr(
func, "__annotations__", {}
):
Expand Down
20 changes: 20 additions & 0 deletions tests/test_injection.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
from contextlib import nullcontext
from inspect import isgeneratorfunction
from typing import ContextManager, Generator, Optional
Expand Down Expand Up @@ -237,3 +238,22 @@ def generator_func() -> Generator:

with pytest.raises(TypeError, match="generator function"):
inject(generator_func, processors=True)


def test_wrapped_functions():
def func(foo: Foo):
return foo

@functools.wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)

@functools.wraps(wrapper)
def wrapper2(*args, **kwargs):
return wrapper(*args, **kwargs)

injected = inject(wrapper2)

foo = Foo()
with register(providers={Foo: lambda: foo}):
assert injected() == foo
17 changes: 17 additions & 0 deletions tests/test_type_resolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,20 @@ def func2(x: int, y: str, z: list):
ppf = pf(z=["hi"])

assert resolve_type_hints(ppf) == {"x": int, "y": str, "z": list}


def test_wrapped_resolution() -> None:
from functools import wraps

def func(x: int, y: str, z: list):
...

@wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)

@wraps(wrapper)
def wrapper2(*args, **kwargs):
return wrapper(*args, **kwargs)

assert resolve_type_hints(wrapper2) == {"x": int, "y": str, "z": list}

0 comments on commit eb6cbb0

Please sign in to comment.