diff --git a/tests/test_longformer.py b/tests/test_longformer.py index e284814a2..b7ae2c7e1 100644 --- a/tests/test_longformer.py +++ b/tests/test_longformer.py @@ -1,4 +1,4 @@ -import jax.config +import jax import jax.numpy as jnp import numpy as np @@ -37,7 +37,6 @@ def test_causal_sliding_window_attention_simple(): def test_sliding_window_attention_fancier(): - # jax.config.update("jax_disable_jit", True) D = 4 for L, W in [(2, 1), (2, 2), (4, 2), (10, 5), (15, 5), (16, 2), (15, 3), (10, 10)]: Pos = Axis("Pos", L)