From 15fda9b64b015807b57f180857f521c7c82949ee Mon Sep 17 00:00:00 2001 From: David Hall Date: Fri, 26 Apr 2024 14:10:46 -0700 Subject: [PATCH] is this the problem? --- tests/test_longformer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_longformer.py b/tests/test_longformer.py index e284814a2..b7ae2c7e1 100644 --- a/tests/test_longformer.py +++ b/tests/test_longformer.py @@ -1,4 +1,4 @@ -import jax.config +import jax import jax.numpy as jnp import numpy as np @@ -37,7 +37,6 @@ def test_causal_sliding_window_attention_simple(): def test_sliding_window_attention_fancier(): - # jax.config.update("jax_disable_jit", True) D = 4 for L, W in [(2, 1), (2, 2), (4, 2), (10, 5), (15, 5), (16, 2), (15, 3), (10, 10)]: Pos = Axis("Pos", L)