diff --git a/trax/tf_numpy/jax_tests/lax_numpy_test.py b/trax/tf_numpy/jax_tests/lax_numpy_test.py index a1fc26de9..e973ef79f 100644 --- a/trax/tf_numpy/jax_tests/lax_numpy_test.py +++ b/trax/tf_numpy/jax_tests/lax_numpy_test.py @@ -2881,7 +2881,7 @@ def body(i, xy): f = lambda y: lax.fori_loop(0, 5, body, (y, y)) wrapped = linear_util.wrap_init(f) pv = partial_eval.PartialVal( - (jax.ShapedArray((3, 4), onp.float32), jax.core.unit)) + (jax.core.ShapedArray((3, 4), onp.float32), jax.core.unit)) _, _, consts = partial_eval.trace_to_jaxpr(wrapped, [pv]) self.assertFalse( any(onp.array_equal(x, onp.full((3, 4), 2., dtype=onp.float32))