Skip to content

Commit

Permalink
updates
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Jun 14, 2024
1 parent 36d284d commit cf27d36
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 69 deletions.
35 changes: 17 additions & 18 deletions brainunit/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
'DIMENSIONLESS',
'DimensionMismatchError',
'get_or_create_dimension',
'get_unit',
'get_dim',
'get_basic_unit',
'is_unitless',
'have_same_unit',
Expand Down Expand Up @@ -119,7 +119,7 @@ def get_unit_for_display(d):
if (isinstance(d, int) and d == 1) or d is DIMENSIONLESS:
return "1"
else:
return str(get_unit(d))
return str(get_dim(d))


# SI dimensions (see table at the top of the file) and various descriptions,
Expand Down Expand Up @@ -497,7 +497,7 @@ def __str__(self):
return s


def get_unit(obj) -> Dimension:
def get_dim(obj) -> Dimension:
"""
Return the unit of any object that has them.
Expand Down Expand Up @@ -551,8 +551,8 @@ def have_same_unit(obj1, obj2) -> bool:
# should only add a small amount of unnecessary computation for cases in
# which this function returns False which very likely leads to a
# DimensionMismatchError anyway.
dim1 = get_unit(obj1)
dim2 = get_unit(obj2)
dim1 = get_dim(obj1)
dim2 = get_dim(obj2)
return (dim1 is dim2) or (dim1 == dim2) or dim1 is None or dim2 is None


Expand Down Expand Up @@ -598,11 +598,11 @@ def fail_for_dimension_mismatch(
if not _unit_checking:
return None, None

dim1 = get_unit(obj1)
dim1 = get_dim(obj1)
if obj2 is None:
dim2 = DIMENSIONLESS
else:
dim2 = get_unit(obj2)
dim2 = get_dim(obj2)

if dim1 is not dim2 and not (dim1 is None or dim2 is None):
# Special treatment for "0":
Expand Down Expand Up @@ -779,7 +779,7 @@ def is_unitless(obj) -> bool:
dimensionless : `bool`
``True`` if `obj` is dimensionless.
"""
return get_unit(obj) is DIMENSIONLESS
return get_dim(obj) is DIMENSIONLESS


