From 3125f2a57ded13873adb19d4172da4fd26d155ca Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 14 Jan 2025 16:59:48 -0800 Subject: [PATCH] fix sequence packing context truncation --- src/levanter/data/packing.py | 11 +++++++++++ src/levanter/eval_harness.py | 10 +++++----- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/src/levanter/data/packing.py b/src/levanter/data/packing.py index aa6028690..a049de56c 100644 --- a/src/levanter/data/packing.py +++ b/src/levanter/data/packing.py @@ -89,6 +89,17 @@ class PromptCompletion: prompt_length: int segment_id: int | None = None + def __post_init__(self): + if len(self.ids) == 0: + raise ValueError("PromptCompletion must have at least one token") + + # check that there is at least one token in the response + if len(self.ids) <= self.prompt_length: + raise ValueError( + f"PromptCompletion must have strictly more tokens than the prompt length. Got {len(self.ids)} tokens" + f" and prompt length {self.prompt_length}" + ) + def pack_prompt_completions( Pos: hax.Axis, diff --git a/src/levanter/eval_harness.py b/src/levanter/eval_harness.py index c3c33ac74..8c87c0530 100644 --- a/src/levanter/eval_harness.py +++ b/src/levanter/eval_harness.py @@ -766,20 +766,20 @@ def _iterate_tokenized_requests( for off in range(len(batch_indices)): i = batch_indices[off] context_enc = context_encodings["input_ids"][off] - whole_ids = combined_encodings["input_ids"][off] + all_enc = combined_encodings["input_ids"][off] context_enc_len = len(context_enc) - if len(whole_ids) > max_len: + if len(all_enc) > max_len: logger.warning(f"Request {i} is too long. Truncating.") # Truncate from the left - whole_ids = whole_ids[-max_len:] - context_enc_len = max_len - len(whole_ids) + context_enc_len + context_enc_len = len(context_enc) - (len(all_enc) - max_len) + all_enc = all_enc[-max_len:] if context_enc_len < 0: context_enc_len = 0 logger.warning("Prompt length is negative after truncation. Setting to 0.") - yield PromptCompletion(ids=whole_ids, prompt_length=context_enc_len, segment_id=i) + yield PromptCompletion(ids=all_enc, prompt_length=context_enc_len, segment_id=i) def _pack_requests(