diff --git a/trax/tf_numpy/jax_tests/lax_numpy_test.py b/trax/tf_numpy/jax_tests/lax_numpy_test.py index 1c12a6f58..9d2241f5e 100644 --- a/trax/tf_numpy/jax_tests/lax_numpy_test.py +++ b/trax/tf_numpy/jax_tests/lax_numpy_test.py @@ -971,16 +971,33 @@ def onp_fun(lhs, rhs): self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=False, atol=tol, rtol=tol, check_incomplete_shape=True) - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_amin={}_amax={}".format( - jtu.format_shape_dtype_string(shape, dtype), a_min, a_max), - "shape": shape, "dtype": dtype, "a_min": a_min, "a_max": a_max, - "rng_factory": jtu.rand_default} - for shape in all_shapes for dtype in minus(number_dtypes, complex_dtypes) - for a_min, a_max in [(-1, None), (None, 1), (-1, 1), - (-onp.ones(1), None), - (None, onp.ones(1)), - (-onp.ones(1), onp.ones(1))])) + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_{}_amin={}_amax={}".format( + jtu.format_shape_dtype_string(shape, dtype), a_min, a_max + ), + "shape": shape, + "dtype": dtype, + "a_min": a_min, + "a_max": a_max, + "rng_factory": jtu.rand_default, + } + for shape in all_shapes + for dtype in minus(number_dtypes, complex_dtypes) + for a_min, a_max in [ + (-1, None), + (None, 1), + (-onp.ones(1), None), + (None, onp.ones(1)), + ] + + ( + [] + if onp.__version__ >= onp.lib.NumpyVersion("2.0.0") + else [(-1, 1), (-onp.ones(1), onp.ones(1))] + ) + ) + ) def testClipStaticBounds(self, shape, dtype, a_min, a_max, rng_factory): rng = rng_factory() onp_fun = lambda x: onp.clip(x, a_min=a_min, a_max=a_max) @@ -1357,7 +1374,6 @@ def testDiagIndices(self, ndim, n): onp.testing.assert_equal(onp.diag_indices(n, ndim), lnp.diag_indices(n, ndim)) - @named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_k={}".format( jtu.format_shape_dtype_string(shape, dtype), k), @@ -1951,7 +1967,6 @@ def testFlipud(self, shape, dtype, rng_factory): self._CompileAndCheck( lnp_op, args_maker, check_dtypes=True, check_incomplete_shape=True) - @named_parameters(jtu.cases_from_list( {"testcase_name": "_{}".format( jtu.format_shape_dtype_string(shape, dtype)), @@ -1968,7 +1983,6 @@ def testFliplr(self, shape, dtype, rng_factory): self._CompileAndCheck( lnp_op, args_maker, check_dtypes=True, check_incomplete_shape=True) - @named_parameters(jtu.cases_from_list( {"testcase_name": "_{}_k={}_axes={}".format( jtu.format_shape_dtype_string(shape, dtype), k, axes), @@ -2295,7 +2309,6 @@ def onp_fun(*args): tol=tol) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True, rtol=tol) - @named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}".format( jtu.format_shape_dtype_string(shape, dtype)), @@ -2318,7 +2331,6 @@ def testWhereOneArgument(self, shape, dtype): check_unknown_rank=False, check_experimental_compile=False, check_xla_forced_compile=False) - @named_parameters(jtu.cases_from_list( {"testcase_name": "_{}".format("_".join( jtu.format_shape_dtype_string(shape, dtype) @@ -2373,7 +2385,6 @@ def onp_fun(condlist, choicelist, default): check_incomplete_shape=True, rtol={onp.float64: 1e-7, onp.complex128: 1e-7}) - @jtu.disable def testIssue330(self): x = lnp.full((1, 1), lnp.array([1])[0]) # doesn't crash @@ -2429,7 +2440,6 @@ def testAtLeastNdLiterals(self, pytype, dtype, op): self._CompileAndCheck( lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True) - def testLongLong(self): self.assertAllClose( onp.int64(7), npe.jit(lambda x: x)(onp.longlong(7)), check_dtypes=True) @@ -2676,19 +2686,40 @@ def testMeshGrid(self, shapes, dtype, indexing, sparse, rng_factory): @named_parameters( jtu.cases_from_list( - {"testcase_name": ("_start_shape={}_stop_shape={}_num={}_endpoint={}" - "_retstep={}_dtype={}").format( - start_shape, stop_shape, num, endpoint, retstep, dtype), - "start_shape": start_shape, "stop_shape": stop_shape, - "num": num, "endpoint": endpoint, "retstep": retstep, - "dtype": dtype, "rng_factory": rng_factory} - for start_shape in [(), (2,), (2, 2)] - for stop_shape in [(), (2,), (2, 2)] - for num in [0, 1, 2, 5, 20] - for endpoint in [True, False] - for retstep in [True, False] - for dtype in number_dtypes + [None,] - for rng_factory in [jtu.rand_default])) + { + "testcase_name": ( + "_start_shape={}_stop_shape={}_num={}_endpoint={}" + "_retstep={}_dtype={}" + ).format(start_shape, stop_shape, num, endpoint, retstep, dtype), + "start_shape": start_shape, + "stop_shape": stop_shape, + "num": num, + "endpoint": endpoint, + "retstep": retstep, + "dtype": dtype, + "rng_factory": rng_factory, + } + for start_shape in [(), (2,), (2, 2)] + for stop_shape in [(), (2,), (2, 2)] + for num in [0, 1, 2, 5, 20] + for endpoint in [True, False] + for retstep in [True, False] + for dtype in ( + ( + float_dtypes + + complex_dtypes + + [ + None, + ] + ) + if onp.__version__ >= onp.lib.NumpyVersion("2.0.0") + else ([ + number_dtypes + None, + ]) + ) + for rng_factory in [jtu.rand_default] + ) + ) def testLinspace(self, start_shape, stop_shape, num, endpoint, retstep, dtype, rng_factory): if not endpoint and onp.issubdtype(dtype, onp.integer): @@ -2770,20 +2801,40 @@ def testLogspace(self, start_shape, stop_shape, num, @named_parameters( jtu.cases_from_list( - {"testcase_name": ("_start_shape={}_stop_shape={}_num={}_endpoint={}" - "_dtype={}").format( - start_shape, stop_shape, num, endpoint, dtype), - "start_shape": start_shape, - "stop_shape": stop_shape, - "num": num, "endpoint": endpoint, - "dtype": dtype, "rng_factory": rng_factory} - for start_shape in [(), (2,), (2, 2)] - for stop_shape in [(), (2,), (2, 2)] - for num in [0, 1, 2, 5, 20] - for endpoint in [True, False] - # NB: numpy's geomspace gives nonsense results on integer types - for dtype in inexact_dtypes + [None,] - for rng_factory in [jtu.rand_default])) + { + "testcase_name": ( + "_start_shape={}_stop_shape={}_num={}_endpoint={}_dtype={}" + ).format(start_shape, stop_shape, num, endpoint, dtype), + "start_shape": start_shape, + "stop_shape": stop_shape, + "num": num, + "endpoint": endpoint, + "dtype": dtype, + "rng_factory": rng_factory, + } + for start_shape in [(), (2,), (2, 2)] + for stop_shape in [(), (2,), (2, 2)] + for num in [0, 1, 2, 5, 20] + for endpoint in [True, False] + # NB: numpy's geomspace gives nonsense results on integer types + for dtype in ( + ( + float_dtypes + + [ + None, + ] + ) + if onp.__version__ >= onp.lib.NumpyVersion("2.0.0") + else ( + inexact_dtypes + + [ + None, + ] + ) + ) + for rng_factory in [jtu.rand_default] + ) + ) def testGeomspace(self, start_shape, stop_shape, num, endpoint, dtype, rng_factory): rng = rng_factory()