Skip to content

Commit

Permalink
feat: add on_unresolved_required_args="ignore" to inject() (#33)
Browse files Browse the repository at this point in the history
* feat: add ignore required

* test: another test

* test: pragma
  • Loading branch information
tlambert03 authored Aug 12, 2022
1 parent 394ca0a commit 65d3886
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 62 deletions.
8 changes: 4 additions & 4 deletions src/in_n_out/_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def __init__(self, name: str) -> None:
self._providers: List[_RegisteredCallback] = []
self._processors: List[_RegisteredCallback] = []
self._namespace: Union[Namespace, Callable[[], Namespace], None] = None
self.on_unresolved_required_args: RaiseWarnReturnIgnore = "raise"
self.on_unresolved_required_args: RaiseWarnReturnIgnore = "warn"
self.on_unannotated_required_args: RaiseWarnReturnIgnore = "warn"
self.guess_self: bool = True

Expand Down Expand Up @@ -655,13 +655,13 @@ def inject(
on_unresolved_required_args : RaiseWarnReturnIgnore
What to do when a required parameter (one without a default) is encountered
with an unresolvable type annotation.
Must be one of the following (by default 'raise'):
Must be one of the following (by default 'warn'):
- 'raise': immediately raise an exception
- 'warn': warn and return the original function
- 'return': return the original function without warning
- 'ignore': currently an alias for `return`, but will be used in
the future to allow the decorator to proceed.
- 'ignore': continue decorating without warning (at call time, this
function will fail without additional arguments).
on_unannotated_required_args : RaiseWarnReturnIgnore
What to do when a required parameter (one without a default) is encountered
Expand Down
89 changes: 57 additions & 32 deletions src/in_n_out/_type_resolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import warnings
from functools import lru_cache, partial
from inspect import Signature
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Type
from typing import TYPE_CHECKING, Any, Callable, Dict, ForwardRef, Optional, Tuple, Type

try:
from toolz import curry
Expand Down Expand Up @@ -108,6 +108,7 @@ def type_resolved_signature(
*,
localns: Optional[dict] = None,
raise_unresolved_optional_args: bool = True,
raise_unresolved_required_args: bool = True,
guess_self: bool = True,
) -> Signature:
"""Return a Signature object for a function with resolved type annotations.
Expand All @@ -121,6 +122,9 @@ def type_resolved_signature(
raise_unresolved_optional_args : bool
Whether to raise an exception when an optional parameter (one with a default
value) has an unresolvable type annotation, by default True
raise_unresolved_required_args : bool
Whether to raise an exception when a required parameter has an unresolvable
type annotation, by default True
guess_self : bool
Whether to infer the type of the first argument if the function is an unbound
class method. This is done as follows:
Expand Down Expand Up @@ -178,7 +182,9 @@ class method. This is done as follows:
"To allow optional parameters and return types to remain unresolved, "
"use `raise_unresolved_optional_args=False`"
) from err
hints = _resolve_mandatory_params(sig)
hints = _resolve_params_one_by_one(
sig, exclude_unresolved_mandatory=not raise_unresolved_required_args
)

resolved_parameters = [
param.replace(annotation=hints.get(param.name, param.empty))
Expand All @@ -190,9 +196,10 @@ class method. This is done as follows:
)


def _resolve_mandatory_params(
def _resolve_params_one_by_one(
sig: Signature,
exclude_unresolved_optionals: bool = False,
exclude_unresolved_mandatory: bool = False,
) -> Dict[str, Any]:
"""Resolve all required param annotations in `sig`, but allow optional ones to fail.
Expand All @@ -201,8 +208,16 @@ def _resolve_mandatory_params(
It resolves each parameter's type annotation independently, and only raises an
error if a parameter without a default value has an unresolvable type annotation.
If `exclude_unresolved_optionals` is `True`, then unresolved optional parameters
will not appear in the output dict
Parameters
----------
sig : Signature
:class:`inspect.Signature` object with unresolved type annotations.
exclude_unresolved_optionals : bool
Whether to exclude parameters with unresolved type annotations that have a
default value, by default False
exclude_unresolved_mandatory : bool
Whether to exclude parameters with unresolved type annotations that do not
have a default value, by default False
Returns
-------
Expand All @@ -221,12 +236,17 @@ def _resolve_mandatory_params(
try:
hints[name] = resolve_single_type_hints(param.annotation)[0]
except NameError as e:
if param.default is param.empty:
if (
param.default is param.empty
and exclude_unresolved_mandatory
or param.default is not param.empty
and not exclude_unresolved_optionals
):
hints[name] = param.annotation
elif param.default is param.empty:
raise NameError(
f"Could not resolve type hint for required parameter {name!r}: {e}"
) from e
elif not exclude_unresolved_optionals:
hints[name] = param.annotation
if sig.return_annotation is not sig.empty:
try:
hints["return"] = resolve_single_type_hints(sig.return_annotation)[0]
Expand All @@ -247,32 +267,37 @@ def _resolve_sig_or_inform(
all parameters are described above in inject_dependencies
"""
try:
sig = type_resolved_signature(
func,
localns=localns,
raise_unresolved_optional_args=False,
guess_self=guess_self,
)
except NameError as e:
errmsg = str(e)
if on_unresolved_required_args == "raise":
msg = (
f"{errmsg}. To simply return the original function, pass `on_un"
'resolved_required_args="return"`. To emit a warning, pass "warn".'
)
raise NameError(msg) from e
if on_unresolved_required_args == "warn":
msg = (
f"{errmsg}. To suppress this warning and simply return the original "
'function, pass `on_unresolved_required_args="return"`.'
)

warnings.warn(msg, UserWarning, stacklevel=2)
return None
sig = type_resolved_signature(
func,
localns=localns,
raise_unresolved_optional_args=False,
raise_unresolved_required_args=False,
guess_self=guess_self,
)

for param in sig.parameters.values():
if param.default is param.empty and param.annotation is param.empty:
if param.default is not param.empty:
continue # pragma: no cover
if isinstance(param.annotation, (str, ForwardRef)):
errmsg = (
f"Could not resolve type hint for required parameter {param.name!r}"
)
if on_unresolved_required_args == "raise":
msg = (
f"{errmsg}. To simply return the original function, pass `on_un"
'annotated_required_args="return"`. To emit a warning, pass "warn".'
)
raise NameError(msg)
elif on_unresolved_required_args == "warn":
msg = (
f"{errmsg}. To suppress this warning and simply return the original"
' function, pass `on_unannotated_required_args="return"`.'
)
warnings.warn(msg, UserWarning, stacklevel=2)
elif on_unresolved_required_args == "return":
return None

elif param.annotation is param.empty:
fname = (getattr(func, "__name__", ""),)
name = param.name
base = (
Expand Down
61 changes: 35 additions & 26 deletions tests/test_injection.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,11 @@ def f():
modes = ["raise", "warn", "return", "ignore"]


def unknown(v: "Unknown") -> int: # type: ignore # noqa
def unannotated(x) -> int: # type: ignore # noqa
...


def unannotated(x) -> int: # type: ignore # noqa
def unknown(v: "Unknown") -> int: # type: ignore # noqa
...


Expand All @@ -129,47 +129,42 @@ def unknown_and_unannotated(v: "Unknown", x) -> int: # type: ignore # noqa
def test_injection_errors(in_func, on_unresolved, on_unannotated):

ctx: ContextManager = nullcontext()
ctxb: ContextManager = nullcontext()
expect_same_func_back = False

if "unknown" in in_func.__name__: # required params with unknown annotations
UNANNOTATED_MSG = "Injecting dependencies .* with a required, unannotated param"

if "unknown" in in_func.__name__ and on_unresolved != "ignore":
# required params with unknown annotations
UNRESOLVED_MSG = "Could not resolve type hint for required parameter"

if on_unresolved == "raise":
ctx = pytest.raises(
NameError,
match="Could not resolve type hint for required parameter",
)
ctx = pytest.raises(NameError, match=UNRESOLVED_MSG)
elif on_unresolved == "warn":
ctx = pytest.warns(UserWarning, match=UNRESOLVED_MSG)
if "unannotated" in in_func.__name__:
if on_unannotated == "raise":
ctxb = pytest.raises(TypeError, match=UNANNOTATED_MSG)
elif on_unannotated == "return":
expect_same_func_back = True
else:
expect_same_func_back = True
if on_unresolved == "warn":
ctx = pytest.warns(
UserWarning,
match="Could not resolve type hint for required parameter",
)

elif "unannotated" in in_func.__name__: # required params without annotations
if on_unannotated == "raise":
ctx = pytest.raises(
TypeError,
match="Injecting dependencies .* with a required, unannotated param",
)
ctx = pytest.raises(TypeError, match=UNANNOTATED_MSG)
elif on_unannotated == "warn":
ctx = pytest.warns(
UserWarning,
match="Injecting dependencies .* with a required, unannotated param",
)
ctx = pytest.warns(UserWarning, match=UNANNOTATED_MSG)
elif on_unannotated == "return":
expect_same_func_back = True

with ctx:
with ctx, ctxb:
out_func = inject(
in_func,
on_unannotated_required_args=on_unannotated,
on_unresolved_required_args=on_unresolved,
)

if expect_same_func_back:
assert out_func is in_func
else:
assert out_func is not in_func
assert (out_func is in_func) is expect_same_func_back


def test_processors_not_passed_none(test_store: Store):
Expand Down Expand Up @@ -257,3 +252,17 @@ def wrapper2(*args, **kwargs):
foo = Foo()
with register(providers={Foo: lambda: foo}):
assert injected() == foo


def test_partial_annotations():
def func(foo: Foo, bar: "Bar"): # noqa
return foo, bar

with pytest.warns(UserWarning):
injected = inject(func)

injected = inject(func, on_unresolved_required_args="ignore")

foo = Foo()
with register(providers={Foo: lambda: foo}):
assert injected(bar=2) == (foo, 2) # type: ignore

0 comments on commit 65d3886

Please sign in to comment.