Skip to content

Commit

Permalink
Sort devices explicitly by process index, then id (as opposed to IDs …
Browse files Browse the repository at this point in the history
…alone). IDs may be randomly generated, and are not guaranteed to be ordered based on their process index.

PiperOrigin-RevId: 650294623
  • Loading branch information
ChexDev authored and ChexDev committed Jul 18, 2024
1 parent 343d03a commit 1496cb7
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions chex/_src/asserts_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1200,12 +1200,14 @@ def test_assert_tree_is_on_device(self):
asserts.assert_tree_is_on_device(tpu_2_tree, device=tpu_2)

with self.assertRaisesRegex(
AssertionError, _get_err_regex(r"'a' resides on.*TpuDevice\(id=0")
AssertionError,
_get_err_regex(r"'a' resides on.*TpuDevice\(process_index=0, id=0"),
):
asserts.assert_tree_is_on_device(tpu_1_tree, device=tpu_2)

with self.assertRaisesRegex(
AssertionError, _get_err_regex(r"'a' resides on.*TpuDevice\(id=1")
AssertionError,
_get_err_regex(r"'a' resides on.*TpuDevice\(process_index=0, id=1"),
):
asserts.assert_tree_is_on_device(tpu_2_tree, device=tpu_1)

Expand Down Expand Up @@ -1735,7 +1737,7 @@ def test_assert_equal_pass(self, first, second):
asserts.assert_equal(first, second)

def test_assert_equal_pass_on_arrays(self):
# Not using named_parameters, becase JAX cannot be used before app.run().
# Not using named_parameters, because JAX cannot be used before app.run().
asserts.assert_equal(jnp.ones([]), np.ones([]))
asserts.assert_equal(
jnp.ones([], dtype=jnp.int32), np.ones([], dtype=np.float64))
Expand Down

0 comments on commit 1496cb7

Please sign in to comment.