diff --git a/trax/tf_numpy/jax_tests/lax_numpy_test.py b/trax/tf_numpy/jax_tests/lax_numpy_test.py index 7600a1448..cf014f6f4 100644 --- a/trax/tf_numpy/jax_tests/lax_numpy_test.py +++ b/trax/tf_numpy/jax_tests/lax_numpy_test.py @@ -927,7 +927,7 @@ def onp_fun(a, b): check_xla = not set((lhs_dtype, rhs_dtype)).intersection( (onp.int32, onp.int64)) - tol = {onp.float64: 1e-14} + tol = {onp.float64: 1e-14, onp.float16: 0.04} tol = max(jtu.tolerance(lhs_dtype, tol), jtu.tolerance(rhs_dtype, tol)) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True, @@ -1301,8 +1301,12 @@ def testCumSumProd(self, axis, shape, dtype, out_dtype, onp_op, lnp_op, rng_fact tol=tol) # XLA lacks int64 Cumsum/Cumprod kernels (b/168841378). check_xla = out_dtype != onp.int64 + rtol = None + if out_dtype == onp.float16: + rtol = 2e-3 self._CompileAndCheck( - lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True, + lnp_fun, args_maker, check_dtypes=True, rtol=rtol, + check_incomplete_shape=True, check_experimental_compile=check_xla, check_xla_forced_compile=check_xla) diff --git a/trax/tf_numpy/jax_tests/test_util.py b/trax/tf_numpy/jax_tests/test_util.py index fae593759..b12b04676 100644 --- a/trax/tf_numpy/jax_tests/test_util.py +++ b/trax/tf_numpy/jax_tests/test_util.py @@ -333,8 +333,10 @@ def format_test_name_suffix(opname, shapes, dtypes): # We use special symbols, represented as singleton objects, to distinguish # between NumPy scalars, Python scalars, and 0-D arrays. -class ScalarShape(object): +class ScalarShape: def __len__(self): return 0 + def __getitem__(self, i): + raise IndexError(f'index {i} out of range.') class _NumpyScalar(ScalarShape): pass class _PythonScalar(ScalarShape): pass NUMPY_SCALAR_SHAPE = _NumpyScalar()