Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[docs] Fix and Update #15

Merged
merged 19 commits into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 30 additions & 18 deletions brainunit/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
'DIMENSIONLESS',
'DimensionMismatchError',
'get_or_create_dimension',
'get_unit',
'get_dim',
'get_basic_unit',
'is_unitless',
'have_same_unit',
Expand Down Expand Up @@ -119,7 +119,7 @@ def get_unit_for_display(d):
if (isinstance(d, int) and d == 1) or d is DIMENSIONLESS:
return "1"
else:
return str(get_unit(d))
return str(get_dim(d))


# SI dimensions (see table at the top of the file) and various descriptions,
Expand Down Expand Up @@ -497,7 +497,7 @@ def __str__(self):
return s


def get_unit(obj) -> Dimension:
def get_dim(obj) -> Dimension:
"""
Return the unit of any object that has them.

Expand Down Expand Up @@ -551,8 +551,8 @@ def have_same_unit(obj1, obj2) -> bool:
# should only add a small amount of unnecessary computation for cases in
# which this function returns False which very likely leads to a
# DimensionMismatchError anyway.
dim1 = get_unit(obj1)
dim2 = get_unit(obj2)
dim1 = get_dim(obj1)
dim2 = get_dim(obj2)
return (dim1 is dim2) or (dim1 == dim2) or dim1 is None or dim2 is None


Expand Down Expand Up @@ -598,11 +598,11 @@ def fail_for_dimension_mismatch(
if not _unit_checking:
return None, None

dim1 = get_unit(obj1)
dim1 = get_dim(obj1)
if obj2 is None:
dim2 = DIMENSIONLESS
else:
dim2 = get_unit(obj2)
dim2 = get_dim(obj2)

if dim1 is not dim2 and not (dim1 is None or dim2 is None):
# Special treatment for "0":
Expand Down Expand Up @@ -779,7 +779,7 @@ def is_unitless(obj) -> bool:
dimensionless : `bool`
``True`` if `obj` is dimensionless.
"""
return get_unit(obj) is DIMENSIONLESS
return get_dim(obj) is DIMENSIONLESS


def is_scalar_type(obj) -> bool:
Expand Down Expand Up @@ -1105,8 +1105,8 @@ def has_same_unit(self, other):
"""
if not _unit_checking:
return True
other_unit = get_unit(other.dim)
return (get_unit(self.dim) is other_unit) or (get_unit(self.dim) == other_unit)
other_unit = get_dim(other.dim)
return (get_dim(self.dim) is other_unit) or (get_dim(self.dim) == other_unit)

def get_best_unit(self, *regs) -> 'Quantity':
"""
Expand Down Expand Up @@ -1475,7 +1475,7 @@ def _binary_operation(
_, other_dim = fail_for_dimension_mismatch(self, other, message, value1=self, value2=other)

if other_dim is None:
other_dim = get_unit(other)
other_dim = get_dim(other)

new_dim = unit_operation(self.dim, other_dim)
result = value_operation(self.value, other.value)
Expand Down Expand Up @@ -1940,9 +1940,21 @@ def split(self, indices_or_sections, axis=0) -> List['Quantity']:
"""
return [Quantity(a, dim=self.dim) for a in jnp.split(self.value, indices_or_sections, axis=axis)]

def take(self, indices, axis=None, mode=None) -> 'Quantity':
def take(
self,
indices,
axis=None,
mode=None,
unique_indices=False,
indices_are_sorted=False,
fill_value=None,
) -> 'Quantity':
"""Return an array formed from the elements of a at the given indices."""
return Quantity(jnp.take(self.value, indices=indices, axis=axis, mode=mode), dim=self.dim)
if isinstance(fill_value, Quantity):
fill_value = fill_value.value
return Quantity(jnp.take(self.value, indices=indices, axis=axis, mode=mode,
unique_indices=unique_indices, indices_are_sorted=indices_are_sorted,
fill_value=fill_value), dim=self.dim)

def tolist(self):
"""Return the array as an ``a.ndim``-levels deep nested list of Python scalars.
Expand Down Expand Up @@ -3021,8 +3033,8 @@ def new_f(*args, **kwds):
)
raise TypeError(error_message)
if not have_same_unit(newkeyset[k], newkeyset[au[k]]):
d1 = get_unit(newkeyset[k])
d2 = get_unit(newkeyset[au[k]])
d1 = get_dim(newkeyset[k])
d2 = get_dim(newkeyset[au[k]])
error_message = (
f"Function '{f.__name__}' expected "
f"the argument '{k}' to have the same "
Expand All @@ -3043,13 +3055,13 @@ def new_f(*args, **kwds):
f"'{value}'"
)
raise DimensionMismatchError(
error_message, get_unit(newkeyset[k])
error_message, get_dim(newkeyset[k])
)

