From 1bb3b89427f669f2f0ec84633952e21b68964a23 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 28 Mar 2023 18:14:20 -0700 Subject: [PATCH] Remove references to deprecated jax.ShapedArray This is deprecated as of https://github.com/google/jax/pull/15263: most users will never need to use ShapedArray directly, and so having it exposed in the top-level public namespace causes undue confusion. PiperOrigin-RevId: 520189916 --- trax/tf_numpy/jax_tests/lax_numpy_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trax/tf_numpy/jax_tests/lax_numpy_test.py b/trax/tf_numpy/jax_tests/lax_numpy_test.py index a1fc26de9..e973ef79f 100644 --- a/trax/tf_numpy/jax_tests/lax_numpy_test.py +++ b/trax/tf_numpy/jax_tests/lax_numpy_test.py @@ -2881,7 +2881,7 @@ def body(i, xy): f = lambda y: lax.fori_loop(0, 5, body, (y, y)) wrapped = linear_util.wrap_init(f) pv = partial_eval.PartialVal( - (jax.ShapedArray((3, 4), onp.float32), jax.core.unit)) + (jax.core.ShapedArray((3, 4), onp.float32), jax.core.unit)) _, _, consts = partial_eval.trace_to_jaxpr(wrapped, [pv]) self.assertFalse( any(onp.array_equal(x, onp.full((3, 4), 2., dtype=onp.float32))