Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disables a failing remat test. #1763

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions trax/tf_numpy/jax_tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -927,7 +927,7 @@ def onp_fun(a, b):
check_xla = not set((lhs_dtype, rhs_dtype)).intersection(
(onp.int32, onp.int64))

tol = {onp.float64: 1e-14}
tol = {onp.float64: 1e-14, onp.float16: 0.04}
tol = max(jtu.tolerance(lhs_dtype, tol), jtu.tolerance(rhs_dtype, tol))
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True,
check_incomplete_shape=True,
Expand Down Expand Up @@ -1301,8 +1301,12 @@ def testCumSumProd(self, axis, shape, dtype, out_dtype, onp_op, lnp_op, rng_fact
tol=tol)
# XLA lacks int64 Cumsum/Cumprod kernels (b/168841378).
check_xla = out_dtype != onp.int64
rtol = None
if out_dtype == onp.float16:
rtol = 2e-3
self._CompileAndCheck(
lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True,
lnp_fun, args_maker, check_dtypes=True, rtol=rtol,
check_incomplete_shape=True,
check_experimental_compile=check_xla,
check_xla_forced_compile=check_xla)

Expand Down
4 changes: 3 additions & 1 deletion trax/tf_numpy/jax_tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,8 +333,10 @@ def format_test_name_suffix(opname, shapes, dtypes):

# We use special symbols, represented as singleton objects, to distinguish
# between NumPy scalars, Python scalars, and 0-D arrays.
class ScalarShape(object):
class ScalarShape:
def __len__(self): return 0
def __getitem__(self, i):
raise IndexError(f'index {i} out of range.')
class _NumpyScalar(ScalarShape): pass
class _PythonScalar(ScalarShape): pass
NUMPY_SCALAR_SHAPE = _NumpyScalar()
Expand Down