Skip to content

Commit

Permalink
[feat] Add new decorator assign_units for unit processing on input an…
Browse files Browse the repository at this point in the history
…d output for a function that does not support unit operations (#73)

* Add handle_units decorator

* Fix bugs

* Update _base.py

* Support multiple results

* Update

* Support return pytree structure for `assign_units`
  • Loading branch information
Routhleck authored Dec 4, 2024
1 parent e8f01b2 commit a840d6c
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 0 deletions.
78 changes: 78 additions & 0 deletions brainunit/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
# functions for checking
'check_dims',
'check_units',
'assign_units',
'fail_for_dimension_mismatch',
'fail_for_unit_mismatch',
'assert_quantity',
Expand Down Expand Up @@ -4453,6 +4454,78 @@ def new_f(*args, **kwds):
return do_check_units


@set_module_as('brainunit')
def assign_units(**au):
"""
Decorator to transform units of arguments passed to a function
"""

def do_assign_units(f):
@wraps(f)
def new_f(*args, **kwds):
newkeyset = kwds.copy()
arg_names = f.__code__.co_varnames[0: f.__code__.co_argcount]
for n, v in zip(arg_names, args[0: f.__code__.co_argcount]):
if n in au and v is not None:
specific_unit = au[n]
# if the specific unit is a boolean, just check and return
if specific_unit == bool:
if isinstance(v, bool):
newkeyset[n] = v
else:
raise TypeError(f"Function '{f.__name__}' expected a boolean value for argument '{n}' but got '{v}'")

elif specific_unit == 1:
if isinstance(v, Quantity):
newkeyset[n] = v.to_decimal()
elif isinstance(v, (jax.Array, np.ndarray, int, float, complex)):
newkeyset[n] = v
else:
specific_unit = jax.typing.ArrayLike
raise TypeError(f"Function '{f.__name__}' expected a unitless Quantity object"
f"or {specific_unit} for argument '{n}' but got '{v}'")

elif isinstance(specific_unit, Unit):
if isinstance(v, Quantity):
v = v.to_decimal(specific_unit)
newkeyset[n] = v
else:
raise TypeError(
f"Function '{f.__name__}' expected a Quantity object for argument '{n}' but got '{v}'"
)
else:
raise TypeError(
f"Function '{f.__name__}' expected a target unit object or"
f" a Number, boolean object for checking, but got '{specific_unit}'"
)
else:
newkeyset[n] = v

result = f(**newkeyset)
if "result" in au:
if isinstance(au["result"], Callable) and au["result"] != bool:
expected_result = au["result"](*[get_unit(a) for a in args])
else:
expected_result = au["result"]

if (
jax.tree.structure(expected_result, is_leaf=_is_quantity)
!=
jax.tree.structure(result, is_leaf=_is_quantity)
):
raise TypeError(
f"Expected a return value of type {expected_result} but got {result}"
)

result = jax.tree.map(
partial(_assign_unit, f), result, expected_result
)
return result

return new_f

return do_assign_units

def _check_unit(f, val, unit):
unit = UNITLESS if unit is None else unit
if not has_same_unit(val, unit):
Expand All @@ -4464,6 +4537,11 @@ def _check_unit(f, val, unit):
)
raise UnitMismatchError(error_message, get_unit(val))

def _assign_unit(f, val, unit):
if unit is None or unit == bool or unit == 1:
return val
return Quantity(val, unit=unit)


def _is_quantity(x):
return isinstance(x, Quantity)
81 changes: 81 additions & 0 deletions brainunit/_base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1395,6 +1395,87 @@ def d_function2(true_result):
with pytest.raises(u.UnitMismatchError):
d_function2(2)

def test_assign_units(self):
"""
Test the assign_units decorator
"""

@u.assign_units(v=volt)
def a_function(v, x):
"""
v has to have units of volt, x can have any (or no) unit.
"""
return v

# Try correct units
assert a_function(3 * mV, 5 * second) == (3 * mV).to_decimal(volt)
assert a_function(3 * volt, 5 * second) == (3 * volt).to_decimal(volt)
assert a_function(5 * volt, "something") == (5 * volt).to_decimal(volt)
assert_quantity(a_function([1, 2, 3] * volt, None), ([1, 2, 3] * volt).to_decimal(volt))

# Try incorrect units
with pytest.raises(u.UnitMismatchError):
a_function(5 * second, None)
with pytest.raises(TypeError):
a_function(5, None)
with pytest.raises(TypeError):
a_function(object(), None)

@u.assign_units(result=second)
def b_function():
"""
Return a value in seconds if return_second is True, otherwise return
a value in volt.
"""
return 5

# Should work (returns second)
assert b_function() == 5 * second

@u.assign_units(a=bool, b=1, result=bool)
def c_function(a, b):
if a:
return b > 0
else:
return b

assert c_function(True, 1)
assert not c_function(True, -1)
with pytest.raises(TypeError):
c_function(1, 1)
with pytest.raises(TypeError):
c_function(1 * mV, 1)

# Multiple results
@u.assign_units(result=(second, volt))
def d_function():
return 5, 3

# Should work (returns second)
assert d_function()[0] == 5 * second
assert d_function()[1] == 3 * volt

# Multiple results
@u.assign_units(result={'u': second, 'v': (volt, metre)})
def d_function2(true_result):
"""
Return a value in seconds if return_second is True, otherwise return
a value in volt.
"""
if true_result == 0:
return {'u': 5, 'v': (3, 2)}
elif true_result == 1:
return 3, 5
else:
return 3, 5

# Should work (returns dict)
d_function2(0)
# Should fail (returns tuple)
with pytest.raises(TypeError):
d_function2(1)



def test_str_repr():
"""
Expand Down

0 comments on commit a840d6c

Please sign in to comment.