Skip to content

Commit

Permalink
Revise SFT File (#793)
Browse files Browse the repository at this point in the history
PR to revise SFT file to avoid breaking changes to marin and for a
request from @dlwh
  • Loading branch information
ahmeda14960 authored Nov 14, 2024
1 parent a885f20 commit f8ab21a
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 18 deletions.
13 changes: 13 additions & 0 deletions config/llama_sft_hf_ckpt.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Model configuration
model:
type: llama
seq_len: 2048
hidden_dim: 4096
intermediate_dim: 11008
num_layers: 32
num_heads: 32
num_kv_heads: 32
use_flash_attention: true
flash_attention_block_size: 512
use_bias: false
use_layer_norm_weight: false
72 changes: 54 additions & 18 deletions examples/sft/sft.py → src/levanter/main/sft.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import logging
import os
from dataclasses import dataclass
from dataclasses import dataclass, field
from enum import Enum
from typing import List, Optional
from typing import List, Optional, Union

import jax.random as jrandom
import transformers
Expand All @@ -15,10 +15,17 @@
from levanter import callbacks
from levanter.compat.hf_checkpoints import HFCheckpointConverter, HFCompatConfig, save_hf_checkpoint_callback
from levanter.data import PermutationDataset
from levanter.data.text import ChatUrlDataSourceConfig, EpochDataset, mk_chat_sft_dataset, mk_supervised_dataset
from levanter.main.train_lm import TrainLmConfig
from levanter.models.lm_model import LmHeadModel, compute_next_token_loss
from levanter.trainer import Trainer
from levanter.data.text import (
ChatUrlDataSourceConfig,
EpochDataset,
SupervisedSourceConfig,
mk_chat_sft_dataset,
mk_supervised_dataset,
)
from levanter.models.llama import LlamaConfig
from levanter.models.lm_model import LmConfig, LmHeadModel, compute_next_token_loss
from levanter.optim import AdamConfig, OptimizerConfig
from levanter.trainer import Trainer, TrainerConfig


logger = logging.getLogger(__name__)
Expand All @@ -38,24 +45,40 @@ class DatasetType(str, Enum):


@dataclass
class SFTConfig(TrainLmConfig):
class SFTConfig:
# inherit most of the config from TrainLmConfig
max_tune_length: int = 2048
trainer: TrainerConfig = field(default_factory=TrainerConfig)
model: LmConfig = field(default_factory=LlamaConfig)
optimizer: OptimizerConfig = field(default_factory=AdamConfig)
supervised_data: Optional[SupervisedSourceConfig | dict[str, SupervisedSourceConfig]] = None

# config related to continued pretraining
initialize_from_hf: Union[bool, str] = False
hf_save_path: Optional[str] = None
hf_upload: Optional[str] = None
hf_save_steps: int = 0

max_seq_len: int = 2048
model_name_or_path: str = "meta-llama/Llama-2-7b-hf"
tokenizer: str = "meta-llama/Llama-2-7b-hf"

# Add dataset type and chat-specific fields
dataset_type: DatasetType = DatasetType.HUGGINGFACE
dataset_type: DatasetType = DatasetType.CHAT_JSONL
chat_train_urls: Optional[List[str]] = None
messages_field: str = "messages"
input_role: str = "user"
output_role: str = "assistant"

data_seed: Optional[int] = None # if provided, will override the data seed from the trainer

# if provided, will initialize from this checkpoint, used for llama style data mixture
epoch: int = 0


def train(config: SFTConfig):
tokenizer = transformers.AutoTokenizer.from_pretrained(
config.tokenizer,
model_max_length=config.max_tune_length,
model_max_length=config.max_seq_len,
padding_side="right",
trust_remote_code=True,
)
Expand All @@ -79,7 +102,11 @@ def train(config: SFTConfig):
elif config.trainer.initialize_from is None:
raise ValueError("Must specify either --initialize_from_hf or --initialize_from")
else:
converter = None
if config.hf_save_steps:
converter = HFCheckpointConverter.from_hf(config.model_name_or_path, trust_remote_code=True)
converter = converter.replaced(tokenizer=tokenizer)
else:
converter = None
model_config = config.model

levanter.initialize(config)
Expand All @@ -100,8 +127,16 @@ def train(config: SFTConfig):
if config.dataset_type == DatasetType.CHAT_JSONL:
assert config.chat_train_urls is not None
assert config.supervised_data is not None

# Get the cache_dir safely
cache_dir = (
config.supervised_data.cache_dir
if not isinstance(config.supervised_data, dict)
else next(iter(config.supervised_data.values())).cache_dir
)

chat_config = ChatUrlDataSourceConfig(
cache_dir=config.supervised_data.cache_dir,
cache_dir=cache_dir,
train_urls=config.chat_train_urls, # No validation in this config
messages_field=config.messages_field,
input_role=config.input_role,
Expand All @@ -110,7 +145,13 @@ def train(config: SFTConfig):
train_dataset = mk_chat_sft_dataset(chat_config, tokenizer, model_config.Pos)
else:
assert config.supervised_data is not None
train_dataset = mk_supervised_dataset(config.supervised_data, "train", tokenizer, model_config.Pos)
if isinstance(config.supervised_data, dict):
# TODO: figure out what actually makes sense here
# for marin we will just use the url code path
config_to_use = next(iter(config.supervised_data.values()))
else:
config_to_use = config.supervised_data
train_dataset = mk_supervised_dataset(config_to_use, "train", tokenizer, model_config.Pos)
logger.info("Supervised dataset created")
train_dataset = PermutationDataset(train_dataset, data_key)

Expand Down Expand Up @@ -161,11 +202,6 @@ def train(config: SFTConfig):

loader = trainer.data_loader(train_dataset, trainer.TrainBatch)

if int(state.step) != 0:
logger.info(f"Resuming training from step {state.step}")
for i in range(state.step):
next(loader)

if config.hf_save_path is not None:
# bit gross to reach this far into the config, but it's fine
if config.trainer.checkpointer.append_run_id_to_base_path:
Expand Down

0 comments on commit f8ab21a

Please sign in to comment.