Skip to content

Commit

Permalink
Update test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed Jun 8, 2024
1 parent a3a5d49 commit d85ce40
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 10 deletions.
9 changes: 5 additions & 4 deletions brainunit/math/_compat_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,12 +148,13 @@ def f(*args, **kwargs):

@set_module_as('brainunit.math')
def full_like(a, fill_value, dtype=None, shape=None):
if isinstance(a, Quantity):
return Quantity(jnp.full_like(a.value, fill_value, dtype=dtype, shape=shape), unit=a.unit)
elif isinstance(a, (jax.Array, np.ndarray)):
if isinstance(a, Quantity) and isinstance(fill_value, Quantity):
fail_for_dimension_mismatch(a, fill_value, error_message='Units do not match for full_like operation.')
return Quantity(jnp.full_like(a.value, fill_value.value, dtype=dtype, shape=shape), unit=a.unit)
elif isinstance(a, (jax.Array, np.ndarray)) and not isinstance(fill_value, Quantity):
return jnp.full_like(a, fill_value, dtype=dtype, shape=shape)
else:
raise ValueError(f'Unsupported type: {type(a)} for full_like')
raise ValueError(f'Unsupported types : {type(a)} abd {type(fill_value)} for full_like')


@set_module_as('brainunit.math')
Expand Down
69 changes: 63 additions & 6 deletions brainunit/math/_compat_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,88 +63,134 @@ def test_array(self):
self.assertEqual(result.shape, (3,))
self.assertTrue(jnp.all(result == jnp.array([1, 2, 3])))

# with Quantity

def test_full_like(self):
array = jnp.array([1, 2, 3])
result = bm.full_like(array, 4)
self.assertEqual(result.shape, array.shape)
self.assertTrue(jnp.all(result == 4))

q = [1, 2, 3] * U.second
result_q = bm.full_like(q, 4 * U.second)
assert_quantity(result_q, jnp.full_like(jnp.array([1, 2, 3]), 4), U.second)

def test_diag(self):
array = jnp.array([1, 2, 3])
result = bm.diag(array)
self.assertEqual(result.shape, (3, 3))
self.assertTrue(jnp.all(result == jnp.diag(array)))

q = [1, 2, 3] * U.second
result_q = bm.diag(q)
assert_quantity(result_q, jnp.diag(jnp.array([1, 2, 3])), U.second)

def test_tril(self):
array = jnp.ones((3, 3))
result = bm.tril(array)
self.assertEqual(result.shape, (3, 3))
self.assertTrue(jnp.all(result == jnp.tril(array)))

q = jnp.ones((3, 3)) * U.second
result_q = bm.tril(q)
assert_quantity(result_q, jnp.tril(jnp.ones((3, 3))), U.second)

def test_triu(self):
array = jnp.ones((3, 3))
result = bm.triu(array)
self.assertEqual(result.shape, (3, 3))
self.assertTrue(jnp.all(result == jnp.triu(array)))

q = jnp.ones((3, 3)) * U.second
result_q = bm.triu(q)
assert_quantity(result_q, jnp.triu(jnp.ones((3, 3))), U.second)

def test_empty_like(self):
array = jnp.array([1, 2, 3])
result = bm.empty_like(array)
self.assertEqual(result.shape, array.shape)

q = [1, 2, 3] * U.second
result_q = bm.empty_like(q)
assert_quantity(result_q, jnp.empty_like(jnp.array([1, 2, 3])), U.second)

def test_ones_like(self):
array = jnp.array([1, 2, 3])
result = bm.ones_like(array)
self.assertEqual(result.shape, array.shape)
self.assertTrue(jnp.all(result == 1))

q = [1, 2, 3] * U.second
result_q = bm.ones_like(q)
assert_quantity(result_q, jnp.ones_like(jnp.array([1, 2, 3])), U.second)

def test_zeros_like(self):
array = jnp.array([1, 2, 3])
result = bm.zeros_like(array)
self.assertEqual(result.shape, array.shape)
self.assertTrue(jnp.all(result == 0))

q = [1, 2, 3] * U.second
result_q = bm.zeros_like(q)
assert_quantity(result_q, jnp.zeros_like(jnp.array([1, 2, 3])), U.second)