result = f(*args, **kwds)
if "result" in au:
if isinstance(au["result"], Callable) and au["result"] != bool:
expected_result = au["result"](*[get_unit(a) for a in args])
expected_result = au["result"](*[get_dim(a) for a in args])
else:
expected_result = au["result"]
if au["result"] == bool:
Expand All @@ -3069,7 +3081,7 @@ def new_f(*args, **kwds):
f"unit {unit} but was "
f"'{result}'"
)
raise DimensionMismatchError(error_message, get_unit(result))
raise DimensionMismatchError(error_message, get_dim(result))
return result

new_f._orig_func = f
Expand Down
20 changes: 10 additions & 10 deletions brainunit/_unit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
check_units,
fail_for_dimension_mismatch,
get_or_create_dimension,
get_unit,
get_dim,
get_basic_unit,
have_same_unit,
in_unit,
Expand Down Expand Up @@ -74,7 +74,7 @@ def assert_allclose(actual, desired, rtol=4.5e8, atol=0, **kwds):
def assert_quantity(q, values, unit):
values = jnp.asarray(values)
if isinstance(q, Quantity):
assert have_same_unit(q.dim, unit), f"Dimension mismatch: ({get_unit(q)}) ({get_unit(unit)})"
assert have_same_unit(q.dim, unit), f"Dimension mismatch: ({get_dim(q)}) ({get_dim(unit)})"
if not jnp.allclose(q.value, values):
raise AssertionError(f"Values do not match: {q.value} != {values}")
elif isinstance(q, jnp.ndarray):
Expand Down Expand Up @@ -145,19 +145,19 @@ def test_get_dimensions():
Test various ways of getting/comparing the dimensions of a Array.
"""
q = 500 * ms
assert get_unit(q) is get_or_create_dimension(q.dim._dims)
assert get_unit(q) is q.dim
assert get_dim(q) is get_or_create_dimension(q.dim._dims)
assert get_dim(q) is q.dim
assert q.has_same_unit(3 * second)
dims = q.dim
assert_equal(dims.get_dimension("time"), 1.0)
assert_equal(dims.get_dimension("length"), 0)

assert get_unit(5) is DIMENSIONLESS
assert get_unit(5.0) is DIMENSIONLESS
assert get_unit(np.array(5, dtype=np.int32)) is DIMENSIONLESS
assert get_unit(np.array(5.0)) is DIMENSIONLESS
assert get_unit(np.float32(5.0)) is DIMENSIONLESS
assert get_unit(np.float64(5.0)) is DIMENSIONLESS
assert get_dim(5) is DIMENSIONLESS
assert get_dim(5.0) is DIMENSIONLESS
assert get_dim(np.array(5, dtype=np.int32)) is DIMENSIONLESS
assert get_dim(np.array(5.0)) is DIMENSIONLESS
assert get_dim(np.float32(5.0)) is DIMENSIONLESS
assert get_dim(np.float64(5.0)) is DIMENSIONLESS
assert is_scalar_type(5)
assert is_scalar_type(5.0)
assert is_scalar_type(np.array(5, dtype=np.int32))
Expand Down
Loading