diff --git a/brainunit/_base.py b/brainunit/_base.py index 489a71e..b56d7e3 100644 --- a/brainunit/_base.py +++ b/brainunit/_base.py @@ -55,6 +55,7 @@ # functions for checking 'check_dims', 'check_units', + 'assign_units', 'fail_for_dimension_mismatch', 'fail_for_unit_mismatch', 'assert_quantity', @@ -4453,6 +4454,78 @@ def new_f(*args, **kwds): return do_check_units +@set_module_as('brainunit') +def assign_units(**au): + """ + Decorator to transform units of arguments passed to a function + """ + + def do_assign_units(f): + @wraps(f) + def new_f(*args, **kwds): + newkeyset = kwds.copy() + arg_names = f.__code__.co_varnames[0: f.__code__.co_argcount] + for n, v in zip(arg_names, args[0: f.__code__.co_argcount]): + if n in au and v is not None: + specific_unit = au[n] + # if the specific unit is a boolean, just check and return + if specific_unit == bool: + if isinstance(v, bool): + newkeyset[n] = v + else: + raise TypeError(f"Function '{f.__name__}' expected a boolean value for argument '{n}' but got '{v}'") + + elif specific_unit == 1: + if isinstance(v, Quantity): + newkeyset[n] = v.to_decimal() + elif isinstance(v, (jax.Array, np.ndarray, int, float, complex)): + newkeyset[n] = v + else: + specific_unit = jax.typing.ArrayLike + raise TypeError(f"Function '{f.__name__}' expected a unitless Quantity object" + f"or {specific_unit} for argument '{n}' but got '{v}'") + + elif isinstance(specific_unit, Unit): + if isinstance(v, Quantity): + v = v.to_decimal(specific_unit) + newkeyset[n] = v + else: + raise TypeError( + f"Function '{f.__name__}' expected a Quantity object for argument '{n}' but got '{v}'" + ) + else: + raise TypeError( + f"Function '{f.__name__}' expected a target unit object or" + f" a Number, boolean object for checking, but got '{specific_unit}'" + ) + else: + newkeyset[n] = v + + result = f(**newkeyset) + if "result" in au: + if isinstance(au["result"], Callable) and au["result"] != bool: + expected_result = au["result"](*[get_unit(a) for a in args]) + else: + expected_result = au["result"] + + if ( + jax.tree.structure(expected_result, is_leaf=_is_quantity) + != + jax.tree.structure(result, is_leaf=_is_quantity) + ): + raise TypeError( + f"Expected a return value of type {expected_result} but got {result}" + ) + + result = jax.tree.map( + partial(_assign_unit, f), result, expected_result + ) + return result + + return new_f + + return do_assign_units + def _check_unit(f, val, unit): unit = UNITLESS if unit is None else unit if not has_same_unit(val, unit): @@ -4464,6 +4537,11 @@ def _check_unit(f, val, unit): ) raise UnitMismatchError(error_message, get_unit(val)) +def _assign_unit(f, val, unit): + if unit is None or unit == bool or unit == 1: + return val + return Quantity(val, unit=unit) + def _is_quantity(x): return isinstance(x, Quantity) diff --git a/brainunit/_base_test.py b/brainunit/_base_test.py index 6781700..b66d996 100644 --- a/brainunit/_base_test.py +++ b/brainunit/_base_test.py @@ -1395,6 +1395,87 @@ def d_function2(true_result): with pytest.raises(u.UnitMismatchError): d_function2(2) + def test_assign_units(self): + """ + Test the assign_units decorator + """ + + @u.assign_units(v=volt) + def a_function(v, x): + """ + v has to have units of volt, x can have any (or no) unit. + """ + return v + + # Try correct units + assert a_function(3 * mV, 5 * second) == (3 * mV).to_decimal(volt) + assert a_function(3 * volt, 5 * second) == (3 * volt).to_decimal(volt) + assert a_function(5 * volt, "something") == (5 * volt).to_decimal(volt) + assert_quantity(a_function([1, 2, 3] * volt, None), ([1, 2, 3] * volt).to_decimal(volt)) + + # Try incorrect units + with pytest.raises(u.UnitMismatchError): + a_function(5 * second, None) + with pytest.raises(TypeError): + a_function(5, None) + with pytest.raises(TypeError): + a_function(object(), None) + + @u.assign_units(result=second) + def b_function(): + """ + Return a value in seconds if return_second is True, otherwise return + a value in volt. + """ + return 5 + + # Should work (returns second) + assert b_function() == 5 * second + + @u.assign_units(a=bool, b=1, result=bool) + def c_function(a, b): + if a: + return b > 0 + else: + return b + + assert c_function(True, 1) + assert not c_function(True, -1) + with pytest.raises(TypeError): + c_function(1, 1) + with pytest.raises(TypeError): + c_function(1 * mV, 1) + + # Multiple results + @u.assign_units(result=(second, volt)) + def d_function(): + return 5, 3 + + # Should work (returns second) + assert d_function()[0] == 5 * second + assert d_function()[1] == 3 * volt + + # Multiple results + @u.assign_units(result={'u': second, 'v': (volt, metre)}) + def d_function2(true_result): + """ + Return a value in seconds if return_second is True, otherwise return + a value in volt. + """ + if true_result == 0: + return {'u': 5, 'v': (3, 2)} + elif true_result == 1: + return 3, 5 + else: + return 3, 5 + + # Should work (returns dict) + d_function2(0) + # Should fail (returns tuple) + with pytest.raises(TypeError): + d_function2(1) + + def test_str_repr(): """