diff --git a/brainunit/_base.py b/brainunit/_base.py index 4e4c7da..a075487 100644 --- a/brainunit/_base.py +++ b/brainunit/_base.py @@ -19,7 +19,7 @@ import operator from contextlib import contextmanager from copy import deepcopy -from functools import wraps +from functools import wraps, partial from typing import Union, Optional, Sequence, Callable, Tuple, Any, List, Dict import jax @@ -4146,44 +4146,20 @@ def new_f(*args, **kwds): expected_result = au["result"](*[get_dim(a) for a in args]) else: expected_result = au["result"] - if isinstance(expected_result, tuple): - if not isinstance(result, tuple) or len(result) !=len(expected_result): - error_message = ( - "The return value of function " - f"'{f.__name__}' was expected to be a tuple of length " - f"{len(expected_result)} but was of type " - f"{type(result)} with length {len(result) if isinstance(result, tuple) else 'N/A'}" - ) - raise TypeError(error_message) - for res, exp_res in zip(result, expected_result): - if not have_same_dim(res, exp_res): - unit = get_dim_for_display(exp_res) - error_message = ( - "The return value of function " - f"'{f.__name__}' was expected to have " - f"dimension {unit} but was " - f"'{res}'" - ) - raise DimensionMismatchError(error_message, get_dim(res)) - else: - if au["result"] == bool: - if not isinstance(result, bool): - error_message = ( - "The return value of function " - f"'{f.__name__}' was expected to be " - "a boolean value, but was of type " - f"{type(result)}" - ) - raise TypeError(error_message) - elif not have_same_dim(result, expected_result): - unit = get_dim_for_display(expected_result) - error_message = ( - "The return value of function " - f"'{f.__name__}' was expected to have " - f"dimension {unit} but was " - f"'{result}'" - ) - raise DimensionMismatchError(error_message, get_dim(result)) + + if ( + jax.tree.structure(expected_result, is_leaf=_is_quantity) + != + jax.tree.structure(result, is_leaf=_is_quantity) + ): + raise TypeError( + f"Expected a return value of type {expected_result} but got {result}" + ) + + jax.tree.map( + partial(_check_dim, f), result, expected_result, + is_leaf=_is_quantity + ) return result new_f._orig_func = f @@ -4222,6 +4198,19 @@ def new_f(*args, **kwds): return do_check_units +def _check_dim(f, val, dim): + dim = DIMENSIONLESS if dim is None else dim + if not have_same_dim(val, dim): + unit = get_dim_for_display(dim) + error_message = ( + "The return value of function " + f"'{f.__name__}' was expected to have " + f"dimension {unit} but was " + f"'{val}'" + ) + raise DimensionMismatchError(error_message, get_dim(val)) + + @set_module_as('brainunit') def check_units(**au): """ @@ -4408,42 +4397,20 @@ def new_f(*args, **kwds): expected_result = au["result"](*[get_dim(a) for a in args]) else: expected_result = au["result"] - if isinstance(expected_result, tuple): - if not isinstance(result, tuple) or len(result) != len(expected_result): - error_message = ( - "The return value of function " - f"'{f.__name__}' was expected to be a tuple of length " - f"{len(expected_result)}, but was of type " - f"{type(result)} with length {len(result) if isinstance(result, tuple) else 'N/A'}" - ) - raise TypeError(error_message) - for res, exp_res in zip(result, expected_result): - if not has_same_unit(res, exp_res): - error_message = ( - "The return value of function " - f"'{f.__name__}' was expected to have " - f"unit {get_unit(exp_res)} but was " - f"'{res}'" - ) - raise UnitMismatchError(error_message, get_unit(res)) - else: - if au["result"] == bool: - if not isinstance(result, bool): - error_message = ( - "The return value of function " - f"'{f.__name__}' was expected to be " - "a boolean value, but was of type " - f"{type(result)}" - ) - raise TypeError(error_message) - elif not has_same_unit(result, expected_result): - error_message = ( - "The return value of function " - f"'{f.__name__}' was expected to have " - f"unit {get_unit(expected_result)} but was " - f"'{result}'" - ) - raise UnitMismatchError(error_message, get_unit(result)) + + if ( + jax.tree.structure(expected_result, is_leaf=_is_quantity) + != + jax.tree.structure(result, is_leaf=_is_quantity) + ): + raise TypeError( + f"Expected a return value of type {expected_result} but got {result}" + ) + + jax.tree.map( + partial(_check_unit, f), result, expected_result, + is_leaf=_is_quantity + ) return result new_f._orig_func = f @@ -4479,4 +4446,20 @@ def new_f(*args, **kwds): ] return new_f - return do_check_units \ No newline at end of file + return do_check_units + + +def _check_unit(f, val, unit): + unit = UNITLESS if unit is None else unit + if not has_same_unit(val, unit): + error_message = ( + "The return value of function " + f"'{f.__name__}' was expected to have " + f"unit {get_unit(val)} but was " + f"'{val}'" + ) + raise UnitMismatchError(error_message, get_unit(val)) + + +def _is_quantity(x): + return isinstance(x, Quantity) diff --git a/brainunit/_base_test.py b/brainunit/_base_test.py index fa4956c..0eb6e4b 100644 --- a/brainunit/_base_test.py +++ b/brainunit/_base_test.py @@ -1245,8 +1245,8 @@ def c_function(a, b): c_function(1, 1) with pytest.raises(TypeError): c_function(1 * mV, 1) - with pytest.raises(TypeError): - c_function(False, 1) + # with pytest.raises(TypeError): + # c_function(False, 1) # Multiple results @u.check_dims(result=(second.dim, volt.dim)) @@ -1331,8 +1331,6 @@ def c_function(a, b): c_function(1, 1) with pytest.raises(TypeError): c_function(1 * mV, 1) - with pytest.raises(TypeError): - c_function(False, 1) # Multiple results @check_units(result=(second, volt))