Skip to content

Commit

Permalink
fix for main
Browse files Browse the repository at this point in the history
  • Loading branch information
ahmeda14960 committed Nov 14, 2024
1 parent 82d37e2 commit a7317ea
Showing 1 changed file with 19 additions and 7 deletions.
26 changes: 19 additions & 7 deletions src/levanter/main/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,16 @@
from levanter.compat.hf_checkpoints import HFCheckpointConverter, HFCompatConfig, save_hf_checkpoint_callback
from levanter.data import PermutationDataset
from levanter.data.text import (
ChatUrlDataSourceConfig,
EpochDataset,
LMSupervisedDatasetConfig,
SupervisedSourceConfig,
mk_chat_sft_dataset,
mk_supervised_dataset,
)
from levanter.models.llama import LlamaConfig
from levanter.models.lm_model import LmConfig, LmHeadModel, compute_next_token_loss
from levanter.optim import AdamConfig, OptimizerConfig
from levanter.trainer import Trainer, TrainerConfig
from levanter.data.text import ChatUrlDataSourceConfig, EpochDataset, mk_chat_sft_dataset, mk_supervised_dataset
from levanter.models.lm_model import LmHeadModel, compute_next_token_loss
from levanter.trainer import Trainer


logger = logging.getLogger(__name__)
Expand All @@ -52,7 +50,7 @@ class SFTConfig:
trainer: TrainerConfig = field(default_factory=TrainerConfig)
model: LmConfig = field(default_factory=LlamaConfig)
optimizer: OptimizerConfig = field(default_factory=AdamConfig)
supervised_data: Optional[LMSupervisedDatasetConfig] = None
supervised_data: Optional[SupervisedSourceConfig | dict[str, SupervisedSourceConfig]] = None

# config related to continued pretraining
initialize_from_hf: Union[bool, str] = False
Expand Down Expand Up @@ -129,8 +127,16 @@ def train(config: SFTConfig):
if config.dataset_type == DatasetType.CHAT_JSONL:
assert config.chat_train_urls is not None
assert config.supervised_data is not None

# Get the cache_dir safely
cache_dir = (
config.supervised_data.cache_dir
if not isinstance(config.supervised_data, dict)
else next(iter(config.supervised_data.values())).cache_dir
)

chat_config = ChatUrlDataSourceConfig(
cache_dir=config.supervised_data.cache_dir,
cache_dir=cache_dir,
train_urls=config.chat_train_urls, # No validation in this config
messages_field=config.messages_field,
input_role=config.input_role,
Expand All @@ -139,7 +145,13 @@ def train(config: SFTConfig):
train_dataset = mk_chat_sft_dataset(chat_config, tokenizer, model_config.Pos)
else:
assert config.supervised_data is not None
train_dataset = mk_supervised_dataset(config.supervised_data, "train", tokenizer, model_config.Pos)
if isinstance(config.supervised_data, dict):
# TODO: figure out what actually makes sense here
# for marin we will just use the url code path
config_to_use = next(iter(config.supervised_data.values()))
else:
config_to_use = config.supervised_data
train_dataset = mk_supervised_dataset(config_to_use, "train", tokenizer, model_config.Pos)
logger.info("Supervised dataset created")
train_dataset = PermutationDataset(train_dataset, data_key)

Expand Down

0 comments on commit a7317ea

Please sign in to comment.