def is_scalar_type(obj) -> bool:
Expand Down Expand Up @@ -1105,8 +1105,8 @@ def has_same_unit(self, other):
"""
if not _unit_checking:
return True
other_unit = get_unit(other.dim)
return (get_unit(self.dim) is other_unit) or (get_unit(self.dim) == other_unit)
other_unit = get_dim(other.dim)
return (get_dim(self.dim) is other_unit) or (get_dim(self.dim) == other_unit)

def get_best_unit(self, *regs) -> 'Quantity':
"""
Expand Down Expand Up @@ -1475,7 +1475,7 @@ def _binary_operation(
_, other_dim = fail_for_dimension_mismatch(self, other, message, value1=self, value2=other)

if other_dim is None:
other_dim = get_unit(other)
other_dim = get_dim(other)

new_dim = unit_operation(self.dim, other_dim)
result = value_operation(self.value, other.value)
Expand Down Expand Up @@ -1944,14 +1944,13 @@ def take(
self,
indices,
axis=None,
out=None,
mode=None,
unique_indices=False,
indices_are_sorted=False,
fill_value=None,
) -> 'Quantity':
"""Return an array formed from the elements of a at the given indices."""
return Quantity(jnp.take(self.value, indices=indices, axis=axis, out=out, mode=mode,
return Quantity(jnp.take(self.value, indices=indices, axis=axis, mode=mode,
unique_indices=unique_indices, indices_are_sorted=indices_are_sorted,
fill_value=fill_value), dim=self.dim)

Expand Down Expand Up @@ -3032,8 +3031,8 @@ def new_f(*args, **kwds):
)
raise TypeError(error_message)
if not have_same_unit(newkeyset[k], newkeyset[au[k]]):
d1 = get_unit(newkeyset[k])
d2 = get_unit(newkeyset[au[k]])
d1 = get_dim(newkeyset[k])
d2 = get_dim(newkeyset[au[k]])
error_message = (
f"Function '{f.__name__}' expected "
f"the argument '{k}' to have the same "
Expand All @@ -3054,13 +3053,13 @@ def new_f(*args, **kwds):
f"'{value}'"
)
raise DimensionMismatchError(
error_message, get_unit(newkeyset[k])
error_message, get_dim(newkeyset[k])
)

result = f(*args, **kwds)
if "result" in au:
if isinstance(au["result"], Callable) and au["result"] != bool:
expected_result = au["result"](*[get_unit(a) for a in args])
expected_result = au["result"](*[get_dim(a) for a in args])
else:
expected_result = au["result"]
if au["result"] == bool:
Expand All @@ -3080,7 +3079,7 @@ def new_f(*args, **kwds):
f"unit {unit} but was "
f"'{result}'"
)
raise DimensionMismatchError(error_message, get_unit(result))
raise DimensionMismatchError(error_message, get_dim(result))
return result

new_f._orig_func = f
Expand Down
20 changes: 10 additions & 10 deletions brainunit/_unit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
check_units,
fail_for_dimension_mismatch,
get_or_create_dimension,
get_unit,
get_dim,
get_basic_unit,
have_same_unit,
in_unit,
Expand Down Expand Up @@ -74,7 +74,7 @@ def assert_allclose(actual, desired, rtol=4.5e8, atol=0, **kwds):
def assert_quantity(q, values, unit):
values = jnp.asarray(values)
if isinstance(q, Quantity):
assert have_same_unit(q.dim, unit), f"Dimension mismatch: ({get_unit(q)}) ({get_unit(unit)})"
assert have_same_unit(q.dim, unit), f"Dimension mismatch: ({get_dim(q)}) ({get_dim(unit)})"
if not jnp.allclose(q.value, values):
raise AssertionError(f"Values do not match: {q.value} != {values}")
elif isinstance(q, jnp.ndarray):
Expand Down Expand Up @@ -145,19 +145,19 @@ def test_get_dimensions():
Test various ways of getting/comparing the dimensions of a Array.
"""
q = 500 * ms
assert get_unit(q) is get_or_create_dimension(q.dim._dims)
assert get_unit(q) is q.dim
assert get_dim(q) is get_or_create_dimension(q.dim._dims)
assert get_dim(q) is q.dim
assert q.has_same_unit(3 * second)
dims = q.dim
assert_equal(dims.get_dimension("time"), 1.0)
assert_equal(dims.get_dimension("length"), 0)

assert get_unit(5) is DIMENSIONLESS
assert get_unit(5.0) is DIMENSIONLESS
assert get_unit(np.array(5, dtype=np.int32)) is DIMENSIONLESS
assert get_unit(np.array(5.0)) is DIMENSIONLESS
assert get_unit(np.float32(5.0)) is DIMENSIONLESS
assert get_unit(np.float64(5.0)) is DIMENSIONLESS
assert get_dim(5) is DIMENSIONLESS
assert get_dim(5.0) is DIMENSIONLESS
assert get_dim(np.array(5, dtype=np.int32)) is DIMENSIONLESS
assert get_dim(np.array(5.0)) is DIMENSIONLESS
assert get_dim(np.float32(5.0)) is DIMENSIONLESS
assert get_dim(np.float64(5.0)) is DIMENSIONLESS
assert is_scalar_type(5)
assert is_scalar_type(5.0)
assert is_scalar_type(np.array(5, dtype=np.int32))
Expand Down
2 changes: 1 addition & 1 deletion brainunit/math/_compat_numpy_array_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1276,7 +1276,7 @@ def searchsorted(

@set_module_as('brainunit.math')
def extract(
condition: Union[Array, Quantity],
condition: Array,
arr: Union[Array, Quantity],
*,
size: Optional[int] = None,
Expand Down
34 changes: 21 additions & 13 deletions brainunit/math/_compat_numpy_funcs_keep_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import numpy as np

from brainunit._misc import set_module_as
from .._base import Quantity
from .._base import Quantity, fail_for_dimension_mismatch

__all__ = [
# math funcs keep unit (unary)
Expand Down Expand Up @@ -1090,6 +1090,7 @@ def modf(

def funcs_keep_unit_binary(func, x1, x2, *args, **kwargs):
if isinstance(x1, Quantity) and isinstance(x2, Quantity):
fail_for_dimension_mismatch(x1, x2, func.__name__)
return Quantity(func(x1.value, x2.value, *args, **kwargs), dim=x1.dim)
elif isinstance(x1, (jax.Array, np.ndarray)) and isinstance(x2, (jax.Array, np.ndarray)):
return func(x1, x2, *args, **kwargs)
Expand All @@ -1098,7 +1099,8 @@ def funcs_keep_unit_binary(func, x1, x2, *args, **kwargs):


@set_module_as('brainunit.math')
def fmod(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> Union[Quantity, jax.Array]:
def fmod(x1: Union[Quantity, jax.Array],
x2: Union[Quantity, jax.Array]) -> Union[Quantity, jax.Array]:
"""
Return the element-wise remainder of division.
Expand Down Expand Up @@ -1158,7 +1160,8 @@ def copysign(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) ->


@set_module_as('brainunit.math')
def heaviside(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> Union[Quantity, jax.Array]:
def heaviside(x1: Union[Quantity, jax.Array],
x2: jax.typing.ArrayLike) -> Union[Quantity, jax.Array]:
"""
Compute the Heaviside step function.
Expand All @@ -1174,7 +1177,8 @@ def heaviside(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) ->
out : jax.Array, Quantity
Quantity if `x1` and `x2` are Quantities that have the same unit, else an array.
"""
return funcs_keep_unit_binary(jnp.heaviside, x1, x2)
x1 = x1.value if isinstance(x1, Quantity) else x1
return jnp.heaviside(x1, x2)


@set_module_as('brainunit.math')
Expand Down Expand Up @@ -1300,12 +1304,14 @@ def gcd(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> Union
# math funcs keep unit (n-ary)
# ----------------------------
@set_module_as('brainunit.math')
def interp(x: Union[Quantity, jax.typing.ArrayLike],
xp: Union[Quantity, jax.typing.ArrayLike],
fp: Union[Quantity, jax.typing.ArrayLike],
left: Union[Quantity, jax.typing.ArrayLike] = None,
right: Union[Quantity, jax.typing.ArrayLike] = None,
period: Union[Quantity, jax.typing.ArrayLike] = None) -> Union[Quantity, jax.Array]:
def interp(
x: Union[Quantity, jax.typing.ArrayLike],
xp: Union[Quantity, jax.typing.ArrayLike],
fp: Union[Quantity, jax.typing.ArrayLike],
left: Union[Quantity, jax.typing.ArrayLike] = None,
right: Union[Quantity, jax.typing.ArrayLike] = None,
period: Union[Quantity, jax.typing.ArrayLike] = None
) -> Union[Quantity, jax.Array]:
"""
One-dimensional linear interpolation.
Expand Down Expand Up @@ -1343,9 +1349,11 @@ def interp(x: Union[Quantity, jax.typing.ArrayLike],


@set_module_as('brainunit.math')
def clip(a: Union[Quantity, jax.typing.ArrayLike],
a_min: Union[Quantity, jax.typing.ArrayLike],
a_max: Union[Quantity, jax.typing.ArrayLike]) -> Union[Quantity, jax.Array]:
def clip(
a: Union[Quantity, jax.typing.ArrayLike],
a_min: Union[Quantity, jax.typing.ArrayLike],
a_max: Union[Quantity, jax.typing.ArrayLike]
) -> Union[Quantity, jax.Array]:
"""
Clip (limit) the values in an array.
Expand Down
1 change: 0 additions & 1 deletion brainunit/math/_compat_numpy_funcs_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from brainunit._misc import set_module_as

__all__ = [

# window funcs
'bartlett', 'blackman', 'hamming', 'hanning', 'kaiser',
]
Expand Down
Loading

0 comments on commit cf27d36

Please sign in to comment.