Skip to content

Commit

Permalink
Fix more CI tests (#35661)
Browse files Browse the repository at this point in the history
add tooslow for the fat ones
  • Loading branch information
ArthurZucker authored Jan 23, 2025
1 parent 0a950e0 commit 8f1509a
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions tests/models/gemma2/test_modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
require_torch,
require_torch_gpu,
slow,
tooslow,
torch_device,
)

Expand Down Expand Up @@ -209,6 +210,7 @@ def setUpClass(cls):
# 8 is for A100 / A10 and 7 for T4
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]

@tooslow
@require_read_token
def test_model_9b_bf16(self):
model_id = "google/gemma-2-9b"
Expand All @@ -229,6 +231,7 @@ def test_model_9b_bf16(self):

self.assertEqual(output_text, EXPECTED_TEXTS)

@tooslow
@require_read_token
def test_model_9b_fp16(self):
model_id = "google/gemma-2-9b"
Expand All @@ -250,6 +253,7 @@ def test_model_9b_fp16(self):
self.assertEqual(output_text, EXPECTED_TEXTS)

@require_read_token
@tooslow
def test_model_9b_pipeline_bf16(self):
# See https://github.com/huggingface/transformers/pull/31747 -- pipeline was broken for Gemma2 before this PR
model_id = "google/gemma-2-9b"
Expand Down Expand Up @@ -296,6 +300,7 @@ def test_model_2b_pipeline_bf16_flex_attention(self):
@require_torch_gpu
@mark.flash_attn_test
@slow
@tooslow
def test_model_9b_flash_attn(self):
# See https://github.com/huggingface/transformers/issues/31953 --- flash attn was generating garbage for gemma2, especially in long context
model_id = "google/gemma-2-9b"
Expand Down Expand Up @@ -370,6 +375,7 @@ def test_export_static_cache(self):
self.assertEqual(EXPECTED_TEXT_COMPLETION, ep_generated_text)

@require_read_token
@tooslow
def test_model_9b_bf16_flex_attention(self):
model_id = "google/gemma-2-9b"
EXPECTED_TEXTS = [
Expand Down

0 comments on commit 8f1509a

Please sign in to comment.