Skip to content

Commit

Permalink
tweaks: truncate after pad
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed Nov 12, 2024
1 parent 0503001 commit 554b86d
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/levanter/data/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,7 +708,9 @@ def _prepare_supervised_example(ex: dict, tokenizer: PreTrainedTokenizerBase, Po
max_length=Pos.size,
)
ex = {k: v[0] for k, v in ex.items()}
input_ids = hax.named(ex["input_ids"], Pos)
# padding doesn't do truncation, so we have to do it ourselves.
# Truncate from the left since we want to predict the last tokens
input_ids = hax.named(ex["input_ids"][-Pos.size :], Pos)
# mask out padding and anything before the start of the target
loss_mask = hax.arange(Pos) >= ex["sources_len"] - 1

Expand Down

0 comments on commit 554b86d

Please sign in to comment.