Skip to content

Commit

Permalink
[feat] Support unit checking of a pytree of results in check_dims a…
Browse files Browse the repository at this point in the history
…nd `check_units` (#71)

* Support multiple results for `check_dims` and `check_units`

* enable the result dimension and unit check supporting pytree structure

* add more tests

---------

Co-authored-by: Chaoming Wang <[email protected]>
  • Loading branch information
Routhleck and chaoming0625 authored Dec 3, 2024
1 parent 72e6f58 commit e8f01b2
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 38 deletions.
90 changes: 56 additions & 34 deletions brainunit/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import operator
from contextlib import contextmanager
from copy import deepcopy
from functools import wraps
from functools import wraps, partial
from typing import Union, Optional, Sequence, Callable, Tuple, Any, List, Dict

import jax
Expand Down Expand Up @@ -4150,24 +4150,20 @@ def new_f(*args, **kwds):
expected_result = au["result"](*[get_dim(a) for a in args])
else:
expected_result = au["result"]
if au["result"] == bool:
if not isinstance(result, bool):
error_message = (
"The return value of function "
f"'{f.__name__}' was expected to be "
"a boolean value, but was of type "
f"{type(result)}"
)
raise TypeError(error_message)
elif not have_same_dim(result, expected_result):
unit = get_dim_for_display(expected_result)
error_message = (
"The return value of function "
f"'{f.__name__}' was expected to have "
f"dimension {unit} but was "
f"'{result}'"

if (
jax.tree.structure(expected_result, is_leaf=_is_quantity)
!=
jax.tree.structure(result, is_leaf=_is_quantity)
):
raise TypeError(
f"Expected a return value of type {expected_result} but got {result}"
)
raise DimensionMismatchError(error_message, get_dim(result))

jax.tree.map(
partial(_check_dim, f), result, expected_result,
is_leaf=_is_quantity
)
return result

new_f._orig_func = f
Expand Down Expand Up @@ -4206,6 +4202,19 @@ def new_f(*args, **kwds):
return do_check_units


def _check_dim(f, val, dim):
dim = DIMENSIONLESS if dim is None else dim
if not have_same_dim(val, dim):
unit = get_dim_for_display(dim)
error_message = (
"The return value of function "
f"'{f.__name__}' was expected to have "
f"dimension {unit} but was "
f"'{val}'"
)
raise DimensionMismatchError(error_message, get_dim(val))


@set_module_as('brainunit')
def check_units(**au):
"""
Expand Down Expand Up @@ -4392,23 +4401,20 @@ def new_f(*args, **kwds):
expected_result = au["result"](*[get_dim(a) for a in args])
else:
expected_result = au["result"]
if au["result"] == bool:
if not isinstance(result, bool):
error_message = (
"The return value of function "
f"'{f.__name__}' was expected to be "
"a boolean value, but was of type "
f"{type(result)}"
)
raise TypeError(error_message)
elif not has_same_unit(result, expected_result):
error_message = (
"The return value of function "
f"'{f.__name__}' was expected to have "
f"unit {get_unit(expected_result)} but was "
f"'{result}'"

if (
jax.tree.structure(expected_result, is_leaf=_is_quantity)
!=
jax.tree.structure(result, is_leaf=_is_quantity)
):
raise TypeError(
f"Expected a return value of type {expected_result} but got {result}"
)
raise UnitMismatchError(error_message, get_unit(result))

jax.tree.map(
partial(_check_unit, f), result, expected_result,
is_leaf=_is_quantity
)
return result

new_f._orig_func = f
Expand Down Expand Up @@ -4445,3 +4451,19 @@ def new_f(*args, **kwds):
return new_f

return do_check_units


def _check_unit(f, val, unit):
unit = UNITLESS if unit is None else unit
if not has_same_unit(val, unit):
error_message = (
"The return value of function "
f"'{f.__name__}' was expected to have "
f"unit {get_unit(val)} but was "
f"'{val}'"
)
raise UnitMismatchError(error_message, get_unit(val))


def _is_quantity(x):
return isinstance(x, Quantity)
87 changes: 83 additions & 4 deletions brainunit/_base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from numpy.testing import assert_equal

import brainunit as u
import brainunit as bu
from brainunit._base import (
DIMENSIONLESS,
UNITLESS,
Expand Down Expand Up @@ -1185,7 +1184,7 @@ def test_fail_for_dimension_mismatch(self):

def test_check_dims(self):
"""
Test the check_units decorator
Test the check_dims decorator
"""

@u.check_dims(v=volt.dim)
Expand Down Expand Up @@ -1245,8 +1244,49 @@ def c_function(a, b):
c_function(1, 1)
with pytest.raises(TypeError):
c_function(1 * mV, 1)

# with pytest.raises(TypeError):
# c_function(False, 1)

# Multiple results
@u.check_dims(result=(second.dim, volt.dim))
def d_function(true_result):
"""
Return a value in seconds if return_second is True, otherwise return
a value in volt.
"""
if true_result:
return 5 * second, 3 * volt
else:
return 3 * volt, 5 * second

# Should work (returns second)
d_function(True)
# Should fail (returns volt)
with pytest.raises(u.DimensionMismatchError):
d_function(False)

# Multiple results
@u.check_dims(result={'u': second.dim, 'v': (volt.dim, metre.dim)})
def d_function2(true_result):
"""
Return a value in seconds if return_second is True, otherwise return
a value in volt.
"""
if true_result == 0:
return {'u': 5 * second, 'v': (3 * volt, 2 * metre)}
elif true_result == 1:
return 3 * volt, 5 * second
else:
return {'u': 5 * second, 'v': (3 * volt, 2 * volt)}

d_function2(0)

with pytest.raises(TypeError):
c_function(False, 1)
d_function2(1)

with pytest.raises(u.DimensionMismatchError):
d_function2(2)

def test_check_units(self):
"""
Expand Down Expand Up @@ -1313,8 +1353,47 @@ def c_function(a, b):
c_function(1, 1)
with pytest.raises(TypeError):
c_function(1 * mV, 1)

# Multiple results
@check_units(result=(second, volt))
def d_function(true_result):
"""
Return a value in seconds if return_second is True, otherwise return
a value in volt.
"""
if true_result:
return 5 * second, 3 * volt
else:
return 3 * volt, 5 * second

# Should work (returns second)
d_function(True)
# Should fail (returns volt)
with pytest.raises(u.UnitMismatchError):
d_function(False)

# Multiple results
@check_units(result={'u': second, 'v': (volt, metre)})
def d_function2(true_result):
"""
Return a value in seconds if return_second is True, otherwise return
a value in volt.
"""
if true_result == 0:
return {'u': 5 * second, 'v': (3 * volt, 2 * metre)}
elif true_result == 1:
return 3 * volt, 5 * second
else:
return {'u': 5 * second, 'v': (3 * volt, 2 * volt)}

# Should work (returns second)
d_function2(0)
# Should fail (returns volt)
with pytest.raises(TypeError):
c_function(False, 1)
d_function2(1)

with pytest.raises(u.UnitMismatchError):
d_function2(2)


def test_str_repr():
Expand Down

0 comments on commit e8f01b2

Please sign in to comment.