Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] Add new decorator handle_units for unit processing on input and output for a function that does not support unit operations #72

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions brainunit/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
# functions for checking
'check_dims',
'check_units',
'handle_units',
'fail_for_dimension_mismatch',
'fail_for_unit_mismatch',
'assert_quantity',
Expand Down Expand Up @@ -4441,3 +4442,82 @@ def new_f(*args, **kwds):
return new_f

return do_check_units


@set_module_as('brainunit')
def handle_units(**au):
"""
Decorator to transform units of arguments passed to a function
"""

def do_handle_units(f):
@wraps(f)
def new_f(*args, **kwds):
newkeyset = kwds.copy()
arg_names = f.__code__.co_varnames[0: f.__code__.co_argcount]
for n, v in zip(arg_names, args[0: f.__code__.co_argcount]):
if n in au and v is not None:
specific_unit = au[n]
# if the specific unit is a boolean, just check and return
if specific_unit == bool:
if isinstance(v, bool):
newkeyset[n] = v
else:
raise TypeError(f"Function '{f.__name__}' expected a boolean value for argument '{n}' but got '{v}'")

elif specific_unit == 1:
if isinstance(v, Quantity):
newkeyset[n] = v.to_decimal()
elif isinstance(v, (jax.Array, np.ndarray, int, float, complex)):
newkeyset[n] = v
else:
specific_unit = jax.typing.ArrayLike
raise TypeError(f"Function '{f.__name__}' expected a unitless Quantity object"
f"or {specific_unit} for argument '{n}' but got '{v}'")

elif isinstance(specific_unit, Unit):
if isinstance(v, Quantity):
v = v.to_decimal(specific_unit)
newkeyset[n] = v
else:
raise TypeError(
f"Function '{f.__name__}' expected a Quantity object for argument '{n}' but got '{v}'"
)
else:
raise TypeError(
f"Function '{f.__name__}' expected a target unit object or"
f" a Number, boolean object for checking, but got '{specific_unit}'"
)
else:
newkeyset[n] = v

result = f(**newkeyset)
if isinstance(result, tuple):
assert isinstance(au["result"], tuple), "The return value of the function is a tuple, but the decorator expected a single unit."
result = tuple(
Quantity(r, unit=au["result"][i]) if isinstance(au["result"][i], Unit) else r
for i, r in enumerate(result)
)
if "result" in au:
specific_unit = au["result"]
if specific_unit == bool:
if isinstance(result, bool):
pass
else:
raise TypeError(f"Function '{f.__name__}' expected a boolean value for the return value but got '{result}'")
elif specific_unit == 1:
if isinstance(result, Quantity):
result = result.to_decimal()
elif isinstance(result, (jax.Array, np.ndarray, int, float, complex)):
result = jnp.asarray(result)
else:
specific_unit = jax.typing.ArrayLike
raise TypeError(f"Function '{f.__name__}' expected a unitless Quantity object"
f" or {specific_unit} for the return value but got '{result}'")
elif isinstance(specific_unit, Unit):
result = Quantity(result, unit=specific_unit)
return result

return new_f

return do_handle_units
61 changes: 61 additions & 0 deletions brainunit/_base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1316,6 +1316,67 @@ def c_function(a, b):
with pytest.raises(TypeError):
c_function(False, 1)

def test_handle_units(self):
"""
Test the handle_units decorator
"""

@u.handle_units(v=volt)
def a_function(v, x):
"""
v has to have units of volt, x can have any (or no) unit.
"""
return v

# Try correct units
assert a_function(3 * mV, 5 * second) == (3 * mV).to_decimal(volt)
assert a_function(3 * volt, 5 * second) == (3 * volt).to_decimal(volt)
assert a_function(5 * volt, "something") == (5 * volt).to_decimal(volt)
assert_quantity(a_function([1, 2, 3] * volt, None), ([1, 2, 3] * volt).to_decimal(volt))

# Try incorrect units
with pytest.raises(u.UnitMismatchError):
a_function(5 * second, None)
with pytest.raises(TypeError):
a_function(5, None)
with pytest.raises(TypeError):
a_function(object(), None)

@u.handle_units(result=second)
def b_function():
"""
Return a value in seconds if return_second is True, otherwise return
a value in volt.
"""
return 5 * second

# Should work (returns second)
assert b_function() == 5 * second

@u.handle_units(a=bool, b=1, result=bool)
def c_function(a, b):
if a:
return b > 0
else:
return b

assert c_function(True, 1)
assert not c_function(True, -1)
with pytest.raises(TypeError):
c_function(1, 1)
with pytest.raises(TypeError):
c_function(1 * mV, 1)

# Multiple results
@u.handle_units(result=(second, volt))
def d_function():
return 5, 3

# Should work (returns second)
assert d_function()[0] == 5 * second
assert d_function()[1] == 3 * volt



def test_str_repr():
"""
Expand Down
Loading