diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index 7e92d200b..4cc000e59 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -708,7 +708,9 @@ def _prepare_supervised_example(ex: dict, tokenizer: PreTrainedTokenizerBase, Po max_length=Pos.size, ) ex = {k: v[0] for k, v in ex.items()} - input_ids = hax.named(ex["input_ids"], Pos) + # padding doesn't do truncation, so we have to do it ourselves. + # Truncate from the left since we want to predict the last tokens + input_ids = hax.named(ex["input_ids"][-Pos.size :], Pos) # mask out padding and anything before the start of the target loss_mask = hax.arange(Pos) >= ex["sources_len"] - 1