Skip to content

Commit

Permalink
skip gptj slow generate tests for now (huggingface#13809)
Browse files Browse the repository at this point in the history
  • Loading branch information
patil-suraj authored Sep 30, 2021
1 parent 41436d3 commit 8bbb53e
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions tests/test_modeling_gptj.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,8 +396,9 @@ def test_gptj_gradient_checkpointing(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs, gradient_checkpointing=True)

@slow
@tooslow
def test_batch_generation(self):
# Marked as @tooslow due to GPU OOM
model = GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", revision="float16", torch_dtype=torch.float16)
model.to(torch_device)
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B", revision="float16")
Expand Down Expand Up @@ -464,8 +465,9 @@ def test_model_from_pretrained(self):

@require_torch
class GPTJModelLanguageGenerationTest(unittest.TestCase):
@slow
@tooslow
def test_lm_generate_gptj(self):
# Marked as @tooslow due to GPU OOM
for checkpointing in [True, False]:
model = GPTJForCausalLM.from_pretrained(
"EleutherAI/gpt-j-6B", revision="float16", torch_dtype=torch.float16
Expand Down

0 comments on commit 8bbb53e

Please sign in to comment.