Skip to content

Commit

Permalink
Fix logic of asarray
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed Jun 11, 2024
1 parent b0154ab commit d0fcce6
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 12 deletions.
35 changes: 24 additions & 11 deletions brainunit/math/_compat_numpy_array_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,21 +433,34 @@ def asarray(
order: Optional[str] = None,
unit: Optional[Unit] = None,
) -> Union[Quantity, jax.Array]:
'''
Convert the input to a quantity or array.
If unit is provided, the input is converted to a Quantity object with the given unit.
Args:
a: array_like, Quantity, or Sequence[Quantity]
dtype: data-type, optional
order: {'C', 'F', 'A', 'K'}, optional
unit: Unit, optional
Returns:
Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array.
'''
from builtins import all as origin_all
from builtins import any as origin_any
if isinstance(a, Quantity):
return Quantity(jnp.asarray(a.value, dtype=dtype, order=order), dim=a.dim)
if unit is not None:
assert isinstance(unit, Unit)
return jnp.asarray(a.value, dtype=dtype, order=order) * unit
else:
return jnp.asarray(a.value, dtype=dtype, order=order)
elif isinstance(a, (jax.Array, np.ndarray)):
return jnp.asarray(a, dtype=dtype, order=order)
# list[Quantity]
elif isinstance(a, Sequence) and origin_all(isinstance(x, Quantity) for x in a):
# check all elements have the same unit
if origin_any(x.dim != a[0].dim for x in a):
raise ValueError('Units do not match for asarray operation.')
values = [x.value for x in a]
unit = a[0].dim
# Convert the values to a jnp.ndarray and create a Quantity object
return Quantity(jnp.asarray(values, dtype=dtype, order=order), dim=unit)
if unit is not None:
assert isinstance(unit, Unit)
return jnp.asarray(a, dtype=dtype, order=order) * unit
else:
return jnp.asarray(a, dtype=dtype, order=order)
else:
return jnp.asarray(a, dtype=dtype, order=order)

Expand Down
2 changes: 1 addition & 1 deletion brainunit/math/_compat_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def test_asarray(self):
self.assertEqual(result.shape, (3,))
self.assertTrue(jnp.all(result == jnp.asarray([1, 2, 3])))

result_q = bu.math.asarray([1 * bu.second, 2 * bu.second, 3 * bu.second])
result_q = bu.math.asarray([1, 2, 3], unit=bu.second)
assert_quantity(result_q, jnp.asarray([1, 2, 3]), bu.second)

def test_arange(self):
Expand Down

0 comments on commit d0fcce6

Please sign in to comment.