diff --git a/src/levanter/main/sft.py b/src/levanter/main/sft.py index ee8f6f6b7..aeaa85516 100644 --- a/src/levanter/main/sft.py +++ b/src/levanter/main/sft.py @@ -45,7 +45,7 @@ class DatasetType(str, Enum): @dataclass -class SFTConfig): +class SFTConfig: # inherit most of the config from TrainLmConfig trainer: TrainerConfig = field(default_factory=TrainerConfig) model: LmConfig = field(default_factory=LlamaConfig)