From 612c7fa20b3b0b6fedda7be7528eeb2461c21eb8 Mon Sep 17 00:00:00 2001 From: Peng Wang Date: Thu, 3 Nov 2022 17:33:54 -0700 Subject: [PATCH] Disables a failing remat test. PiperOrigin-RevId: 486016573 --- trax/tf_numpy/jax_tests/lax_numpy_test.py | 8 ++++++-- trax/tf_numpy/jax_tests/test_util.py | 4 +++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/trax/tf_numpy/jax_tests/lax_numpy_test.py b/trax/tf_numpy/jax_tests/lax_numpy_test.py index 7600a1448..cf014f6f4 100644 --- a/trax/tf_numpy/jax_tests/lax_numpy_test.py +++ b/trax/tf_numpy/jax_tests/lax_numpy_test.py @@ -927,7 +927,7 @@ def onp_fun(a, b): check_xla = not set((lhs_dtype, rhs_dtype)).intersection( (onp.int32, onp.int64)) - tol = {onp.float64: 1e-14} + tol = {onp.float64: 1e-14, onp.float16: 0.04} tol = max(jtu.tolerance(lhs_dtype, tol), jtu.tolerance(rhs_dtype, tol)) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True, @@ -1301,8 +1301,12 @@ def testCumSumProd(self, axis, shape, dtype, out_dtype, onp_op, lnp_op, rng_fact tol=tol) # XLA lacks int64 Cumsum/Cumprod kernels (b/168841378). check_xla = out_dtype != onp.int64 + rtol = None + if out_dtype == onp.float16: + rtol = 2e-3 self._CompileAndCheck( - lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True, + lnp_fun, args_maker, check_dtypes=True, rtol=rtol, + check_incomplete_shape=True, check_experimental_compile=check_xla, check_xla_forced_compile=check_xla) diff --git a/trax/tf_numpy/jax_tests/test_util.py b/trax/tf_numpy/jax_tests/test_util.py index fae593759..b12b04676 100644 --- a/trax/tf_numpy/jax_tests/test_util.py +++ b/trax/tf_numpy/jax_tests/test_util.py @@ -333,8 +333,10 @@ def format_test_name_suffix(opname, shapes, dtypes): # We use special symbols, represented as singleton objects, to distinguish # between NumPy scalars, Python scalars, and 0-D arrays. -class ScalarShape(object): +class ScalarShape: def __len__(self): return 0 + def __getitem__(self, i): + raise IndexError(f'index {i} out of range.') class _NumpyScalar(ScalarShape): pass class _PythonScalar(ScalarShape): pass NUMPY_SCALAR_SHAPE = _NumpyScalar()