diff --git a/brainunit/_base.py b/brainunit/_base.py index c29b40d..b56d7e3 100644 --- a/brainunit/_base.py +++ b/brainunit/_base.py @@ -4502,30 +4502,24 @@ def new_f(*args, **kwds): newkeyset[n] = v result = f(**newkeyset) - if isinstance(result, tuple): - assert isinstance(au["result"], tuple), "The return value of the function is a tuple, but the decorator expected a single unit." - result = tuple( - Quantity(r, unit=au["result"][i]) if isinstance(au["result"][i], Unit) else r - for i, r in enumerate(result) - ) if "result" in au: - specific_unit = au["result"] - if specific_unit == bool: - if isinstance(result, bool): - pass - else: - raise TypeError(f"Function '{f.__name__}' expected a boolean value for the return value but got '{result}'") - elif specific_unit == 1: - if isinstance(result, Quantity): - result = result.to_decimal() - elif isinstance(result, (jax.Array, np.ndarray, int, float, complex)): - result = jnp.asarray(result) - else: - specific_unit = jax.typing.ArrayLike - raise TypeError(f"Function '{f.__name__}' expected a unitless Quantity object" - f" or {specific_unit} for the return value but got '{result}'") - elif isinstance(specific_unit, Unit): - result = Quantity(result, unit=specific_unit) + if isinstance(au["result"], Callable) and au["result"] != bool: + expected_result = au["result"](*[get_unit(a) for a in args]) + else: + expected_result = au["result"] + + if ( + jax.tree.structure(expected_result, is_leaf=_is_quantity) + != + jax.tree.structure(result, is_leaf=_is_quantity) + ): + raise TypeError( + f"Expected a return value of type {expected_result} but got {result}" + ) + + result = jax.tree.map( + partial(_assign_unit, f), result, expected_result + ) return result return new_f @@ -4543,6 +4537,11 @@ def _check_unit(f, val, unit): ) raise UnitMismatchError(error_message, get_unit(val)) +def _assign_unit(f, val, unit): + if unit is None or unit == bool or unit == 1: + return val + return Quantity(val, unit=unit) + def _is_quantity(x): return isinstance(x, Quantity) diff --git a/brainunit/_base_test.py b/brainunit/_base_test.py index 5cd9053..b66d996 100644 --- a/brainunit/_base_test.py +++ b/brainunit/_base_test.py @@ -1427,7 +1427,7 @@ def b_function(): Return a value in seconds if return_second is True, otherwise return a value in volt. """ - return 5 * second + return 5 # Should work (returns second) assert b_function() == 5 * second