diff --git a/python_typing/strict.py b/python_typing/strict.py index 22c3262..b095c7e 100644 --- a/python_typing/strict.py +++ b/python_typing/strict.py @@ -1,6 +1,6 @@ import typing from copy import deepcopy -from inspect import _empty, getfullargspec, signature +from inspect import _empty, getfullargspec, isclass, signature class Typing: @@ -53,6 +53,10 @@ def check_return_type(self, func, return_value): # impact users of this library pass + def _format_type(self, type): + # Get rid of repr + return type.__name__ if isclass(type) else type + def _assert_type_helper(self, value, expected_types): multiple_valid_types = False is_an_expected_type = False @@ -75,10 +79,10 @@ def _assert_type_helper(self, value, expected_types): msg = "" if self._return_type is not None: msg += "return " - return_type = getattr(self._return_type, "__name__", self._return_type) + return_type = self._format_type(self._return_type) func_sig = f'"{self._func_name}() -> {return_type}" ' else: - arg_type = getattr(self._arg_type, "__name__", self._arg_type) + arg_type = self._format_type(self._arg_type) func_sig = f'"{self._func_name}({self._arg_name}={arg_type})" ' shortened_value = str(value)[:10] + (str(value)[10:] and "..") @@ -101,7 +105,7 @@ def _assert_structure(self, structure, expected_structure): expected_structure, "__origin__", expected_structure ) if structure is not expected_structure_type: - arg_type = getattr(self._arg_type, "__name__", self._arg_type) + arg_type = self._format_type(self._arg_type) msg = ( f"Expected type {arg_type} in " f'"{self._func_name}({self._arg_name}={arg_type})" ' diff --git a/python_typing_tests/test_defaults.py b/python_typing_tests/test_defaults.py index 210f921..36e0590 100644 --- a/python_typing_tests/test_defaults.py +++ b/python_typing_tests/test_defaults.py @@ -90,7 +90,8 @@ def _func(_: Optional[str]): with pytest.raises(TypeError) as err: _func() - assert str(err.value) == "_func() missing 1 required positional argument: '_'" + # Assert "in" for python 3.10+ compatibility + assert "_func() missing 1 required positional argument: '_'" in str(err.value) def test_arg_optional_give_none(): diff --git a/python_typing_tests/test_strict.py b/python_typing_tests/test_strict.py index b8593aa..65c2a78 100644 --- a/python_typing_tests/test_strict.py +++ b/python_typing_tests/test_strict.py @@ -20,7 +20,8 @@ def _func(): with pytest.raises(TypeError) as err: _func("arg") - assert str(err.value) == "_func() takes 0 positional arguments but 1 was given" + # Assert "in" for python 3.10+ compatibility + assert "_func() takes 0 positional arguments but 1 was given" in str(err.value) def test_arg_pass_nothing(): @@ -30,7 +31,8 @@ def _func(_): with pytest.raises(TypeError) as err: _func() - assert str(err.value) == "_func() missing 1 required positional argument: '_'" + # Assert "in" for python 3.10+ compatibility + assert "_func() missing 1 required positional argument: '_'" in str(err.value) def test_arg_no_type(): @@ -184,4 +186,7 @@ def _func(_: Callable): with pytest.raises(TypeError) as err: _func(1) - assert str(err.value) == 'Value (1) in "_func(_=typing.Callable)" is not of type Callable' + assert ( + str(err.value) + == 'Value (1) in "_func(_=typing.Callable)" is not of type Callable' + )