Skip to content

Commit

Permalink
Replace deprecated jax.tree_* functions with jax.tree.*
Browse files Browse the repository at this point in the history
The top-level `jax.tree_*` aliases have long been deprecated, and will soon be removed. Alternate APIs are in `jax.tree_util`, with shorter aliases in the `jax.tree` submodule, added in JAX version 0.4.25.

PiperOrigin-RevId: 634149223
  • Loading branch information
Jake VanderPlas authored and copybara-github committed May 16, 2024
1 parent a6a508e commit 1451479
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 11 deletions.
18 changes: 9 additions & 9 deletions trax/layers/research/efficient_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,14 +94,14 @@ def _test_equivalence_to_reference_code(
test_all = self._run_forward_and_backward(test_model, inp, weights, state)
test_out, test_state, test_inp_grad, test_weights_grad = test_all

self.assertEqual(jax.tree_structure(ref_out),
jax.tree_structure(test_out))
self.assertEqual(jax.tree_structure(ref_state),
jax.tree_structure(test_state))
self.assertEqual(jax.tree_structure(ref_inp_grad),
jax.tree_structure(test_inp_grad))
self.assertEqual(jax.tree_structure(ref_weights_grad),
jax.tree_structure(test_weights_grad))
self.assertEqual(jax.tree.structure(ref_out),
jax.tree.structure(test_out))
self.assertEqual(jax.tree.structure(ref_state),
jax.tree.structure(test_state))
self.assertEqual(jax.tree.structure(ref_inp_grad),
jax.tree.structure(test_inp_grad))
self.assertEqual(jax.tree.structure(ref_weights_grad),
jax.tree.structure(test_weights_grad))

check_close = lambda x, y: self.assertAllClose(x, y, rtol=2e-3, atol=2e-3)
fastmath.nested_map_multiarg(check_close, ref_out, test_out)
Expand Down Expand Up @@ -168,7 +168,7 @@ def get_slice_for_val(x):
dtype=x.dtype)
else:
return x[:, i:i+1]
return jax.tree_map(get_slice_for_val, pytree)
return jax.tree.map(get_slice_for_val, pytree)

seqlen = x[0].shape[1] if isinstance(x, (tuple, list)) else x.shape[1]

Expand Down
4 changes: 2 additions & 2 deletions trax/models/research/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,9 +308,9 @@ def reshape_bias(name):
for a, b in zip(fastmath.tree_leaves(self.weights), new_w):
assert a.shape == b.shape, (
f'Expected shape {a.shape}, got shape {b.shape}')
self.weights = jax.tree_unflatten(jax.tree_structure(self.weights), new_w)
self.weights = jax.tree.unflatten(jax.tree.structure(self.weights), new_w)
move_to_device = jax.jit(lambda x: x)
self.weights = jax.tree_map(move_to_device, self.weights)
self.weights = jax.tree.map(move_to_device, self.weights)

def _settable_attrs(self):
"""We allow to set attributes required for loading the model from its checkpoints."""
Expand Down

0 comments on commit 1451479

Please sign in to comment.