diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index 0c9c6e9de..3e74c96b7 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -929,15 +929,24 @@ def get_shard_source(self, split: str) -> Optional[ShardedDataSource[dict]]: ) -def preprocess_chat_example(batch, tokenizer: PreTrainedTokenizerBase) -> dict: +def preprocess_chat_example(batch, tokenizer: PreTrainedTokenizerBase, should_append_eos: bool) -> dict: """ Preprocess chat examples to match the format of preprocess_supervised_example. Returns a dict with input_ids and sources_len like the supervised case. + + Args: + batch: List of dicts with input/output pairs + tokenizer: HuggingFace tokenizer + should_append_eos: Whether we need to manually add EOS (True if tokenizer doesn't do it automatically) """ # Get sources (inputs) and targets (outputs) from the batch sources = [example["input"] for example in batch] targets = [example["output"] for example in batch] + # Add EOS only if needed (tokenizer doesn't do it automatically) + if should_append_eos: + targets = [t + tokenizer.eos_token for t in targets] + # Tokenize sources alone first to get the source lengths sources_tokenized = tokenizer(sources, padding=False, truncation=True) @@ -965,9 +974,13 @@ def mk_chat_sft_dataset( # Set up example structure matching supervised case output_exemplar = {"input_ids": np.zeros((0,), dtype=np.int32), "sources_len": np.zeros((0,), dtype=np.int32)} + input_ids = tokenizer("hi there")["input_ids"] + should_append_eos = input_ids[-1] != tokenizer.eos_token_id + logger.info(f"Manual EOS Needed: {should_append_eos}") + # Process the dataset dataset = source.map_batches( - lambda ex: preprocess_chat_example(ex, tokenizer), + lambda ex: preprocess_chat_example(ex, tokenizer, should_append_eos), batch_size=128, num_cpus=num_cpus_used_by_tokenizer(tokenizer), output_exemplar=output_exemplar,