Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] Add new decorator handle_units for unit processing on input and output for a function that does not support unit operations #72

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 113 additions & 43 deletions brainunit/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
# functions for checking
'check_dims',
'check_units',
'handle_units',
'fail_for_dimension_mismatch',
'fail_for_unit_mismatch',
'assert_quantity',
Expand Down Expand Up @@ -2410,10 +2411,6 @@ def size(self) -> int:
def T(self) -> 'Quantity':
return Quantity(jnp.asarray(self.mantissa).T, unit=self.unit)

@property
def mT(self) -> 'Quantity':
return Quantity(jnp.asarray(self.mantissa).mT, unit=self.unit)

@property
def isreal(self) -> jax.Array:
return jnp.isreal(self.mantissa)
Expand Down Expand Up @@ -4150,20 +4147,24 @@ def new_f(*args, **kwds):
expected_result = au["result"](*[get_dim(a) for a in args])
else:
expected_result = au["result"]

if (
jax.tree.structure(expected_result, is_leaf=_is_quantity)
!=
jax.tree.structure(result, is_leaf=_is_quantity)
):
raise TypeError(
f"Expected a return value of type {expected_result} but got {result}"
if au["result"] == bool:
if not isinstance(result, bool):
error_message = (
"The return value of function "
f"'{f.__name__}' was expected to be "
"a boolean value, but was of type "
f"{type(result)}"
)
raise TypeError(error_message)
elif not have_same_dim(result, expected_result):
unit = get_dim_for_display(expected_result)
error_message = (
"The return value of function "
f"'{f.__name__}' was expected to have "
f"dimension {unit} but was "
f"'{result}'"
)

jax.tree.map(
partial(_check_dim, f), result, expected_result,
is_leaf=_is_quantity
)
raise DimensionMismatchError(error_message, get_dim(result))
return result

new_f._orig_func = f
Expand Down Expand Up @@ -4202,19 +4203,6 @@ def new_f(*args, **kwds):
return do_check_units


def _check_dim(f, val, dim):
dim = DIMENSIONLESS if dim is None else dim
if not have_same_dim(val, dim):
unit = get_dim_for_display(dim)
error_message = (
"The return value of function "
f"'{f.__name__}' was expected to have "
f"dimension {unit} but was "
f"'{val}'"
)
raise DimensionMismatchError(error_message, get_dim(val))


@set_module_as('brainunit')
def check_units(**au):
"""
Expand Down Expand Up @@ -4401,20 +4389,23 @@ def new_f(*args, **kwds):
expected_result = au["result"](*[get_dim(a) for a in args])
else:
expected_result = au["result"]

if (
jax.tree.structure(expected_result, is_leaf=_is_quantity)
!=
jax.tree.structure(result, is_leaf=_is_quantity)
):
raise TypeError(
f"Expected a return value of type {expected_result} but got {result}"
if au["result"] == bool:
if not isinstance(result, bool):
error_message = (
"The return value of function "
f"'{f.__name__}' was expected to be "
"a boolean value, but was of type "
f"{type(result)}"
)
raise TypeError(error_message)
elif not has_same_unit(result, expected_result):
error_message = (
"The return value of function "
f"'{f.__name__}' was expected to have "
f"unit {get_unit(expected_result)} but was "
f"'{result}'"
)

jax.tree.map(
partial(_check_unit, f), result, expected_result,
is_leaf=_is_quantity
)
raise UnitMismatchError(error_message, get_unit(result))
return result

new_f._orig_func = f
Expand Down Expand Up @@ -4467,3 +4458,82 @@ def _check_unit(f, val, unit):

def _is_quantity(x):
return isinstance(x, Quantity)


@set_module_as('brainunit')
def handle_units(**au):
"""
Decorator to transform units of arguments passed to a function
"""

def do_handle_units(f):
@wraps(f)
def new_f(*args, **kwds):
newkeyset = kwds.copy()
arg_names = f.__code__.co_varnames[0: f.__code__.co_argcount]
for n, v in zip(arg_names, args[0: f.__code__.co_argcount]):
if n in au and v is not None:
specific_unit = au[n]
# if the specific unit is a boolean, just check and return
if specific_unit == bool:
if isinstance(v, bool):
newkeyset[n] = v
else:
raise TypeError(f"Function '{f.__name__}' expected a boolean value for argument '{n}' but got '{v}'")

elif specific_unit == 1:
if isinstance(v, Quantity):
newkeyset[n] = v.to_decimal()
elif isinstance(v, (jax.Array, np.ndarray, int, float, complex)):
newkeyset[n] = v
else:
specific_unit = jax.typing.ArrayLike
raise TypeError(f"Function '{f.__name__}' expected a unitless Quantity object"
f"or {specific_unit} for argument '{n}' but got '{v}'")

elif isinstance(specific_unit, Unit):
if isinstance(v, Quantity):
v = v.to_decimal(specific_unit)
newkeyset[n] = v
else:
raise TypeError(
f"Function '{f.__name__}' expected a Quantity object for argument '{n}' but got '{v}'"
)
else:
raise TypeError(
f"Function '{f.__name__}' expected a target unit object or"
f" a Number, boolean object for checking, but got '{specific_unit}'"
)
else:
newkeyset[n] = v

result = f(**newkeyset)
if isinstance(result, tuple):
assert isinstance(au["result"], tuple), "The return value of the function is a tuple, but the decorator expected a single unit."
result = tuple(
Quantity(r, unit=au["result"][i]) if isinstance(au["result"][i], Unit) else r
for i, r in enumerate(result)
)
if "result" in au:
specific_unit = au["result"]
if specific_unit == bool:
if isinstance(result, bool):
pass
else:
raise TypeError(f"Function '{f.__name__}' expected a boolean value for the return value but got '{result}'")
elif specific_unit == 1:
if isinstance(result, Quantity):
result = result.to_decimal()
elif isinstance(result, (jax.Array, np.ndarray, int, float, complex)):
result = jnp.asarray(result)
else:
specific_unit = jax.typing.ArrayLike
raise TypeError(f"Function '{f.__name__}' expected a unitless Quantity object"
f" or {specific_unit} for the return value but got '{result}'")
elif isinstance(specific_unit, Unit):
result = Quantity(result, unit=specific_unit)
return result

return new_f

return do_handle_units
61 changes: 61 additions & 0 deletions brainunit/_base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1395,6 +1395,67 @@ def d_function2(true_result):
with pytest.raises(u.UnitMismatchError):
d_function2(2)

def test_handle_units(self):
"""
Test the handle_units decorator
"""

@u.handle_units(v=volt)
def a_function(v, x):
"""
v has to have units of volt, x can have any (or no) unit.
"""
return v

# Try correct units
assert a_function(3 * mV, 5 * second) == (3 * mV).to_decimal(volt)
assert a_function(3 * volt, 5 * second) == (3 * volt).to_decimal(volt)
assert a_function(5 * volt, "something") == (5 * volt).to_decimal(volt)
assert_quantity(a_function([1, 2, 3] * volt, None), ([1, 2, 3] * volt).to_decimal(volt))

# Try incorrect units
with pytest.raises(u.UnitMismatchError):
a_function(5 * second, None)
with pytest.raises(TypeError):
a_function(5, None)
with pytest.raises(TypeError):
a_function(object(), None)

@u.handle_units(result=second)
def b_function():
"""
Return a value in seconds if return_second is True, otherwise return
a value in volt.
"""
return 5 * second

# Should work (returns second)
assert b_function() == 5 * second

@u.handle_units(a=bool, b=1, result=bool)
def c_function(a, b):
if a:
return b > 0
else:
return b

assert c_function(True, 1)
assert not c_function(True, -1)
with pytest.raises(TypeError):
c_function(1, 1)
with pytest.raises(TypeError):
c_function(1 * mV, 1)

# Multiple results
@u.handle_units(result=(second, volt))
def d_function():
return 5, 3

# Should work (returns second)
assert d_function()[0] == 5 * second
assert d_function()[1] == 3 * volt



def test_str_repr():
"""
Expand Down
Loading