diff --git a/trax/layers/research/efficient_attention_test.py b/trax/layers/research/efficient_attention_test.py index 912a68d01..125023718 100644 --- a/trax/layers/research/efficient_attention_test.py +++ b/trax/layers/research/efficient_attention_test.py @@ -94,14 +94,14 @@ def _test_equivalence_to_reference_code( test_all = self._run_forward_and_backward(test_model, inp, weights, state) test_out, test_state, test_inp_grad, test_weights_grad = test_all - self.assertEqual(jax.tree_structure(ref_out), - jax.tree_structure(test_out)) - self.assertEqual(jax.tree_structure(ref_state), - jax.tree_structure(test_state)) - self.assertEqual(jax.tree_structure(ref_inp_grad), - jax.tree_structure(test_inp_grad)) - self.assertEqual(jax.tree_structure(ref_weights_grad), - jax.tree_structure(test_weights_grad)) + self.assertEqual(jax.tree.structure(ref_out), + jax.tree.structure(test_out)) + self.assertEqual(jax.tree.structure(ref_state), + jax.tree.structure(test_state)) + self.assertEqual(jax.tree.structure(ref_inp_grad), + jax.tree.structure(test_inp_grad)) + self.assertEqual(jax.tree.structure(ref_weights_grad), + jax.tree.structure(test_weights_grad)) check_close = lambda x, y: self.assertAllClose(x, y, rtol=2e-3, atol=2e-3) fastmath.nested_map_multiarg(check_close, ref_out, test_out) @@ -168,7 +168,7 @@ def get_slice_for_val(x): dtype=x.dtype) else: return x[:, i:i+1] - return jax.tree_map(get_slice_for_val, pytree) + return jax.tree.map(get_slice_for_val, pytree) seqlen = x[0].shape[1] if isinstance(x, (tuple, list)) else x.shape[1] diff --git a/trax/models/research/bert.py b/trax/models/research/bert.py index 57ec5495b..6fb3259a3 100644 --- a/trax/models/research/bert.py +++ b/trax/models/research/bert.py @@ -308,9 +308,9 @@ def reshape_bias(name): for a, b in zip(fastmath.tree_leaves(self.weights), new_w): assert a.shape == b.shape, ( f'Expected shape {a.shape}, got shape {b.shape}') - self.weights = jax.tree_unflatten(jax.tree_structure(self.weights), new_w) + self.weights = jax.tree.unflatten(jax.tree.structure(self.weights), new_w) move_to_device = jax.jit(lambda x: x) - self.weights = jax.tree_map(move_to_device, self.weights) + self.weights = jax.tree.map(move_to_device, self.weights) def _settable_attrs(self): """We allow to set attributes required for loading the model from its checkpoints."""