Skip to content

Commit

Permalink
fix for token bug that skips EOS
Browse files Browse the repository at this point in the history
  • Loading branch information
ahmeda14960 committed Nov 20, 2024
1 parent 80b2296 commit 3ea3277
Showing 1 changed file with 17 additions and 4 deletions.
21 changes: 17 additions & 4 deletions src/levanter/data/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ async def get_batch(self, indices: Sequence[int]) -> Sequence[T_co]:
len = await self.wait_until_len_at_least(max(indices) + 1)
if len is not None and len < max(indices) + 1:
raise ValueError("Requested indices beyond the end of the dataset")
offsets = np.array(indices, dtype=np.int64) * self.seq_len
offsets = np.array(indices) * self.seq_len
with ts.Batch():
out = []
for offset in offsets:
Expand Down Expand Up @@ -929,23 +929,32 @@ def get_shard_source(self, split: str) -> Optional[ShardedDataSource[dict]]:
)


def preprocess_chat_example(batch, tokenizer: PreTrainedTokenizerBase) -> dict:
def preprocess_chat_example(batch, tokenizer: PreTrainedTokenizerBase, should_append_eos: bool) -> dict:
"""
Preprocess chat examples to match the format of preprocess_supervised_example.
Returns a dict with input_ids and sources_len like the supervised case.
Args:
batch: List of dicts with input/output pairs
tokenizer: HuggingFace tokenizer
should_append_eos: Whether we need to manually add EOS (True if tokenizer doesn't do it automatically)
"""
# Get sources (inputs) and targets (outputs) from the batch
sources = [example["input"] for example in batch]
targets = [example["output"] for example in batch]

# Add EOS only if needed (tokenizer doesn't do it automatically)
if should_append_eos:
targets = [t + tokenizer.eos_token for t in targets]

# Tokenize sources alone first to get the source lengths
sources_tokenized = tokenizer(sources, padding=False, truncation=True)

# Combine source and target for full examples
full_examples = [f"{s}{t}" for s, t in zip(sources, targets)]
examples_tokenized = tokenizer(full_examples, padding=False, truncation=True)

# Get source lengths to mask loss appropriately
# Get source lengths to mask loss appropriately
source_lens = [len(s) for s in sources_tokenized["input_ids"]]

return {
Expand All @@ -965,9 +974,13 @@ def mk_chat_sft_dataset(
# Set up example structure matching supervised case
output_exemplar = {"input_ids": np.zeros((0,), dtype=np.int32), "sources_len": np.zeros((0,), dtype=np.int32)}

input_ids = tokenizer("hi there")["input_ids"]
should_append_eos = input_ids[-1] != tokenizer.eos_token_id
logger.info(f"Manual EOS Needed: {should_append_eos}")

# Process the dataset
dataset = source.map_batches(
lambda ex: preprocess_chat_example(ex, tokenizer),
lambda ex: preprocess_chat_example(ex, tokenizer, should_append_eos),
batch_size=128,
num_cpus=num_cpus_used_by_tokenizer(tokenizer),
output_exemplar=output_exemplar,
Expand Down

0 comments on commit 3ea3277

Please sign in to comment.