From e108e8ed08642e9d35c227722cc319a913f0200b Mon Sep 17 00:00:00 2001 From: Ahmed Ahmed Date: Tue, 10 Dec 2024 15:47:59 -0500 Subject: [PATCH 1/2] overwrite pos for models loaded from hf --- src/levanter/main/sft.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/levanter/main/sft.py b/src/levanter/main/sft.py index b3ff0e74c..3f8329a2b 100644 --- a/src/levanter/main/sft.py +++ b/src/levanter/main/sft.py @@ -1,3 +1,4 @@ +import dataclasses import logging import os from dataclasses import dataclass, field @@ -99,6 +100,7 @@ def train(config: SFTConfig): converter = converter.replaced(tokenizer=tokenizer) model_config = converter.default_config + model_config = dataclasses.replace(converter.default_config, seq_len=config.max_seq_len) elif config.trainer.initialize_from is None: raise ValueError("Must specify either --initialize_from_hf or --initialize_from") else: From b3cbaaf9df8a62d3cf973afdeb5e394b36e9658f Mon Sep 17 00:00:00 2001 From: Ahmed Ahmed Date: Tue, 10 Dec 2024 17:53:17 -0500 Subject: [PATCH 2/2] fix for sft misaligned position --- src/levanter/models/loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/levanter/models/loss.py b/src/levanter/models/loss.py index bf0bd380e..fc6e424a1 100644 --- a/src/levanter/models/loss.py +++ b/src/levanter/models/loss.py @@ -43,7 +43,7 @@ def maybe_fused_next_token_loss( NamedArray: Computed loss. """ # Resolve axes - Pos = pred_embeddings.resolve_axis(Pos) + Pos = pred_embeddings.resolve_axis(Pos.name) Vocab = pred_lm_head.resolve_axis(Vocab) if block_size is None: