Skip to content

Commit

Permalink
Fix bugs
Browse files Browse the repository at this point in the history
Routhleck committed Jun 11, 2024
1 parent 337b365 commit 6fc6add
Showing 6 changed files with 20 additions and 20 deletions.
30 changes: 15 additions & 15 deletions brainunit/_unit_test.py
Original file line number Diff line number Diff line change
@@ -73,7 +73,7 @@ def assert_allclose(actual, desired, rtol=4.5e8, atol=0, **kwds):
def assert_quantity(q, values, unit):
values = jnp.asarray(values)
if isinstance(q, Quantity):
assert have_same_unit(q.unit, unit), f"Dimension mismatch: ({get_unit(q)}) ({get_unit(unit)})"
assert have_same_unit(q.dim, unit), f"Dimension mismatch: ({get_unit(q)}) ({get_unit(unit)})"
if not jnp.allclose(q.value, values):
raise AssertionError(f"Values do not match: {q.value} != {values}")
elif isinstance(q, jnp.ndarray):
@@ -144,10 +144,10 @@ def test_get_dimensions():
Test various ways of getting/comparing the dimensions of a Array.
"""
q = 500 * ms
assert get_unit(q) is get_or_create_dimension(q.unit._dims)
assert get_unit(q) is q.unit
assert get_unit(q) is get_or_create_dimension(q.dim._dims)
assert get_unit(q) is q.dim
assert q.has_same_unit(3 * second)
dims = q.unit
dims = q.dim
assert_equal(dims.get_dimension("time"), 1.0)
assert_equal(dims.get_dimension("length"), 0)

@@ -201,11 +201,11 @@ def test_unary_operations():


def test_operations():
q1 = Quantity(5, dim=mV)
q2 = Quantity(10, dim=mV)
assert_quantity(q1 + q2, 15, mV)
assert_quantity(q1 - q2, -5, mV)
assert_quantity(q1 * q2, 50, mV * mV)
q1 = 5 * second
q2 = 10 * second
assert_quantity(q1 + q2, 15, second)
assert_quantity(q1 - q2, -5, second)
assert_quantity(q1 * q2, 50, second * second)
assert_quantity(q2 / q1, 2, DIMENSIONLESS)
assert_quantity(q2 // q1, 2, DIMENSIONLESS)
assert_quantity(q2 % q1, 0, second)
@@ -215,21 +215,21 @@ def test_operations():
assert_quantity(round(q1, 0), 5, second)

# matmul
q1 = Quantity([1, 2], dim=mV)
q2 = Quantity([3, 4], dim=mV)
assert_quantity(q1 @ q2, 11, mV ** 2)
q1 = [1, 2] * second
q2 = [3, 4] * second
assert_quantity(q1 @ q2, 11, second ** 2)
q1 = Quantity([1, 2], unit=second)
q2 = Quantity([3, 4], unit=second)
assert_quantity(q1 @ q2, 11, second ** 2)

# shift
q1 = Quantity(0b1100, dtype=jnp.int32, unit=DIMENSIONLESS)
q1 = Quantity(0b1100, dtype=jnp.int32, dim=DIMENSIONLESS)
assert_quantity(q1 << 1, 0b11000, second)
assert_quantity(q1 >> 1, 0b110, second)


def test_numpy_methods():
q = Quantity([[1, 2], [3, 4]], dim=mV)
q = [[1, 2], [3, 4]] * second
assert q.all()
assert q.any()
assert q.nonzero()[0].tolist() == [0, 0, 1, 1]
@@ -1603,7 +1603,7 @@ def test_constants():
import brainunit._unit_constants as constants

# Check that the expected names exist and have the correct dimensions
assert constants.avogadro_constant.dim == (1 / mole).unit
assert constants.avogadro_constant.dim == (1 / mole).dim
assert constants.boltzmann_constant.dim == (joule / kelvin).dim
assert constants.electric_constant.dim == (farad / meter).dim
assert constants.electron_mass.dim == kilogram.dim
2 changes: 1 addition & 1 deletion brainunit/math/_compat_numpy_array_creation.py
Original file line number Diff line number Diff line change
@@ -521,7 +521,7 @@ def arange(*args, **kwargs):
stop=stop,
step=step,
)
unit = getattr(stop, "unit", DIMENSIONLESS)
unit = getattr(stop, "dim", DIMENSIONLESS)
# start is a position-only argument in numpy 2.0
# https://numpy.org/devdocs/release/2.0.0-notes.html#arange-s-start-argument-is-positional-only
# TODO: check whether this is still the case in the final release
2 changes: 1 addition & 1 deletion brainunit/math/_compat_numpy_funcs_change_unit.py
Original file line number Diff line number Diff line change
@@ -49,7 +49,7 @@ def decorator(func: Callable) -> Callable:
@wraps(func)
def f(x, *args, **kwargs):
if isinstance(x, Quantity):
return _return_check_unitless(Quantity(func(x.value, *args, **kwargs), dtype=change_unit_func(x.dim)))
return _return_check_unitless(Quantity(func(x.value, *args, **kwargs), dim=change_unit_func(x.dim)))
elif isinstance(x, (jnp.ndarray, np.ndarray)):
return func(x, *args, **kwargs)
else:
2 changes: 1 addition & 1 deletion brainunit/math/_compat_numpy_misc.py
Original file line number Diff line number Diff line change
@@ -245,7 +245,7 @@ def intersect1d(
fail_for_dimension_mismatch(ar1, ar2, 'intersect1d')
unit = None
if isinstance(ar1, Quantity):
unit = ar1.unit
unit = ar1.dim
ar1 = ar1.value if isinstance(ar1, Quantity) else ar1
ar2 = ar2.value if isinstance(ar2, Quantity) else ar2
result = jnp.intersect1d(ar1, ar2, assume_unique=assume_unique, return_indices=return_indices)
2 changes: 1 addition & 1 deletion brainunit/math/_compat_numpy_test.py
Original file line number Diff line number Diff line change
@@ -32,7 +32,7 @@
def assert_quantity(q, values, unit):
values = jnp.asarray(values)
if isinstance(q, Quantity):
assert q.unit == unit.dim, f"Unit mismatch: {q.unit} != {unit}"
assert q.dim == unit.dim, f"Unit mismatch: {q.dim} != {unit}"
assert jnp.allclose(q.value, values), f"Values do not match: {q.value} != {values}"
else:
assert jnp.allclose(q, values), f"Values do not match: {q} != {values}"
2 changes: 1 addition & 1 deletion brainunit/math/_others.py
Original file line number Diff line number Diff line change
@@ -16,7 +16,7 @@

import brainstate as bst

from ._compat_numpy import wrap_math_funcs_only_accept_unitless_unary
from ._compat_numpy_funcs_accept_unitless import wrap_math_funcs_only_accept_unitless_unary

__all__ = [
'exprel',

0 comments on commit 6fc6add

Please sign in to comment.