Skip to content

Commit

Permalink
fix asarray() bug (#12)
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 authored Jun 12, 2024
1 parent 4c1a377 commit 01d4c4a
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
8 changes: 4 additions & 4 deletions brainunit/math/_compat_numpy_array_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,9 +475,10 @@ def asarray(
Returns:
Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array.
'''
if unit is not None:
assert isinstance(unit, Unit), f'unit must be an instance of Unit, got {type(unit)}'
if isinstance(a, Quantity):
if unit is not None:
assert isinstance(unit, Unit)
fail_for_dimension_mismatch(a, unit, error_message="a and unit have to have the same units.")
return Quantity(jnp.asarray(a.value, dtype=dtype, order=order), dim=a.dim)
elif isinstance(a, (jax.Array, np.ndarray)):
Expand All @@ -494,16 +495,15 @@ def asarray(
if any(x.dim != leaves[0].dim for x in leaves):
raise ValueError('Units do not match for asarray operation.')
values = jax.tree.unflatten(tree, [x.value for x in a])

fail_for_dimension_mismatch(a[0], unit, error_message="a and unit have to have the same units.")
if unit is not None:
fail_for_dimension_mismatch(a[0], unit, error_message="a and unit have to have the same units.")
unit = a[0].dim
# Convert the values to a jnp.ndarray and create a Quantity object
return Quantity(jnp.asarray(values, dtype=dtype, order=order), dim=unit)
else:
values = jax.tree.unflatten(tree, leaves)
val = jnp.asarray(values, dtype=dtype, order=order)
if unit is not None:
assert isinstance(unit, Unit)
return val * unit
else:
return val
Expand Down
4 changes: 4 additions & 0 deletions brainunit/math/_compat_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,10 @@ def test_asarray(self):
result_q = bu.math.asarray([1, 2, 3], unit=bu.second)
assert_quantity(result_q, jnp.asarray([1, 2, 3]), bu.second)

a = bu.math.asarray([1 * bu.second, 2 * bu.second, 3 * bu.second], unit=bu.second)
b = bu.math.asarray([1 * bu.second, 2 * bu.second, 3 * bu.second])
assert bu.math.allclose(a, b)

def test_arange(self):
result = bu.math.arange(5)
self.assertEqual(result.shape, (5,))
Expand Down

0 comments on commit 01d4c4a

Please sign in to comment.