From 400ef6ba29e2caf37c3fc742f5be4fdcbb9daa19 Mon Sep 17 00:00:00 2001 From: Antony Lewis Date: Thu, 7 Nov 2024 20:08:38 +0000 Subject: [PATCH] fix for non-types --- cobaya/tools.py | 1 - cobaya/typing.py | 27 ++++++++++++++------------- tests/test_type_checking.py | 2 ++ 3 files changed, 16 insertions(+), 14 deletions(-) diff --git a/cobaya/tools.py b/cobaya/tools.py index bfa5f1eb..7b777363 100644 --- a/cobaya/tools.py +++ b/cobaya/tools.py @@ -306,7 +306,6 @@ def get_external_function(string_or_function, name=None): if isinstance(string_or_function, str): try: scope = globals() - import scipy.stats as stats # provide default scope for eval scope['stats'] = stats scope['np'] = np string_or_function = replace_optimizations(string_or_function) diff --git a/cobaya/typing.py b/cobaya/typing.py index a6b3cf45..eccbbed2 100644 --- a/cobaya/typing.py +++ b/cobaya/typing.py @@ -201,7 +201,7 @@ def validate_type(expected_type: type, value: Any, path: str = ''): "\n".join(f"- {e}" for e in path_errors) ) - if origin is typing.ClassVar: + if not isinstance(origin, type): return validate_type(args[0], value, path) if isinstance(value, Mapping) != issubclass(origin, Mapping): @@ -245,18 +245,19 @@ def validate_type(expected_type: type, value: Any, path: str = ''): validate_type(t, v, f"{path}[{i}]" if path else f"[{i}]") return - if not (isinstance(value, expected_type) or - expected_type is Sequence and isinstance(value, np.ndarray)): + if not isinstance(expected_type, type) or isinstance(value, expected_type) \ + or expected_type is Sequence and isinstance(value, np.ndarray): + return - type_name = getattr(expected_type, "__name__", repr(expected_type)) + type_name = getattr(expected_type, "__name__", repr(expected_type)) - # special case for Cobaya's NumberWithUnits, if not instance yet - if type_name == 'NumberWithUnits': - if not isinstance(value, (numbers.Real, str)): - raise TypeError( - f"{curr_path} must be a number or string for NumberWithUnits," - f" got {type(value).__name__}") - return + # special case for Cobaya's NumberWithUnits, if not instance yet + if type_name == 'NumberWithUnits': + if not isinstance(value, (numbers.Real, str)): + raise TypeError( + f"{curr_path} must be a number or string for NumberWithUnits," + f" got {type(value).__name__}") + return - raise TypeError(f"{curr_path} must be of type {type_name}, " - f"got {type(value).__name__}") + raise TypeError(f"{curr_path} must be of type {type_name}, " + f"got {type(value).__name__}") diff --git a/tests/test_type_checking.py b/tests/test_type_checking.py index 63e8bbb2..4a895a4b 100644 --- a/tests/test_type_checking.py +++ b/tests/test_type_checking.py @@ -26,6 +26,7 @@ class GenericComponent(CobayaComponent): map: Mapping[float, str] deferred: 'ParamDict' unset = 1 + install_options: ClassVar _enforce_types = True @@ -47,6 +48,7 @@ def test_component_types(): "array2": [1, 2], "map": {1.0: "a", 2.0: "b"}, "deferred": {'value': lambda x: x}, + "install_options": {} } GenericComponent(correct_kwargs)