Skip to content

Commit

Permalink
Add SFT + Epochs to levanter (#768)
Browse files Browse the repository at this point in the history
adds epochs with a boolean flag, which will continue epoching over the
dataset and tracks epochs throughout training. Should be backwards
compatible with checkpoints, and also allows us to read from marin formatted datasets.
  • Loading branch information
ahmeda14960 authored Nov 7, 2024
2 parents a7e42ec + caf0a38 commit 0f94ff2
Show file tree
Hide file tree
Showing 14 changed files with 857 additions and 19 deletions.
39 changes: 39 additions & 0 deletions config/llama_7b_tulu.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
data:
train_urls:
- "gs://marin-us-central2/documents/instruct/tulu_v2_mix/text/tulu-v2-sft-mixture-000.jsonl.gz"
- "gs://marin-us-central2/documents/instruct/tulu_v2_mix/text/tulu-v2-sft-mixture-001.jsonl.gz"
- "gs://marin-us-central2/documents/instruct/tulu_v2_mix/text/tulu-v2-sft-mixture-002.jsonl.gz"
cache_dir: "gs://marin-us-central2/tokenized/OLMo-1B/tuluv2_sft/"
tokenizer: "allenai/OLMo-1B"
model: # 7B class model
type: llama
seq_len: 2048
hidden_dim: 4096
intermediate_dim: 11008
num_layers: 32
num_heads: 32
num_kv_heads: 32
use_flash_attention: True
flash_attention_block_size: 512
use_bias: false
use_layer_norm_weight: false
trainer:
tracker:
type: wandb
project: "marin"
tags: ["dolma", "olmo", "llama"]

mp: p=f32,c=bfloat16
train_batch_size: 256
num_train_steps: 750000 # 3,000,000,000,000 / 4,000,000 = 750,000
steps_per_eval: 1000
tensor_parallel_axes: ["mlp", "heads"]
fsdp_axis: "embed"
batch_axis: "batch"
optimizer:
learning_rate: 4E-4
weight_decay: 0.1
min_lr_ratio: 0.1
warmup: 5000

epoch: 3
4 changes: 3 additions & 1 deletion examples/alpaca/alpaca.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,11 +162,13 @@ def _prepare_example(ex: dict) -> LmExample:
# mask out padding and anything before the start of the target
Pos = input_ids.resolve_axis("position")
if config.mask_inputs:
loss_mask = hax.arange(Pos) >= ex["source_lens"]
loss_mask = hax.arange(Pos) >= ex["source_lens"] - 1 # should be minus 1?

# don't predict the padding
targets = hax.roll(input_ids, -1, Pos)
loss_mask = loss_mask & (targets != tokenizer.pad_token_id)
# to not predict EOS token since we don't have target!
loss_mask = loss_mask & (1 - hax.nn.one_hot(-1, Pos, dtype=jax.numpy.bool_))
else:
loss_mask = 1 - hax.nn.one_hot(-1, Pos, dtype=jax.numpy.float32)
lm_ex = LmExample.causal(input_ids, loss_mask=loss_mask)
Expand Down
52 changes: 52 additions & 0 deletions examples/sft/alpaca-llama-sft.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Model configuration
model:
type: llama
seq_len: 2048
hidden_dim: 4096
intermediate_dim: 11008
num_layers: 32
num_heads: 32
num_kv_heads: 32
use_flash_attention: true
flash_attention_block_size: 512
use_bias: false
use_layer_norm_weight: false

# Training configuration
trainer:
mp: p=f32,c=bfloat16
tracker:
type: wandb
project: "levanter-sft"
tags: ["llama", "sft"]
num_train_steps: 750000
train_batch_size: 64
tensor_parallel_axes: ["mlp", "heads"]
fsdp_axis: "embed"
batch_axis: "batch"
steps_per_eval: 1000

# Optimizer settings
optimizer:
learning_rate: 2e-5
weight_decay: 0.0
min_lr_ratio: 0.1
warmup: 100

# Supervised data configuration
supervised_data:
cache_dir: "gs://levanter-checkpoints/marin/sft_cache/alpaca-olmo"
input_field: "instruction"
output_field: "output"
hf_dataset_name: "tatsu-lab/alpaca" # Changed from id
hf_dataset_split: "train"
name: "alpaca" # Optional metadata
tags: ["instruction-tuning"] # Optional metadata
validation_urls: [] # Empty list for no validation files

# Additional settings
tokenizer: "allenai/OLMo-1B"
max_tune_length: 2048
epoch: 0

initialize_from_hf: false
32 changes: 32 additions & 0 deletions examples/sft/alpaca-llama.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
model_name_or_path: meta-llama/Llama-2-7b-hf

# Training configuration
trainer:
mp: p=f32,c=bfloat16
wandb:
project: "levanter-sft"
tags: ["llama2", "alpaca"]
num_train_steps: 1218
train_batch_size: 64
# If using model parallelism
tensor_parallel_axes: ["mlp", "heads"]

# Optimizer settings
optimizer:
learning_rate: 2e-5
weight_decay: 0.0

supervised_data:
hf_dataset_name: "tatsu-lab/alpaca"
hf_dataset_split: "train"
input_field: "instruction" # change from prompt
output_field: "output" # this is correct
cache_dir: "gs://levanter-checkpoints/marin/sft_cache/alpaca-new"

max_tune_length: 2048
trust_remote_code: false
model_cache_dir: null

hf_save_path: "sft_hf_ckpts"
hf_upload: false
hf_save_steps: 1000
32 changes: 32 additions & 0 deletions examples/sft/dolly-llama.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
model_name_or_path: meta-llama/Llama-2-7b-hf

# Training configuration
trainer:
mp: p=f32,c=bfloat16
wandb:
project: "levanter-sft"
tags: ["llama2", "oasst"]
num_train_steps: 1218
train_batch_size: 128
# If using model parallelism
tensor_parallel_axes: ["mlp", "heads"]

# Optimizer settings
optimizer:
learning_rate: 2e-5
weight_decay: 0.0

supervised_data:
hf_dataset_name: "databricks/databricks-dolly-15k"
hf_dataset_split: "train"
input_field: "instruction" # change from prompt
output_field: "response" # this is correct
cache_dir: "cache/dolly"

max_tune_length: 2048
trust_remote_code: false
model_cache_dir: null

hf_save_path: "sft_hf_ckpts"
hf_upload: false
hf_save_steps: 1000
38 changes: 38 additions & 0 deletions examples/sft/oasst-llama.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
model_name_or_path: meta-llama/Llama-2-7b-hf

# Training configuration
trainer:
mp: p=f32,c=bfloat16
wandb:
project: "levanter-sft"
tags: ["llama2", "oasst"]
num_train_steps: 1218
train_batch_size: 128

# If using model parallelism
tensor_parallel_axes: ["mlp", "heads"]

# Optimizer settings
optimizer:
learning_rate: 2e-5
weight_decay: 0.0

# Supervised data configuration
supervised_data:
# For HF dataset
id: "databricks/databricks-dolly-15k"
input_field: "instruction" # adjust based on dataset
output_field: "response" # adjust based on dataset
cache_dir: "cache/dolly"

# Model configuration
max_tune_length: 2048
trust_remote_code: false
model_cache_dir: null

# Checkpoint saving configuration
hf_save_path: "sft_hf_ckpts"
hf_upload: false
hf_save_steps: 1000

# python examples/sft/sft.py --config_path examples/sft/oasst-llama2.yaml
Loading

0 comments on commit 0f94ff2

Please sign in to comment.