From bc6ae290662fc67d6e914dfe254fafc432a3d7fa Mon Sep 17 00:00:00 2001 From: Antony Lewis Date: Mon, 4 Nov 2024 12:45:21 +0000 Subject: [PATCH] fix --- cobaya/typing.py | 7 ++++--- tests/test_type_checking.py | 7 +++++-- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/cobaya/typing.py b/cobaya/typing.py index bca29629..d0dde1e7 100644 --- a/cobaya/typing.py +++ b/cobaya/typing.py @@ -165,8 +165,9 @@ def validate_type(expected_type: type, value: Any, path: str = ''): validate_type(type_hints[key], val, f"{path}.{key}" if path else str(key)) return True - if origin := typing.get_origin(expected_type): - args = typing.get_args(expected_type) + if (origin := typing.get_origin(expected_type)) and ( + args := typing.get_args(expected_type)): + # complex types like Dict[str, float] etc. if origin is Union: errors = [] @@ -209,7 +210,7 @@ def validate_type(expected_type: type, value: Any, path: str = ''): if origin is typing.ClassVar: return validate_type(args[0], value, path) - if origin in (dict, Mapping): + if issubclass(origin, Mapping): if not isinstance(value, Mapping): raise TypeError(f"{curr_path} must be a mapping, " f"got {type(value).__name__}") diff --git a/tests/test_type_checking.py b/tests/test_type_checking.py index 83ec0877..c0b13cab 100644 --- a/tests/test_type_checking.py +++ b/tests/test_type_checking.py @@ -1,6 +1,6 @@ """General test for types of components.""" -from typing import Any, ClassVar, Dict, List, Optional, Tuple +from typing import Any, ClassVar, Dict, List, Optional, Tuple, Mapping import numpy as np import pytest @@ -57,6 +57,7 @@ class GenericComponent(CobayaComponent): tuple_params: Tuple[float, float] = (0.0, 1.0) array: Sequence[float] array2: Sequence[float] + map: Mapping[float, str] _enforce_types = True @@ -75,7 +76,8 @@ def test_component_types(): "params": {"a": [0.0, 1.0], "b": [0, 1]}, "tuple_params": (0.0, 1.0), "array": np.arange(2, dtype=np.float64), - "array2": [1, 2] + "array2": [1, 2], + "map": {1.0: "a", 2.0: "b"} } GenericComponent(correct_kwargs) @@ -95,6 +97,7 @@ def test_component_types(): {"tuple_params": "not_a_tuple"}, {"tuple_params": (0.0, "not_a_float")}, {"array": 2}, + {"map": {"a": 2.0}} ] for case in wrong_cases: with pytest.raises(TypeError):