From d85ce404b10cd9c00eeef1b8197fcb1678e82f7e Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Sat, 8 Jun 2024 08:02:19 +0800 Subject: [PATCH] Update test cases --- brainunit/math/_compat_numpy.py | 9 ++-- brainunit/math/_compat_numpy_test.py | 69 +++++++++++++++++++++++++--- 2 files changed, 68 insertions(+), 10 deletions(-) diff --git a/brainunit/math/_compat_numpy.py b/brainunit/math/_compat_numpy.py index b84d204..1b91ab8 100644 --- a/brainunit/math/_compat_numpy.py +++ b/brainunit/math/_compat_numpy.py @@ -148,12 +148,13 @@ def f(*args, **kwargs): @set_module_as('brainunit.math') def full_like(a, fill_value, dtype=None, shape=None): - if isinstance(a, Quantity): - return Quantity(jnp.full_like(a.value, fill_value, dtype=dtype, shape=shape), unit=a.unit) - elif isinstance(a, (jax.Array, np.ndarray)): + if isinstance(a, Quantity) and isinstance(fill_value, Quantity): + fail_for_dimension_mismatch(a, fill_value, error_message='Units do not match for full_like operation.') + return Quantity(jnp.full_like(a.value, fill_value.value, dtype=dtype, shape=shape), unit=a.unit) + elif isinstance(a, (jax.Array, np.ndarray)) and not isinstance(fill_value, Quantity): return jnp.full_like(a, fill_value, dtype=dtype, shape=shape) else: - raise ValueError(f'Unsupported type: {type(a)} for full_like') + raise ValueError(f'Unsupported types : {type(a)} abd {type(fill_value)} for full_like') @set_module_as('brainunit.math') diff --git a/brainunit/math/_compat_numpy_test.py b/brainunit/math/_compat_numpy_test.py index 56127e0..d3e8bf3 100644 --- a/brainunit/math/_compat_numpy_test.py +++ b/brainunit/math/_compat_numpy_test.py @@ -63,81 +63,121 @@ def test_array(self): self.assertEqual(result.shape, (3,)) self.assertTrue(jnp.all(result == jnp.array([1, 2, 3]))) + # with Quantity + def test_full_like(self): array = jnp.array([1, 2, 3]) result = bm.full_like(array, 4) self.assertEqual(result.shape, array.shape) self.assertTrue(jnp.all(result == 4)) + q = [1, 2, 3] * U.second + result_q = bm.full_like(q, 4 * U.second) + assert_quantity(result_q, jnp.full_like(jnp.array([1, 2, 3]), 4), U.second) + def test_diag(self): array = jnp.array([1, 2, 3]) result = bm.diag(array) self.assertEqual(result.shape, (3, 3)) self.assertTrue(jnp.all(result == jnp.diag(array))) + q = [1, 2, 3] * U.second + result_q = bm.diag(q) + assert_quantity(result_q, jnp.diag(jnp.array([1, 2, 3])), U.second) + def test_tril(self): array = jnp.ones((3, 3)) result = bm.tril(array) self.assertEqual(result.shape, (3, 3)) self.assertTrue(jnp.all(result == jnp.tril(array))) + q = jnp.ones((3, 3)) * U.second + result_q = bm.tril(q) + assert_quantity(result_q, jnp.tril(jnp.ones((3, 3))), U.second) + def test_triu(self): array = jnp.ones((3, 3)) result = bm.triu(array) self.assertEqual(result.shape, (3, 3)) self.assertTrue(jnp.all(result == jnp.triu(array))) + q = jnp.ones((3, 3)) * U.second + result_q = bm.triu(q) + assert_quantity(result_q, jnp.triu(jnp.ones((3, 3))), U.second) + def test_empty_like(self): array = jnp.array([1, 2, 3]) result = bm.empty_like(array) self.assertEqual(result.shape, array.shape) + q = [1, 2, 3] * U.second + result_q = bm.empty_like(q) + assert_quantity(result_q, jnp.empty_like(jnp.array([1, 2, 3])), U.second) + def test_ones_like(self): array = jnp.array([1, 2, 3]) result = bm.ones_like(array) self.assertEqual(result.shape, array.shape) self.assertTrue(jnp.all(result == 1)) + q = [1, 2, 3] * U.second + result_q = bm.ones_like(q) + assert_quantity(result_q, jnp.ones_like(jnp.array([1, 2, 3])), U.second) + def test_zeros_like(self): array = jnp.array([1, 2, 3]) result = bm.zeros_like(array) self.assertEqual(result.shape, array.shape) self.assertTrue(jnp.all(result == 0)) + q = [1, 2, 3] * U.second + result_q = bm.zeros_like(q) + assert_quantity(result_q, jnp.zeros_like(jnp.array([1, 2, 3])), U.second) + def test_asarray(self): result = bm.asarray([1, 2, 3]) self.assertEqual(result.shape, (3,)) self.assertTrue(jnp.all(result == jnp.asarray([1, 2, 3]))) - result = bm.asarray([1 * U.second, 2 * U.second, 3 * U.second]) - assert_quantity(result, jnp.asarray([1, 2, 3]), U.second) + result_q = bm.asarray([1 * U.second, 2 * U.second, 3 * U.second]) + assert_quantity(result_q, jnp.asarray([1, 2, 3]), U.second) def test_arange(self): result = bm.arange(5) self.assertEqual(result.shape, (5,)) self.assertTrue(jnp.all(result == jnp.arange(5))) - result = bm.arange(5 * U.second, step=1 * U.second) - assert_quantity(result, jnp.arange(5, step=1), U.second) + result_q = bm.arange(5 * U.second, step=1 * U.second) + assert_quantity(result_q, jnp.arange(5, step=1), U.second) - result = bm.arange(3 * U.second, 9 * U.second, 1 * U.second) - assert_quantity(result, jnp.arange(3, 9, 1), U.second) + result_q = bm.arange(3 * U.second, 9 * U.second, 1 * U.second) + assert_quantity(result_q, jnp.arange(3, 9, 1), U.second) def test_linspace(self): result = bm.linspace(0, 10, 5) self.assertEqual(result.shape, (5,)) self.assertTrue(jnp.all(result == jnp.linspace(0, 10, 5))) + q = bm.linspace(0 * U.second, 10 * U.second, 5) + assert_quantity(q, jnp.linspace(0, 10, 5), U.second) + def test_logspace(self): result = bm.logspace(0, 2, 5) self.assertEqual(result.shape, (5,)) self.assertTrue(jnp.all(result == jnp.logspace(0, 2, 5))) + q = bm.logspace(0 * U.second, 2 * U.second, 5) + assert_quantity(q, jnp.logspace(0, 2, 5), U.second) + def test_fill_diagonal(self): array = jnp.zeros((3, 3)) result = bm.fill_diagonal(array, 5, inplace=False) self.assertTrue(jnp.all(result == jnp.array([[5, 0, 0], [0, 5, 0], [0, 0, 5]]))) + q = jnp.zeros((3, 3)) * U.second + result_q = bm.fill_diagonal(q, 5 * U.second, inplace=False) + assert_quantity(result_q, jnp.array([[5, 0, 0], [0, 5, 0], [0, 0, 5]]), U.second) + def test_array_split(self): array = jnp.arange(9) result = bm.array_split(array, 3) @@ -145,6 +185,12 @@ def test_array_split(self): for r, e in zip(result, expected): self.assertTrue(jnp.all(r == e)) + q = jnp.arange(9) * U.second + result_q = bm.array_split(q, 3) + expected_q = jnp.array_split(jnp.arange(9), 3) + for r, e in zip(result_q, expected_q): + assert_quantity(r, e, U.second) + def test_meshgrid(self): x = jnp.array([1, 2, 3]) y = jnp.array([4, 5]) @@ -153,12 +199,23 @@ def test_meshgrid(self): for r, e in zip(result, expected): self.assertTrue(jnp.all(r == e)) + x = jnp.array([1, 2, 3]) * U.second + y = jnp.array([4, 5]) * U.second + result_q = bm.meshgrid(x, y) + expected_q = jnp.meshgrid(jnp.array([1, 2, 3]), jnp.array([4, 5])) + for r, e in zip(result_q, expected_q): + assert_quantity(r, e, U.second) + def test_vander(self): array = jnp.array([1, 2, 3]) result = bm.vander(array) self.assertEqual(result.shape, (3, 3)) self.assertTrue(jnp.all(result == jnp.vander(array))) + q = jnp.array([1, 2, 3]) * U.second + result_q = bm.vander(q) + assert_quantity(result_q, jnp.vander(jnp.array([1, 2, 3])), U.second) + class TestAttributeFunctions(unittest.TestCase):