Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
don't inhert trainlm
Browse files Browse the repository at this point in the history
ahmeda14960 committed Nov 12, 2024
1 parent 93250b4 commit 5408c75
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions src/levanter/main/sft.py
Original file line number Diff line number Diff line change
@@ -22,7 +22,6 @@
mk_chat_sft_dataset,
mk_supervised_dataset,
)
from levanter.main.train_lm import TrainLmConfig
from levanter.models.llama import LlamaConfig
from levanter.models.lm_model import LmConfig, LmHeadModel, compute_next_token_loss
from levanter.optim import AdamConfig, OptimizerConfig
@@ -46,7 +45,7 @@ class DatasetType(str, Enum):


@dataclass
class SFTConfig(TrainLmConfig):
class SFTConfig):
# inherit most of the config from TrainLmConfig
trainer: TrainerConfig = field(default_factory=TrainerConfig)
model: LmConfig = field(default_factory=LlamaConfig)

0 comments on commit 5408c75

Please sign in to comment.