Skip to content

Commit

Permalink
sft working w levanter chkpt
Browse files Browse the repository at this point in the history
  • Loading branch information
ahmeda14960 committed Oct 29, 2024
1 parent b3718c1 commit 5f36eb8
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 48 deletions.
52 changes: 52 additions & 0 deletions examples/sft/alpaca-llama-sft.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Model configuration
model:
type: llama
seq_len: 2048
hidden_dim: 4096
intermediate_dim: 11008
num_layers: 32
num_heads: 32
num_kv_heads: 32
use_flash_attention: true
flash_attention_block_size: 512
use_bias: false
use_layer_norm_weight: false

# Training configuration
trainer:
mp: p=f32,c=bfloat16
tracker:
type: wandb
project: "levanter-sft"
tags: ["llama", "sft"]
num_train_steps: 1218
train_batch_size: 64
tensor_parallel_axes: ["mlp", "heads"]
fsdp_axis: "embed"
batch_axis: "batch"
steps_per_eval: 1000

# Optimizer settings
optimizer:
learning_rate: 2e-5
weight_decay: 0.0
min_lr_ratio: 0.1
warmup: 100

# Supervised data configuration
supervised_data:
cache_dir: "gs://levanter-checkpoints/marin/sft_cache/alpaca-olmo"
input_field: "instruction"
output_field: "output"
hf_dataset_name: "tatsu-lab/alpaca" # Changed from id
hf_dataset_split: "train"
name: "alpaca" # Optional metadata
tags: ["instruction-tuning"] # Optional metadata
validation_urls: [] # Empty list for no validation files

# Additional settings
tokenizer: "allenai/OLMo-1B"
max_tune_length: 2048
epoch: 3

initialize_from_hf: false
106 changes: 58 additions & 48 deletions examples/sft/sft.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
import logging
import os
from dataclasses import dataclass
from typing import Optional, Union

import jax.random as jrandom
import transformers

import haliax as hax
from haliax import Axis
from haliax.partitioning import round_axis_for_partitioning

import levanter
from levanter import callbacks
from levanter.compat.hf_checkpoints import HFCheckpointConverter, save_hf_checkpoint_callback
from levanter.data import PermutationDataset
from levanter.data.text import EpochDataset, LMSupervisedDatasetConfig, mk_supervised_dataset
from levanter.models.lm_model import LmHeadModel, compute_next_token_loss
from levanter.optim import OptimizerConfig
from levanter.trainer import Trainer, TrainerConfig
from levanter.data.text import EpochDataset, mk_supervised_dataset
from levanter.main.train_lm import TrainLmConfig
from levanter.models.lm_model import compute_next_token_loss
from levanter.trainer import Trainer
from levanter.utils.py_utils import non_caching_cycle


Expand All @@ -28,55 +29,46 @@


@dataclass
class TrainArgs:
optimizer: OptimizerConfig
trainer: TrainerConfig

class SFTConfig(TrainLmConfig):
# inherit most of the config from TrainLmConfig
max_tune_length: int = 2048 # maximum length of the input to the model during tuning

# Supervision config
supervised_data: LMSupervisedDatasetConfig = LMSupervisedDatasetConfig()
input_field: str = "instruction" # field name for input in dataset
output_field: str = "output" # field name for output in dataset
data_cache_dir: str = "cache/" # Path to cache the tokenized data

model_name_or_path: str = "meta-llama/Llama-2-7b-hf"
trust_remote_code: bool = False # Trust remote code when loading from HuggingFace checkpoints.
model_cache_dir: Optional[str] = None # Path to cache the model. must be local.
tokenizer: str = "gpt2" # Tokenizer to use

hf_save_path: Optional[str] = "sft_hf_ckpts" # Path to save the HuggingFace checkpoint
hf_upload: Union[bool, str] = False # Name of the HuggingFace repo to upload to (if any)
hf_save_steps: int = 1000 # How often to save the HuggingFace checkpoint

epochs: int = 0 # Number of epochs to train for
def train(config: SFTConfig):

if config.initialize_from_hf:
if config.trainer.initialize_from is not None:
raise ValueError("Cannot use both --initialize_from_hf and --initialize_from")

def train(config: TrainArgs):
levanter.initialize(config)
converter = HFCheckpointConverter.from_hf(
config.model_name_or_path, trust_remote_code=config.trust_remote_code
)
else:
converter = None

