diff --git a/cobaya/input.py b/cobaya/input.py index a1c0929c..9b8f5827 100644 --- a/cobaya/input.py +++ b/cobaya/input.py @@ -29,6 +29,7 @@ from cobaya.log import LoggedError, get_logger from cobaya.parameterization import expand_info_param from cobaya import mpi +import cobaya.typing # Logger logger = get_logger(__name__) @@ -141,6 +142,8 @@ def load_info_overrides(*infos_or_yaml_or_files, **flags) -> InputDict: for flag, value in flags.items(): if value is not None: info[flag] = value + if cobaya.typing.enforce_type_checking: + cobaya.typing.validate_type(InputDict, info) return info diff --git a/cobaya/typing.py b/cobaya/typing.py index d0dde1e7..8dfe3deb 100644 --- a/cobaya/typing.py +++ b/cobaya/typing.py @@ -54,7 +54,7 @@ class ParamDict(TypedDict, total=False): value: Union[float, Callable, str] derived: Union[bool, str, Callable] prior: Union[None, Sequence[float], SciPyDistDict, SciPyMinMaxDict] - ref: Union[None, Sequence[float], SciPyDistDict, SciPyMinMaxDict] + ref: Union[None, float, Sequence[float], SciPyDistDict, SciPyMinMaxDict] proposal: Optional[float] renames: Union[str, Sequence[str]] latex: str @@ -129,7 +129,7 @@ def validate_type(expected_type: type, value: Any, path: str = ''): if expected_type is int: if not (value in (np.inf, -np.inf) or isinstance(value, numbers.Integral)): raise TypeError( - f"{curr_path} must be an integer or infinity, got {type(value).__name__}" + f"{curr_path} must be an integer, got {type(value).__name__}" ) return @@ -140,14 +140,13 @@ def validate_type(expected_type: type, value: Any, path: str = ''): return if expected_type is bool: - if not hasattr(value, '__bool__') and not isinstance(value, (str, np.ndarray)): + if not isinstance(value, bool): + # if not hasattr(value, '__bool__') and not isinstance(value, (str, np.ndarray)): raise TypeError( f"{curr_path} must be boolean, got {type(value).__name__}" ) return - # special case for Cobaya - if sys.version_info < (3, 10): from typing_extensions import is_typeddict else: @@ -163,7 +162,7 @@ def validate_type(expected_type: type, value: Any, path: str = ''): f"'{expected_type.__name__}': {invalid_keys}") for key, val in value.items(): validate_type(type_hints[key], val, f"{path}.{key}" if path else str(key)) - return True + return if (origin := typing.get_origin(expected_type)) and ( args := typing.get_args(expected_type)): @@ -178,6 +177,8 @@ def validate_type(expected_type: type, value: Any, path: str = ''): return validate_type(t, value, path) except TypeError as e: error_msg = str(e) + if ' any Union type' in error_msg: + raise error_path = error_msg.split(' ')[0].strip("'") # If error is about the current path, it's a structural error @@ -210,10 +211,12 @@ def validate_type(expected_type: type, value: Any, path: str = ''): if origin is typing.ClassVar: return validate_type(args[0], value, path) + if isinstance(value, Mapping) != issubclass(origin, Mapping): + raise TypeError( + f"{curr_path} must be {args[0]}, got {type(value).__name__}" + ) + if issubclass(origin, Mapping): - if not isinstance(value, Mapping): - raise TypeError(f"{curr_path} must be a mapping, " - f"got {type(value).__name__}") for k, v in value.items(): key_path = f"{path}[{k!r}]" if path else f"[{k!r}]" validate_type(args[0], k, f"{key_path} (key)") @@ -231,12 +234,11 @@ def validate_type(expected_type: type, value: Any, path: str = ''): ) return - if not isinstance(value, Iterable): - raise TypeError( - f"{curr_path} must be iterable, got {type(value).__name__}" - ) - if len(args) == 1: + if not isinstance(value, Iterable): + raise TypeError( + f"{curr_path} must be iterable, got {type(value).__name__}" + ) for i, item in enumerate(value): validate_type(args[0], item, f"{path}[{i}]" if path else f"[{i}]") else: diff --git a/tests/test_type_checking.py b/tests/test_type_checking.py index c0b13cab..3719686a 100644 --- a/tests/test_type_checking.py +++ b/tests/test_type_checking.py @@ -5,42 +5,8 @@ import pytest from cobaya.component import CobayaComponent -from cobaya.likelihood import Likelihood from cobaya.tools import NumberWithUnits -from cobaya.typing import InputDict, ParamDict, Sequence -from cobaya.run import run - - -class GenericLike(Likelihood): - any: Any - classvar: ClassVar[int] = 1 - infinity: int = float("inf") - mean: NumberWithUnits = 1 - noise: float = 0 - none: int = None - numpy_int: int = np.int64(1) - optional: Optional[int] = None - paramdict_params: ParamDict = {"prior": [0.0, 1.0]} - params: Dict[str, List[float]] = {"a": [0.0, 1.0], "b": [0, 1]} - tuple_params: Tuple[float, float] = (0.0, 1.0) - - _enforce_types = True - - def logp(self, **params_values): - return 1 - - -def test_sampler_types(): - original_info: InputDict = { - "likelihood": {"like": GenericLike}, - "sampler": {"mcmc": {"max_samples": 1}}, - } - _ = run(original_info) - - info = original_info.copy() - info["sampler"]["mcmc"]["max_samples"] = "not_an_int" - with pytest.raises(TypeError): - run(info) +from cobaya.typing import ParamDict, Sequence class GenericComponent(CobayaComponent):