From 3ea327761c445e42df32c697a801e75f3144bde1 Mon Sep 17 00:00:00 2001 From: Ahmed Ahmed Date: Wed, 20 Nov 2024 14:16:25 -0800 Subject: [PATCH] fix for token bug that skips EOS --- src/levanter/data/text.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index 0c9c6e9de..c0a86d830 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -196,7 +196,7 @@ async def get_batch(self, indices: Sequence[int]) -> Sequence[T_co]: len = await self.wait_until_len_at_least(max(indices) + 1) if len is not None and len < max(indices) + 1: raise ValueError("Requested indices beyond the end of the dataset") - offsets = np.array(indices, dtype=np.int64) * self.seq_len + offsets = np.array(indices) * self.seq_len with ts.Batch(): out = [] for offset in offsets: @@ -929,15 +929,24 @@ def get_shard_source(self, split: str) -> Optional[ShardedDataSource[dict]]: ) -def preprocess_chat_example(batch, tokenizer: PreTrainedTokenizerBase) -> dict: +def preprocess_chat_example(batch, tokenizer: PreTrainedTokenizerBase, should_append_eos: bool) -> dict: """ Preprocess chat examples to match the format of preprocess_supervised_example. Returns a dict with input_ids and sources_len like the supervised case. + + Args: + batch: List of dicts with input/output pairs + tokenizer: HuggingFace tokenizer + should_append_eos: Whether we need to manually add EOS (True if tokenizer doesn't do it automatically) """ # Get sources (inputs) and targets (outputs) from the batch sources = [example["input"] for example in batch] targets = [example["output"] for example in batch] + # Add EOS only if needed (tokenizer doesn't do it automatically) + if should_append_eos: + targets = [t + tokenizer.eos_token for t in targets] + # Tokenize sources alone first to get the source lengths sources_tokenized = tokenizer(sources, padding=False, truncation=True) @@ -945,7 +954,7 @@ def preprocess_chat_example(batch, tokenizer: PreTrainedTokenizerBase) -> dict: full_examples = [f"{s}{t}" for s, t in zip(sources, targets)] examples_tokenized = tokenizer(full_examples, padding=False, truncation=True) - # Get source lengths to mask loss appropriately + # Get source lengths to mask loss appropriately source_lens = [len(s) for s in sources_tokenized["input_ids"]] return { @@ -965,9 +974,13 @@ def mk_chat_sft_dataset( # Set up example structure matching supervised case output_exemplar = {"input_ids": np.zeros((0,), dtype=np.int32), "sources_len": np.zeros((0,), dtype=np.int32)} + input_ids = tokenizer("hi there")["input_ids"] + should_append_eos = input_ids[-1] != tokenizer.eos_token_id + logger.info(f"Manual EOS Needed: {should_append_eos}") + # Process the dataset dataset = source.map_batches( - lambda ex: preprocess_chat_example(ex, tokenizer), + lambda ex: preprocess_chat_example(ex, tokenizer, should_append_eos), batch_size=128, num_cpus=num_cpus_used_by_tokenizer(tokenizer), output_exemplar=output_exemplar,