Skip to content

Commit

Permalink
Add handling for JAX typed PRNG keys in chex.assert_trees_all_equal
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 690573965
  • Loading branch information
Jake VanderPlas authored and ChexDev committed Oct 28, 2024
1 parent eab14bb commit 7b2f989
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 1 deletion.
11 changes: 10 additions & 1 deletion chex/_src/asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -1551,11 +1551,20 @@ def _assert_trees_all_equal_static(
AssertionError: If the leaf values actual and desired are not exactly equal.
"""
def assert_fn(arr_1, arr_2):
if isinstance(arr_1, jax.Array) and jax.dtypes.issubdtype(
arr_1.dtype, jax.dtypes.prng_key
) and isinstance(arr_2, jax.Array) and jax.dtypes.issubdtype(
arr_2.dtype, jax.dtypes.prng_key
):
assert jax.random.key_impl(arr_1) == jax.random.key_impl(arr_2)
arr_1 = jax.random.key_data(arr_1)
arr_2 = jax.random.key_data(arr_2)
np.testing.assert_array_equal(
_ai.jnp_to_np_array(arr_1),
_ai.jnp_to_np_array(arr_2),
err_msg="Error in value equality check: Values not exactly equal",
strict=strict)
strict=strict,
)

def cmp_fn(arr_1, arr_2) -> bool:
try:
Expand Down
15 changes: 15 additions & 0 deletions chex/_src/asserts_chexify_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,21 @@ def fn(x, y):
):
chexified_fn(tree_1, tree_2) # Fail: not equal

def test_assert_trees_all_equal_with_prng_keys(self):
@jax.jit
def fn(x, y):
asserts.assert_trees_all_equal(x, y)
return x['a'] + y['a']

chexified_fn = asserts_chexify.chexify(fn, async_check=False)
tree1 = {'a': jnp.array([3]), 'key': jax.random.split(jax.random.key(1))}
tree2 = {'a': jnp.array([3]), 'key': jax.random.split(jax.random.key(2))}
chexified_fn(tree1, tree1) # OK
with self.assertRaisesRegex(
AssertionError, re.escape("Trees 0 and 1 differ in leaves '('key',)'")
):
chexified_fn(tree1, tree2) # Fail: not equal

def test_assert_trees_all_close(self):
@jax.jit
def fn(x, y, z):
Expand Down
9 changes: 9 additions & 0 deletions chex/_src/asserts_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -923,6 +923,15 @@ def test_tree_all_finite_should_fail_inf(self):
with self.assertRaisesRegex(ValueError, err_msg):
asserts._assert_tree_all_finite_jittable(inf_tree)

def test_assert_trees_all_equal_prng_keys(self):
tree1 = {'a': jnp.array([3]), 'key': jax.random.split(jax.random.key(1))}
tree2 = {'a': jnp.array([3]), 'key': jax.random.split(jax.random.key(2))}
asserts.assert_trees_all_equal(tree1, tree1) # OK

err_regex = _get_err_regex(r'Trees 0 and 1 differ in leaves \'key\'')
with self.assertRaisesRegex(AssertionError, err_regex):
asserts.assert_trees_all_equal(tree1, tree2) # Fail: not equal

def test_assert_trees_all_equal_passes_same_tree(self):
tree = {
'a': [jnp.zeros((1,))],
Expand Down

0 comments on commit 7b2f989

Please sign in to comment.