From ea4e9d5e9b360a0c55dcbf90f04101c3be4f3cfb Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Wed, 12 Jun 2024 13:10:58 +0800 Subject: [PATCH] Add magnitude conversion for `asarray` --- brainunit/math/_compat_numpy_array_creation.py | 9 +++++++-- brainunit/math/_compat_numpy_test.py | 8 ++++++-- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/brainunit/math/_compat_numpy_array_creation.py b/brainunit/math/_compat_numpy_array_creation.py index 32219a8..68873dc 100644 --- a/brainunit/math/_compat_numpy_array_creation.py +++ b/brainunit/math/_compat_numpy_array_creation.py @@ -498,6 +498,7 @@ def asarray( Convert the input to a quantity or array. If unit is provided, the input will be checked whether it has the same unit as the provided unit. + (If they have same dimension but different magnitude, the input will be converted to the provided unit.) If unit is not provided, the input will be converted to an array. Args: @@ -513,9 +514,13 @@ def asarray( if unit is not None: assert isinstance(unit, Unit) fail_for_dimension_mismatch(a, unit, error_message="a and unit have to have the same units.") - return Quantity(jnp.asarray(a.value, dtype=dtype, order=order), dim=a.dim) + if a.dim == unit: + return a + else: + # Convert to the magnitude of the provided unit + return Quantity(a.value / unit.value, dim=unit) else: - return Quantity(jnp.asarray(a.value, dtype=dtype, order=order), dim=a.dim) + return Quantity(jnp.asarray(a.value, dtype=dtype, order=order) / unit.value, dim=a.dim) elif isinstance(a, (jax.Array, np.ndarray)): if unit is not None: assert isinstance(unit, Unit) diff --git a/brainunit/math/_compat_numpy_test.py b/brainunit/math/_compat_numpy_test.py index 2e73403..0bd8c10 100644 --- a/brainunit/math/_compat_numpy_test.py +++ b/brainunit/math/_compat_numpy_test.py @@ -32,7 +32,7 @@ def assert_quantity(q, values, unit): values = jnp.asarray(values) if isinstance(q, Quantity): - assert q.dim == unit.dim, f"Unit mismatch: {q.dim} != {unit}" + assert q.dim == unit.dim or q.dim == unit, f"Unit mismatch: {q.dim} != {unit}" assert jnp.allclose(q.value, values), f"Values do not match: {q.value} != {values}" else: assert jnp.allclose(q, values), f"Values do not match: {q} != {values}" @@ -162,6 +162,10 @@ def test_asarray(self): result_q = bu.math.asarray([1, 2, 3], unit=bu.second) assert_quantity(result_q, jnp.asarray([1, 2, 3]), bu.second) + q1 = [1, 2, 3] * bu.second + result_q = bu.math.asarray(q1, unit=bu.ms) + assert_quantity(result_q, jnp.asarray([1, 2, 3]) * 1000, bu.ms) + def test_arange(self): result = bu.math.arange(5) self.assertEqual(result.shape, (5,)) @@ -171,7 +175,7 @@ def test_arange(self): assert_quantity(result_q, jnp.arange(5, step=1), bu.second) result_q = bu.math.arange(3 * bu.second, 9 * bu.second, 1 * bu.second) - assert_quantity(result_q, jnp.arange(3, 9, 1), bu.second) + assert_quantity(result_q, jnp.arange(3, 9, 1), bu.ms) def test_linspace(self): result = bu.math.linspace(0, 10, 5)