From 554b86dac08e3140d6baed3779ebc1bc6662967d Mon Sep 17 00:00:00 2001 From: David Hall Date: Sat, 9 Nov 2024 10:49:27 -0800 Subject: [PATCH] tweaks: truncate after pad --- src/levanter/data/text.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index 7e92d200b..4cc000e59 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -708,7 +708,9 @@ def _prepare_supervised_example(ex: dict, tokenizer: PreTrainedTokenizerBase, Po max_length=Pos.size, ) ex = {k: v[0] for k, v in ex.items()} - input_ids = hax.named(ex["input_ids"], Pos) + # padding doesn't do truncation, so we have to do it ourselves. + # Truncate from the left since we want to predict the last tokens + input_ids = hax.named(ex["input_ids"][-Pos.size :], Pos) # mask out padding and anything before the start of the target loss_mask = hax.arange(Pos) >= ex["sources_len"] - 1