Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
cmbant committed Nov 4, 2024
1 parent 3bdbf5a commit bc6ae29
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
7 changes: 4 additions & 3 deletions cobaya/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,9 @@ def validate_type(expected_type: type, value: Any, path: str = ''):
validate_type(type_hints[key], val, f"{path}.{key}" if path else str(key))
return True

if origin := typing.get_origin(expected_type):
args = typing.get_args(expected_type)
if (origin := typing.get_origin(expected_type)) and (
args := typing.get_args(expected_type)):
# complex types like Dict[str, float] etc.

if origin is Union:
errors = []
Expand Down Expand Up @@ -209,7 +210,7 @@ def validate_type(expected_type: type, value: Any, path: str = ''):
if origin is typing.ClassVar:
return validate_type(args[0], value, path)

if origin in (dict, Mapping):
if issubclass(origin, Mapping):
if not isinstance(value, Mapping):
raise TypeError(f"{curr_path} must be a mapping, "
f"got {type(value).__name__}")
Expand Down
7 changes: 5 additions & 2 deletions tests/test_type_checking.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""General test for types of components."""

from typing import Any, ClassVar, Dict, List, Optional, Tuple
from typing import Any, ClassVar, Dict, List, Optional, Tuple, Mapping
import numpy as np
import pytest

Expand Down Expand Up @@ -57,6 +57,7 @@ class GenericComponent(CobayaComponent):
tuple_params: Tuple[float, float] = (0.0, 1.0)
array: Sequence[float]
array2: Sequence[float]
map: Mapping[float, str]

_enforce_types = True

Expand All @@ -75,7 +76,8 @@ def test_component_types():
"params": {"a": [0.0, 1.0], "b": [0, 1]},
"tuple_params": (0.0, 1.0),
"array": np.arange(2, dtype=np.float64),
"array2": [1, 2]
"array2": [1, 2],
"map": {1.0: "a", 2.0: "b"}
}
GenericComponent(correct_kwargs)

Expand All @@ -95,6 +97,7 @@ def test_component_types():
{"tuple_params": "not_a_tuple"},
{"tuple_params": (0.0, "not_a_float")},
{"array": 2},
{"map": {"a": 2.0}}
]
for case in wrong_cases:
with pytest.raises(TypeError):
Expand Down

0 comments on commit bc6ae29

Please sign in to comment.