converter = HFCheckpointConverter.from_hf(config.model_name_or_path, trust_remote_code=config.trust_remote_code)
model_config = converter.default_config
levanter.initialize(config)

if config.max_tune_length > model_config.Pos.size:
logger.warning(
f"max_tune_length ({config.max_tune_length}) is greater than the model's maximum length"
f" ({model_config.Pos.size}). "
)
# randomness in jax is tightly controlled by "keys" which are the states of the random number generators
# this makes deterministic training pretty easy
seed = config.trainer.seed
data_key, _, model_key, training_key = jrandom.split(jrandom.PRNGKey(seed), 4)

training_key, data_key = jrandom.split(jrandom.PRNGKey(config.trainer.seed), 2)
if config.data_seed is not None:
logger.info(f"Overriding data seed with {config.data_seed}")
data_key = jrandom.PRNGKey(config.data_seed)

tokenizer = transformers.AutoTokenizer.from_pretrained(
config.model_name_or_path,
cache_dir=config.model_cache_dir,
config.tokenizer,
model_max_length=config.max_tune_length,
padding_side="right",
trust_remote_code=True,
)
logger.info(f"Loaded tokenizer {tokenizer}")
num_new_tokens = add_special_tokens(tokenizer)
logger.info(f"Added {num_new_tokens} new tokens")

# modify converter to use our tokenizer
converter = converter.replaced(tokenizer=tokenizer)

# Configure supervised dataset
supervised_config = config.supervised_data

Expand All @@ -87,28 +79,46 @@ def train(config: TrainArgs):
train_dataset = PermutationDataset(train_dataset, data_key)

# Then wrap for epochs
if config.epochs > 0:
logger.info(f"Wrapping dataset for {config.epochs} epochs")
train_dataset = EpochDataset(train_dataset, max_epochs=config.epochs)
if config.epoch > 0:
logger.info(f"Wrapping dataset for {config.epoch} epochs")
train_dataset = EpochDataset(train_dataset, max_epochs=config.epoch)

logger.info("Creating optimizer")
optimizer = config.optimizer.build(config.trainer.num_train_steps)

# Using the trainer as a context manager does 3 things:
# 1. Sets the device mesh
# 2. Sets the axis mapping (for fsdp)
# 3. Sets the global metrics tracker
with Trainer(config.trainer, optimizer, loss_fn=compute_next_token_loss) as trainer:
parameter_axis_mapping = trainer.parameter_axis_mapping

logger.info(f"Loading pretrained model from {converter.reference_checkpoint}")
model: LmHeadModel = converter.load_pretrained(
model_config.model_type, axis_mapping=parameter_axis_mapping, dtype=trainer.mp.param_dtype
)
# We have two axis_mappings: one for storing the model and optimizer states, and one for compute
# This allows Zero-3-style parameter sharding, where we shard the parameters and optimizer state across the mesh
parameter_axis_mapping = trainer.parameter_axis_mapping

# some axes we need
Pos = config.model.Pos

# to do partitioning, our dimensions have to be divisible by the size of the physical axes they're mapped to
# For most things, we just insist you specify the config right, but tokenizers often have strange numbers of
# tokens: gpt-2 has 50257, for example. So we round up.
vocab_size = len(tokenizer)
Vocab = round_axis_for_partitioning(Axis("vocab", vocab_size), parameter_axis_mapping)
if vocab_size != Vocab.size:
logger.info(f"Rounding vocab size from {vocab_size} to {Vocab.size} for partitioning")

model = hax.named_jit(lambda m: m.resize_vocab(len(tokenizer)))(model)
state = trainer.initial_state(training_key, model_init=lambda: config.model.build(Vocab, key=model_key))

flops_per_token = config.model.flops_per_token(vocab_size)
flops_per_example = 3 * flops_per_token * Pos.size if flops_per_token is not None else None
trainer.add_hook(
callbacks.log_performance_stats(Pos.size, trainer.config.train_batch_size, flops_per_example), every=1
)

loader = trainer.data_loader(train_dataset, trainer.TrainBatch)
loader = non_caching_cycle(loader)

state = trainer.initial_state(training_key, model=model)

if int(state.step) != 0:
logger.info(f"Resuming training from step {state.step}")
for i in range(state.step):
Expand Down

0 comments on commit 5f36eb8

Please sign in to comment.