Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
cmbant committed Nov 4, 2024
1 parent 29db230 commit 5be78e8
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions cobaya/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,11 +122,11 @@ def validate_type(expected_type: type, value: Any, path: str = ''):
:param path: string tracking the nested path for error messages
:raises TypeError: with descriptive message when validation fails
"""
curr_path = f"'{path}'" if path else 'value'

if value is None or expected_type is Any:
return

curr_path = f"'{path}'" if path else 'value'

if expected_type is int:
if not (value in (np.inf, -np.inf) or isinstance(value, numbers.Integral)):
raise TypeError(
Expand All @@ -142,7 +142,6 @@ def validate_type(expected_type: type, value: Any, path: str = ''):

if expected_type is bool:
if not isinstance(value, bool):
# if not hasattr(value, '__bool__') and not isinstance(value, (str, np.ndarray)):
raise TypeError(
f"{curr_path} must be boolean, got {type(value).__name__}"
)
Expand Down Expand Up @@ -176,8 +175,6 @@ def validate_type(expected_type: type, value: Any, path: str = ''):
return validate_type(t, value, path)
except TypeError as e:
error_msg = str(e)
# if ' any Union type' in error_msg or 'for TypedDict ' in error_msg:
# raise
error_path = error_msg.split(' ')[0].strip("'")
errors.append((error_path, error_msg))

Expand Down Expand Up @@ -235,13 +232,15 @@ def validate_type(expected_type: type, value: Any, path: str = ''):
if not (isinstance(value, expected_type) or
expected_type is Sequence and isinstance(value, np.ndarray)):

type_name = getattr(expected_type, "__name__", repr(expected_type))

# special case for Cobaya's NumberWithUnits, if not instance yet
if getattr(expected_type, "__name__", "") == 'NumberWithUnits':
if type_name == 'NumberWithUnits':
if not isinstance(value, (numbers.Real, str)):
raise TypeError(
f"{curr_path} must be a number or string for NumberWithUnits,"
f" got {type(value).__name__}")
return

raise TypeError(f"{curr_path} must be of type {expected_type.__name__}, "
raise TypeError(f"{curr_path} must be of type {type_name}, "
f"got {type(value).__name__}")

0 comments on commit 5be78e8

Please sign in to comment.