diff --git a/chex/_src/asserts_test.py b/chex/_src/asserts_test.py index fed49ce..c172a1d 100644 --- a/chex/_src/asserts_test.py +++ b/chex/_src/asserts_test.py @@ -1200,12 +1200,14 @@ def test_assert_tree_is_on_device(self): asserts.assert_tree_is_on_device(tpu_2_tree, device=tpu_2) with self.assertRaisesRegex( - AssertionError, _get_err_regex(r"'a' resides on.*TpuDevice\(id=0") + AssertionError, + _get_err_regex(r"'a' resides on.*TpuDevice\(process_index=0, id=0"), ): asserts.assert_tree_is_on_device(tpu_1_tree, device=tpu_2) with self.assertRaisesRegex( - AssertionError, _get_err_regex(r"'a' resides on.*TpuDevice\(id=1") + AssertionError, + _get_err_regex(r"'a' resides on.*TpuDevice\(process_index=0, id=1"), ): asserts.assert_tree_is_on_device(tpu_2_tree, device=tpu_1) @@ -1735,7 +1737,7 @@ def test_assert_equal_pass(self, first, second): asserts.assert_equal(first, second) def test_assert_equal_pass_on_arrays(self): - # Not using named_parameters, becase JAX cannot be used before app.run(). + # Not using named_parameters, because JAX cannot be used before app.run(). asserts.assert_equal(jnp.ones([]), np.ones([])) asserts.assert_equal( jnp.ones([], dtype=jnp.int32), np.ones([], dtype=np.float64))