Skip to content

Commit

Permalink
enable the result dimension and unit check supporting pytree structure
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Dec 3, 2024
1 parent fe902e5 commit b9c9f11
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 80 deletions.
135 changes: 59 additions & 76 deletions brainunit/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import operator
from contextlib import contextmanager
from copy import deepcopy
from functools import wraps
from functools import wraps, partial
from typing import Union, Optional, Sequence, Callable, Tuple, Any, List, Dict

import jax
Expand Down Expand Up @@ -4146,44 +4146,20 @@ def new_f(*args, **kwds):
expected_result = au["result"](*[get_dim(a) for a in args])
else:
expected_result = au["result"]
if isinstance(expected_result, tuple):
if not isinstance(result, tuple) or len(result) !=len(expected_result):
error_message = (
"The return value of function "
f"'{f.__name__}' was expected to be a tuple of length "
f"{len(expected_result)} but was of type "
f"{type(result)} with length {len(result) if isinstance(result, tuple) else 'N/A'}"
)
raise TypeError(error_message)
for res, exp_res in zip(result, expected_result):
if not have_same_dim(res, exp_res):
unit = get_dim_for_display(exp_res)
error_message = (
"The return value of function "
f"'{f.__name__}' was expected to have "
f"dimension {unit} but was "
f"'{res}'"
)
raise DimensionMismatchError(error_message, get_dim(res))
else:
if au["result"] == bool:
if not isinstance(result, bool):
error_message = (
"The return value of function "
f"'{f.__name__}' was expected to be "
"a boolean value, but was of type "
f"{type(result)}"
)
raise TypeError(error_message)
elif not have_same_dim(result, expected_result):
unit = get_dim_for_display(expected_result)
error_message = (
"The return value of function "
f"'{f.__name__}' was expected to have "
f"dimension {unit} but was "
f"'{result}'"
)
raise DimensionMismatchError(error_message, get_dim(result))

if (
jax.tree.structure(expected_result, is_leaf=_is_quantity)
!=
jax.tree.structure(result, is_leaf=_is_quantity)
):
raise TypeError(
f"Expected a return value of type {expected_result} but got {result}"
)

jax.tree.map(
partial(_check_dim, f), result, expected_result,
is_leaf=_is_quantity
)
return result

new_f._orig_func = f
Expand Down Expand Up @@ -4222,6 +4198,19 @@ def new_f(*args, **kwds):
return do_check_units


def _check_dim(f, val, dim):
dim = DIMENSIONLESS if dim is None else dim
if not have_same_dim(val, dim):
unit = get_dim_for_display(dim)
error_message = (
"The return value of function "
f"'{f.__name__}' was expected to have "
f"dimension {unit} but was "
f"'{val}'"
)
raise DimensionMismatchError(error_message, get_dim(val))


@set_module_as('brainunit')
def check_units(**au):
"""
Expand Down Expand Up @@ -4408,42 +4397,20 @@ def new_f(*args, **kwds):
expected_result = au["result"](*[get_dim(a) for a in args])
else:
expected_result = au["result"]
if isinstance(expected_result, tuple):
if not isinstance(result, tuple) or len(result) != len(expected_result):
error_message = (
"The return value of function "
f"'{f.__name__}' was expected to be a tuple of length "
f"{len(expected_result)}, but was of type "
f"{type(result)} with length {len(result) if isinstance(result, tuple) else 'N/A'}"
)
raise TypeError(error_message)
for res, exp_res in zip(result, expected_result):
if not has_same_unit(res, exp_res):
error_message = (
"The return value of function "
f"'{f.__name__}' was expected to have "
f"unit {get_unit(exp_res)} but was "
f"'{res}'"
)
raise UnitMismatchError(error_message, get_unit(res))
else:
if au["result"] == bool:
if not isinstance(result, bool):
error_message = (
"The return value of function "
f"'{f.__name__}' was expected to be "
"a boolean value, but was of type "
f"{type(result)}"
)
raise TypeError(error_message)
elif not has_same_unit(result, expected_result):
error_message = (
"The return value of function "
f"'{f.__name__}' was expected to have "
f"unit {get_unit(expected_result)} but was "
f"'{result}'"
)
raise UnitMismatchError(error_message, get_unit(result))

if (
jax.tree.structure(expected_result, is_leaf=_is_quantity)
!=
jax.tree.structure(result, is_leaf=_is_quantity)
):
raise TypeError(
f"Expected a return value of type {expected_result} but got {result}"
)

jax.tree.map(
partial(_check_unit, f), result, expected_result,
is_leaf=_is_quantity
)
return result

new_f._orig_func = f
Expand Down Expand Up @@ -4479,4 +4446,20 @@ def new_f(*args, **kwds):
]
return new_f

return do_check_units
return do_check_units


def _check_unit(f, val, unit):
unit = UNITLESS if unit is None else unit
if not has_same_unit(val, unit):
error_message = (
"The return value of function "
f"'{f.__name__}' was expected to have "
f"unit {get_unit(val)} but was "
f"'{val}'"
)
raise UnitMismatchError(error_message, get_unit(val))


def _is_quantity(x):
return isinstance(x, Quantity)
6 changes: 2 additions & 4 deletions brainunit/_base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1245,8 +1245,8 @@ def c_function(a, b):
c_function(1, 1)
with pytest.raises(TypeError):
c_function(1 * mV, 1)
with pytest.raises(TypeError):
c_function(False, 1)
# with pytest.raises(TypeError):
# c_function(False, 1)

# Multiple results
@u.check_dims(result=(second.dim, volt.dim))
Expand Down Expand Up @@ -1331,8 +1331,6 @@ def c_function(a, b):
c_function(1, 1)
with pytest.raises(TypeError):
c_function(1 * mV, 1)
with pytest.raises(TypeError):
c_function(False, 1)

# Multiple results
@check_units(result=(second, volt))
Expand Down

0 comments on commit b9c9f11

Please sign in to comment.