diff --git a/brainunit/math/_compat_numpy_array_creation.py b/brainunit/math/_compat_numpy_array_creation.py index 61ce599..d904d80 100644 --- a/brainunit/math/_compat_numpy_array_creation.py +++ b/brainunit/math/_compat_numpy_array_creation.py @@ -475,9 +475,10 @@ def asarray( Returns: Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. """ + if unit is not None: + assert isinstance(unit, Unit), f'unit must be an instance of Unit, got {type(unit)}' if isinstance(a, Quantity): if unit is not None: - assert isinstance(unit, Unit) fail_for_dimension_mismatch(a, unit, error_message="a and unit have to have the same units.") return Quantity(jnp.asarray(a.value, dtype=dtype, order=order), dim=a.dim) elif isinstance(a, (jax.Array, np.ndarray)): @@ -494,8 +495,8 @@ def asarray( if any(x.dim != leaves[0].dim for x in leaves): raise ValueError('Units do not match for asarray operation.') values = jax.tree.unflatten(tree, [x.value for x in a]) - - fail_for_dimension_mismatch(a[0], unit, error_message="a and unit have to have the same units.") + if unit is not None: + fail_for_dimension_mismatch(a[0], unit, error_message="a and unit have to have the same units.") unit = a[0].dim # Convert the values to a jnp.ndarray and create a Quantity object return Quantity(jnp.asarray(values, dtype=dtype, order=order), dim=unit) @@ -503,7 +504,6 @@ def asarray( values = jax.tree.unflatten(tree, leaves) val = jnp.asarray(values, dtype=dtype, order=order) if unit is not None: - assert isinstance(unit, Unit) return val * unit else: return val diff --git a/brainunit/math/_compat_numpy_test.py b/brainunit/math/_compat_numpy_test.py index 489e140..3c74dfb 100644 --- a/brainunit/math/_compat_numpy_test.py +++ b/brainunit/math/_compat_numpy_test.py @@ -162,6 +162,10 @@ def test_asarray(self): result_q = bu.math.asarray([1, 2, 3], unit=bu.second) assert_quantity(result_q, jnp.asarray([1, 2, 3]), bu.second) + a = bu.math.asarray([1 * bu.second, 2 * bu.second, 3 * bu.second], unit=bu.second) + b = bu.math.asarray([1 * bu.second, 2 * bu.second, 3 * bu.second]) + assert bu.math.allclose(a, b) + def test_arange(self): result = bu.math.arange(5) self.assertEqual(result.shape, (5,))