Skip to content

Commit

Permalink
fix test_generated_length_assisted_generation (#34935)
Browse files Browse the repository at this point in the history
fix test_generated_length_assisted_generation
  • Loading branch information
keyboardAnt authored Jan 29, 2025
1 parent ec7afad commit 42c8ccf
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3405,7 +3405,14 @@ def test_generated_length_assisted_generation(self):
assistant_model=assistant,
min_new_tokens=10,
)
self.assertTrue((input_length + 10) <= out.shape[-1] <= 20)
self.assertTrue((input_length + 10) <= out.shape[-1])

out = model.generate(
input_ids,
assistant_model=assistant,
max_new_tokens=7,
)
self.assertTrue(out.shape[-1] <= (input_length + 7))

def test_model_kwarg_assisted_decoding_decoder_only(self):
# PT-only test: TF doesn't support assisted decoding yet.
Expand Down

0 comments on commit 42c8ccf

Please sign in to comment.