def test_asarray(self):
result = bm.asarray([1, 2, 3])
self.assertEqual(result.shape, (3,))
self.assertTrue(jnp.all(result == jnp.asarray([1, 2, 3])))

result = bm.asarray([1 * U.second, 2 * U.second, 3 * U.second])
assert_quantity(result, jnp.asarray([1, 2, 3]), U.second)
result_q = bm.asarray([1 * U.second, 2 * U.second, 3 * U.second])
assert_quantity(result_q, jnp.asarray([1, 2, 3]), U.second)

def test_arange(self):
result = bm.arange(5)
self.assertEqual(result.shape, (5,))
self.assertTrue(jnp.all(result == jnp.arange(5)))

result = bm.arange(5 * U.second, step=1 * U.second)
assert_quantity(result, jnp.arange(5, step=1), U.second)
result_q = bm.arange(5 * U.second, step=1 * U.second)
assert_quantity(result_q, jnp.arange(5, step=1), U.second)

result = bm.arange(3 * U.second, 9 * U.second, 1 * U.second)
assert_quantity(result, jnp.arange(3, 9, 1), U.second)
result_q = bm.arange(3 * U.second, 9 * U.second, 1 * U.second)
assert_quantity(result_q, jnp.arange(3, 9, 1), U.second)

def test_linspace(self):
result = bm.linspace(0, 10, 5)
self.assertEqual(result.shape, (5,))
self.assertTrue(jnp.all(result == jnp.linspace(0, 10, 5)))

q = bm.linspace(0 * U.second, 10 * U.second, 5)
assert_quantity(q, jnp.linspace(0, 10, 5), U.second)

def test_logspace(self):
result = bm.logspace(0, 2, 5)
self.assertEqual(result.shape, (5,))
self.assertTrue(jnp.all(result == jnp.logspace(0, 2, 5)))

q = bm.logspace(0 * U.second, 2 * U.second, 5)
assert_quantity(q, jnp.logspace(0, 2, 5), U.second)

def test_fill_diagonal(self):
array = jnp.zeros((3, 3))
result = bm.fill_diagonal(array, 5, inplace=False)
self.assertTrue(jnp.all(result == jnp.array([[5, 0, 0], [0, 5, 0], [0, 0, 5]])))

q = jnp.zeros((3, 3)) * U.second
result_q = bm.fill_diagonal(q, 5 * U.second, inplace=False)
assert_quantity(result_q, jnp.array([[5, 0, 0], [0, 5, 0], [0, 0, 5]]), U.second)

def test_array_split(self):
array = jnp.arange(9)
result = bm.array_split(array, 3)
expected = jnp.array_split(array, 3)
for r, e in zip(result, expected):
self.assertTrue(jnp.all(r == e))

q = jnp.arange(9) * U.second
result_q = bm.array_split(q, 3)
expected_q = jnp.array_split(jnp.arange(9), 3)
for r, e in zip(result_q, expected_q):
assert_quantity(r, e, U.second)

def test_meshgrid(self):
x = jnp.array([1, 2, 3])
y = jnp.array([4, 5])
Expand All @@ -153,12 +199,23 @@ def test_meshgrid(self):
for r, e in zip(result, expected):
self.assertTrue(jnp.all(r == e))

x = jnp.array([1, 2, 3]) * U.second
y = jnp.array([4, 5]) * U.second
result_q = bm.meshgrid(x, y)
expected_q = jnp.meshgrid(jnp.array([1, 2, 3]), jnp.array([4, 5]))
for r, e in zip(result_q, expected_q):
assert_quantity(r, e, U.second)

def test_vander(self):
array = jnp.array([1, 2, 3])
result = bm.vander(array)
self.assertEqual(result.shape, (3, 3))
self.assertTrue(jnp.all(result == jnp.vander(array)))

q = jnp.array([1, 2, 3]) * U.second
result_q = bm.vander(q)
assert_quantity(result_q, jnp.vander(jnp.array([1, 2, 3])), U.second)


class TestAttributeFunctions(unittest.TestCase):

Expand Down

0 comments on commit d85ce40

Please sign in to comment.