diff --git a/brainunit/_base.py b/brainunit/_base.py index 489a71e..742abde 100644 --- a/brainunit/_base.py +++ b/brainunit/_base.py @@ -55,6 +55,7 @@ # functions for checking 'check_dims', 'check_units', + 'handle_units', 'fail_for_dimension_mismatch', 'fail_for_unit_mismatch', 'assert_quantity', @@ -2410,10 +2411,6 @@ def size(self) -> int: def T(self) -> 'Quantity': return Quantity(jnp.asarray(self.mantissa).T, unit=self.unit) - @property - def mT(self) -> 'Quantity': - return Quantity(jnp.asarray(self.mantissa).mT, unit=self.unit) - @property def isreal(self) -> jax.Array: return jnp.isreal(self.mantissa) @@ -4150,20 +4147,24 @@ def new_f(*args, **kwds): expected_result = au["result"](*[get_dim(a) for a in args]) else: expected_result = au["result"] - - if ( - jax.tree.structure(expected_result, is_leaf=_is_quantity) - != - jax.tree.structure(result, is_leaf=_is_quantity) - ): - raise TypeError( - f"Expected a return value of type {expected_result} but got {result}" + if au["result"] == bool: + if not isinstance(result, bool): + error_message = ( + "The return value of function " + f"'{f.__name__}' was expected to be " + "a boolean value, but was of type " + f"{type(result)}" + ) + raise TypeError(error_message) + elif not have_same_dim(result, expected_result): + unit = get_dim_for_display(expected_result) + error_message = ( + "The return value of function " + f"'{f.__name__}' was expected to have " + f"dimension {unit} but was " + f"'{result}'" ) - - jax.tree.map( - partial(_check_dim, f), result, expected_result, - is_leaf=_is_quantity - ) + raise DimensionMismatchError(error_message, get_dim(result)) return result new_f._orig_func = f @@ -4202,19 +4203,6 @@ def new_f(*args, **kwds): return do_check_units -def _check_dim(f, val, dim): - dim = DIMENSIONLESS if dim is None else dim - if not have_same_dim(val, dim): - unit = get_dim_for_display(dim) - error_message = ( - "The return value of function " - f"'{f.__name__}' was expected to have " - f"dimension {unit} but was " - f"'{val}'" - ) - raise DimensionMismatchError(error_message, get_dim(val)) - - @set_module_as('brainunit') def check_units(**au): """ @@ -4401,20 +4389,23 @@ def new_f(*args, **kwds): expected_result = au["result"](*[get_dim(a) for a in args]) else: expected_result = au["result"] - - if ( - jax.tree.structure(expected_result, is_leaf=_is_quantity) - != - jax.tree.structure(result, is_leaf=_is_quantity) - ): - raise TypeError( - f"Expected a return value of type {expected_result} but got {result}" + if au["result"] == bool: + if not isinstance(result, bool): + error_message = ( + "The return value of function " + f"'{f.__name__}' was expected to be " + "a boolean value, but was of type " + f"{type(result)}" + ) + raise TypeError(error_message) + elif not has_same_unit(result, expected_result): + error_message = ( + "The return value of function " + f"'{f.__name__}' was expected to have " + f"unit {get_unit(expected_result)} but was " + f"'{result}'" ) - - jax.tree.map( - partial(_check_unit, f), result, expected_result, - is_leaf=_is_quantity - ) + raise UnitMismatchError(error_message, get_unit(result)) return result new_f._orig_func = f @@ -4467,3 +4458,82 @@ def _check_unit(f, val, unit): def _is_quantity(x): return isinstance(x, Quantity) + + +@set_module_as('brainunit') +def handle_units(**au): + """ + Decorator to transform units of arguments passed to a function + """ + + def do_handle_units(f): + @wraps(f) + def new_f(*args, **kwds): + newkeyset = kwds.copy() + arg_names = f.__code__.co_varnames[0: f.__code__.co_argcount] + for n, v in zip(arg_names, args[0: f.__code__.co_argcount]): + if n in au and v is not None: + specific_unit = au[n] + # if the specific unit is a boolean, just check and return + if specific_unit == bool: + if isinstance(v, bool): + newkeyset[n] = v + else: + raise TypeError(f"Function '{f.__name__}' expected a boolean value for argument '{n}' but got '{v}'") + + elif specific_unit == 1: + if isinstance(v, Quantity): + newkeyset[n] = v.to_decimal() + elif isinstance(v, (jax.Array, np.ndarray, int, float, complex)): + newkeyset[n] = v + else: + specific_unit = jax.typing.ArrayLike + raise TypeError(f"Function '{f.__name__}' expected a unitless Quantity object" + f"or {specific_unit} for argument '{n}' but got '{v}'") + + elif isinstance(specific_unit, Unit): + if isinstance(v, Quantity): + v = v.to_decimal(specific_unit) + newkeyset[n] = v + else: + raise TypeError( + f"Function '{f.__name__}' expected a Quantity object for argument '{n}' but got '{v}'" + ) + else: + raise TypeError( + f"Function '{f.__name__}' expected a target unit object or" + f" a Number, boolean object for checking, but got '{specific_unit}'" + ) + else: + newkeyset[n] = v + + result = f(**newkeyset) + if isinstance(result, tuple): + assert isinstance(au["result"], tuple), "The return value of the function is a tuple, but the decorator expected a single unit." + result = tuple( + Quantity(r, unit=au["result"][i]) if isinstance(au["result"][i], Unit) else r + for i, r in enumerate(result) + ) + if "result" in au: + specific_unit = au["result"] + if specific_unit == bool: + if isinstance(result, bool): + pass + else: + raise TypeError(f"Function '{f.__name__}' expected a boolean value for the return value but got '{result}'") + elif specific_unit == 1: + if isinstance(result, Quantity): + result = result.to_decimal() + elif isinstance(result, (jax.Array, np.ndarray, int, float, complex)): + result = jnp.asarray(result) + else: + specific_unit = jax.typing.ArrayLike + raise TypeError(f"Function '{f.__name__}' expected a unitless Quantity object" + f" or {specific_unit} for the return value but got '{result}'") + elif isinstance(specific_unit, Unit): + result = Quantity(result, unit=specific_unit) + return result + + return new_f + + return do_handle_units \ No newline at end of file diff --git a/brainunit/_base_test.py b/brainunit/_base_test.py index 6781700..f6a1fce 100644 --- a/brainunit/_base_test.py +++ b/brainunit/_base_test.py @@ -1395,6 +1395,67 @@ def d_function2(true_result): with pytest.raises(u.UnitMismatchError): d_function2(2) + def test_handle_units(self): + """ + Test the handle_units decorator + """ + + @u.handle_units(v=volt) + def a_function(v, x): + """ + v has to have units of volt, x can have any (or no) unit. + """ + return v + + # Try correct units + assert a_function(3 * mV, 5 * second) == (3 * mV).to_decimal(volt) + assert a_function(3 * volt, 5 * second) == (3 * volt).to_decimal(volt) + assert a_function(5 * volt, "something") == (5 * volt).to_decimal(volt) + assert_quantity(a_function([1, 2, 3] * volt, None), ([1, 2, 3] * volt).to_decimal(volt)) + + # Try incorrect units + with pytest.raises(u.UnitMismatchError): + a_function(5 * second, None) + with pytest.raises(TypeError): + a_function(5, None) + with pytest.raises(TypeError): + a_function(object(), None) + + @u.handle_units(result=second) + def b_function(): + """ + Return a value in seconds if return_second is True, otherwise return + a value in volt. + """ + return 5 * second + + # Should work (returns second) + assert b_function() == 5 * second + + @u.handle_units(a=bool, b=1, result=bool) + def c_function(a, b): + if a: + return b > 0 + else: + return b + + assert c_function(True, 1) + assert not c_function(True, -1) + with pytest.raises(TypeError): + c_function(1, 1) + with pytest.raises(TypeError): + c_function(1 * mV, 1) + + # Multiple results + @u.handle_units(result=(second, volt)) + def d_function(): + return 5, 3 + + # Should work (returns second) + assert d_function()[0] == 5 * second + assert d_function()[1] == 3 * volt + + def test_str_repr(): """