Skip to content

Commit

Permalink
Support return pytree structure for assign_units
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed Dec 4, 2024
1 parent 46be2e3 commit f6346ee
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 24 deletions.
45 changes: 22 additions & 23 deletions brainunit/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4502,30 +4502,24 @@ def new_f(*args, **kwds):
newkeyset[n] = v

result = f(**newkeyset)
if isinstance(result, tuple):
assert isinstance(au["result"], tuple), "The return value of the function is a tuple, but the decorator expected a single unit."
result = tuple(
Quantity(r, unit=au["result"][i]) if isinstance(au["result"][i], Unit) else r
for i, r in enumerate(result)
)
if "result" in au:
specific_unit = au["result"]
if specific_unit == bool:
if isinstance(result, bool):
pass
else:
raise TypeError(f"Function '{f.__name__}' expected a boolean value for the return value but got '{result}'")
elif specific_unit == 1:
if isinstance(result, Quantity):
result = result.to_decimal()
elif isinstance(result, (jax.Array, np.ndarray, int, float, complex)):
result = jnp.asarray(result)
else:
specific_unit = jax.typing.ArrayLike
raise TypeError(f"Function '{f.__name__}' expected a unitless Quantity object"
f" or {specific_unit} for the return value but got '{result}'")
elif isinstance(specific_unit, Unit):
result = Quantity(result, unit=specific_unit)
if isinstance(au["result"], Callable) and au["result"] != bool:
expected_result = au["result"](*[get_unit(a) for a in args])
else:
expected_result = au["result"]

if (
jax.tree.structure(expected_result, is_leaf=_is_quantity)
!=
jax.tree.structure(result, is_leaf=_is_quantity)
):
raise TypeError(
f"Expected a return value of type {expected_result} but got {result}"
)

result = jax.tree.map(
partial(_assign_unit, f), result, expected_result
)
return result

return new_f
Expand All @@ -4543,6 +4537,11 @@ def _check_unit(f, val, unit):
)
raise UnitMismatchError(error_message, get_unit(val))

def _assign_unit(f, val, unit):
if unit is None or unit == bool or unit == 1:
return val
return Quantity(val, unit=unit)


def _is_quantity(x):
return isinstance(x, Quantity)
2 changes: 1 addition & 1 deletion brainunit/_base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1427,7 +1427,7 @@ def b_function():
Return a value in seconds if return_second is True, otherwise return
a value in volt.
"""
return 5 * second
return 5

# Should work (returns second)
assert b_function() == 5 * second
Expand Down

0 comments on commit f6346ee

Please sign in to comment.