From 1577a440a583cb6f16046f0887ece8ae0e54d04b Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Tue, 3 Dec 2024 14:01:16 +0800 Subject: [PATCH 1/4] Add handle_units decorator --- brainunit/_base.py | 44 +++++++++++++++++++++++++++++++++++ brainunit/_base_test.py | 51 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 95 insertions(+) diff --git a/brainunit/_base.py b/brainunit/_base.py index 19257a4..245cea7 100644 --- a/brainunit/_base.py +++ b/brainunit/_base.py @@ -55,6 +55,7 @@ # functions for checking 'check_dims', 'check_units', + 'handle_units', 'fail_for_dimension_mismatch', 'fail_for_unit_mismatch', 'assert_quantity', @@ -4441,3 +4442,46 @@ def new_f(*args, **kwds): return new_f return do_check_units + + +@set_module_as('brainunit') +def handle_units(**au): + """ + Decorator to transform units of arguments passed to a function + """ + + def do_handle_units(f): + @wraps(f) + def new_f(*args, **kwds): + newkeyset = kwds.copy() + arg_names = f.__code__.co_varnames[0: f.__code__.co_argcount] + for n, v in zip(arg_names, args[0: f.__code__.co_argcount]): + if n in au and v is not None: + specific_unit = au[n] + if isinstance(specific_unit, bool) or specific_unit == 1: + if isinstance(v, bool): + newkeyset[n] = v + elif isinstance(v, Quantity): + newkeyset[n] = v.to_decimal() + elif isinstance(v, jax.typing.ArrayLike): + newkeyset[n] = jnp.asarray(v) + if isinstance(v, Quantity): + v = v.to_decimal(specific_unit) + newkeyset[n] = v + else: + raise UnitMismatchError( + f"Function '{f.__name__}' expected a Quantity object for argument '{n}' but got '{v}'" + ) + else: + newkeyset[n] = v + + result = f(**newkeyset) + if "result" in au: + specific_unit = au["result"] + if isinstance(result, Quantity): + result = Quantity(result, unit=specific_unit) + return result + + return new_f + + return do_handle_units \ No newline at end of file diff --git a/brainunit/_base_test.py b/brainunit/_base_test.py index e2d415d..83c9e7e 100644 --- a/brainunit/_base_test.py +++ b/brainunit/_base_test.py @@ -1316,6 +1316,57 @@ def c_function(a, b): with pytest.raises(TypeError): c_function(False, 1) + def test_handle_units(self): + """ + Test the handle_units decorator + """ + + @u.handle_units(v=volt) + def a_function(v, x): + """ + v has to have units of volt, x can have any (or no) unit. + """ + return v + + # Try correct units + assert a_function(3 * mV, 5 * second) == (3 * mV).to_decimal(volt) + assert a_function(3 * volt, 5 * second) == (3 * volt).to_decimal(volt) + assert a_function(5 * volt, "something") == (5 * volt).to_decimal(volt) + assert_quantity(a_function([1, 2, 3] * volt, None), ([1, 2, 3] * volt).to_decimal(volt)) + + # Try incorrect units + with pytest.raises(u.UnitMismatchError): + a_function(5 * second, None) + with pytest.raises(u.UnitMismatchError): + a_function(5, None) + with pytest.raises(u.UnitMismatchError): + a_function(object(), None) + + @u.handle_units(result=second) + def b_function(): + """ + Return a value in seconds if return_second is True, otherwise return + a value in volt. + """ + return 5 * second + + # Should work (returns second) + assert b_function() == 5 * second + + @u.handle_units(a=bool, b=1, result=bool) + def c_function(a, b): + if a: + return b > 0 + else: + return b + + assert c_function(True, 1) + assert not c_function(True, -1) + with pytest.raises(TypeError): + c_function(1, 1) + with pytest.raises(TypeError): + c_function(1 * mV, 1) + def test_str_repr(): """ From c697fa628ea645e193be3e0ec87911b298f87527 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Tue, 3 Dec 2024 16:23:12 +0800 Subject: [PATCH 2/4] Fix bugs --- brainunit/_base.py | 48 +++++++++++++++++++++++++++++++++-------- brainunit/_base_test.py | 4 ++-- 2 files changed, 41 insertions(+), 11 deletions(-) diff --git a/brainunit/_base.py b/brainunit/_base.py index 245cea7..c1ca3e7 100644 --- a/brainunit/_base.py +++ b/brainunit/_base.py @@ -4458,19 +4458,35 @@ def new_f(*args, **kwds): for n, v in zip(arg_names, args[0: f.__code__.co_argcount]): if n in au and v is not None: specific_unit = au[n] - if isinstance(specific_unit, bool) or specific_unit == 1: + # if the specific unit is a boolean, just check and return + if specific_unit == bool: if isinstance(v, bool): newkeyset[n] = v - elif isinstance(v, Quantity): + else: + raise TypeError(f"Function '{f.__name__}' expected a boolean value for argument '{n}' but got '{v}'") + + elif specific_unit == 1: + if isinstance(v, Quantity): newkeyset[n] = v.to_decimal() elif isinstance(v, jax.typing.ArrayLike): - newkeyset[n] = jnp.asarray(v) - if isinstance(v, Quantity): - v = v.to_decimal(specific_unit) - newkeyset[n] = v + newkeyset[n] = v + else: + specific_unit = jax.typing.ArrayLike + raise TypeError(f"Function '{f.__name__}' expected a unitless Quantity object" + f"or {specific_unit} for argument '{n}' but got '{v}'") + + elif isinstance(specific_unit, Unit): + if isinstance(v, Quantity): + v = v.to_decimal(specific_unit) + newkeyset[n] = v + else: + raise TypeError( + f"Function '{f.__name__}' expected a Quantity object for argument '{n}' but got '{v}'" + ) else: - raise UnitMismatchError( - f"Function '{f.__name__}' expected a Quantity object for argument '{n}' but got '{v}'" + raise TypeError( + f"Function '{f.__name__}' expected a target unit object or" + f" a Number, boolean object for checking, but got '{specific_unit}'" ) else: newkeyset[n] = v @@ -4478,7 +4494,21 @@ def new_f(*args, **kwds): result = f(**newkeyset) if "result" in au: specific_unit = au["result"] - if isinstance(result, Quantity): + if specific_unit == bool: + if isinstance(result, bool): + pass + else: + raise TypeError(f"Function '{f.__name__}' expected a boolean value for the return value but got '{result}'") + elif specific_unit == 1: + if isinstance(result, Quantity): + result = result.to_decimal() + elif isinstance(result, jax.typing.ArrayLike): + result = jnp.asarray(result) + else: + specific_unit = jax.typing.ArrayLike + raise TypeError(f"Function '{f.__name__}' expected a unitless Quantity object" + f" or {specific_unit} for the return value but got '{result}'") + elif isinstance(specific_unit, Unit): result = Quantity(result, unit=specific_unit) return result diff --git a/brainunit/_base_test.py b/brainunit/_base_test.py index 83c9e7e..52c434b 100644 --- a/brainunit/_base_test.py +++ b/brainunit/_base_test.py @@ -1337,9 +1337,9 @@ def a_function(v, x): # Try incorrect units with pytest.raises(u.UnitMismatchError): a_function(5 * second, None) - with pytest.raises(u.UnitMismatchError): + with pytest.raises(TypeError): a_function(5, None) - with pytest.raises(u.UnitMismatchError): + with pytest.raises(TypeError): a_function(object(), None) @u.handle_units(result=second) From efcfd757def749587a4ea3aa0c7a7ed6a16f27f5 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Tue, 3 Dec 2024 16:33:31 +0800 Subject: [PATCH 3/4] Update _base.py --- brainunit/_base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/brainunit/_base.py b/brainunit/_base.py index c1ca3e7..47a4d0f 100644 --- a/brainunit/_base.py +++ b/brainunit/_base.py @@ -4468,7 +4468,7 @@ def new_f(*args, **kwds): elif specific_unit == 1: if isinstance(v, Quantity): newkeyset[n] = v.to_decimal() - elif isinstance(v, jax.typing.ArrayLike): + elif isinstance(v, (jax.Array, np.ndarray, int, float, complex)): newkeyset[n] = v else: specific_unit = jax.typing.ArrayLike @@ -4502,7 +4502,7 @@ def new_f(*args, **kwds): elif specific_unit == 1: if isinstance(result, Quantity): result = result.to_decimal() - elif isinstance(result, jax.typing.ArrayLike): + elif isinstance(result, (jax.Array, np.ndarray, int, float, complex)): result = jnp.asarray(result) else: specific_unit = jax.typing.ArrayLike From 21d59a73d65eaac3b74557fe0cf9f63dcd7baf23 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Tue, 3 Dec 2024 16:52:18 +0800 Subject: [PATCH 4/4] Support multiple results --- brainunit/_base.py | 6 ++++++ brainunit/_base_test.py | 10 ++++++++++ 2 files changed, 16 insertions(+) diff --git a/brainunit/_base.py b/brainunit/_base.py index 47a4d0f..d6e58ba 100644 --- a/brainunit/_base.py +++ b/brainunit/_base.py @@ -4492,6 +4492,12 @@ def new_f(*args, **kwds): newkeyset[n] = v result = f(**newkeyset) + if isinstance(result, tuple): + assert isinstance(au["result"], tuple), "The return value of the function is a tuple, but the decorator expected a single unit." + result = tuple( + Quantity(r, unit=au["result"][i]) if isinstance(au["result"][i], Unit) else r + for i, r in enumerate(result) + ) if "result" in au: specific_unit = au["result"] if specific_unit == bool: diff --git a/brainunit/_base_test.py b/brainunit/_base_test.py index 52c434b..27345c0 100644 --- a/brainunit/_base_test.py +++ b/brainunit/_base_test.py @@ -1367,6 +1367,16 @@ def c_function(a, b): with pytest.raises(TypeError): c_function(1 * mV, 1) + # Multiple results + @u.handle_units(result=(second, volt)) + def d_function(): + return 5, 3 + + # Should work (returns second) + assert d_function()[0] == 5 * second + assert d_function()[1] == 3 * volt + + def test_str_repr(): """