Skip to content

Commit

Permalink
Add factorized llama model for testing.
Browse files Browse the repository at this point in the history
  • Loading branch information
rjpower committed May 31, 2024
1 parent fd6333c commit 3a84009
Show file tree
Hide file tree
Showing 10 changed files with 1,891 additions and 8 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
/scratch
/cache

# Configuration for TPU launches/secrets
.config
Expand Down
45 changes: 45 additions & 0 deletions config/distill_llama3_8b.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
data:
id: dlwh/wikitext_103_detokenized
tokenizer: "meta-llama/Meta-Llama-3-8B"
cache_dir: gs://wasabi-tpu-training/wikitext-103-detokenized

teacher:
type: llama
reference_checkpoint: "meta-llama/Meta-Llama-3-8B"
gradient_checkpointing: True
seq_len: 4096
hidden_dim: 4096
intermediate_dim: 14336
num_layers: 32
num_heads: 32
num_kv_heads: 8
use_flash_attention: False

student:
type: factorized_llama
reference_checkpoint: "meta-llama/Meta-Llama-3-8B"
gradient_checkpointing: True
seq_len: 4096
hidden_dim: 4096
intermediate_dim: 14336
num_layers: 32
num_heads: 32
num_kv_heads: 8
use_flash_attention: False
factor_dim: 128

trainer:
mp: p=bf16,c=bfloat16
train_batch_size: 64
num_train_steps: 10000
steps_per_eval: 5000
tensor_parallel_axes: ["mlp", "heads"]
fsdp_axis: "embed"
batch_axis: "batch"
load_checkpoint_path: "gs://wasabi-tpu-training/distill-8b/checkpoints"


optimizer:
learning_rate: 1.2E-5 # set low for fine-tuning
weight_decay: 0.1
min_lr_ratio: 0.1
41 changes: 41 additions & 0 deletions config/distill_llama3_tiny.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
data:
id: dlwh/wikitext_103_detokenized
tokenizer: "meta-llama/Meta-Llama-3-8B"
cache_dir: gs://wasabi-tpu-training/wikitext-103-detokenized

teacher:
type: llama
seq_len: 4096
hidden_dim: 64
intermediate_dim: 64
num_layers: 32
num_heads: 4
num_kv_heads: 2
use_flash_attention: True

student:
type: factorized_llama
seq_len: 4096
hidden_dim: 64
intermediate_dim: 64
factor_dim: 16
num_layers: 32
num_heads: 4
num_kv_heads: 2
use_flash_attention: True

trainer:
mp: p=bf16,c=bfloat16
train_batch_size: 256
num_train_steps: 10000
steps_per_eval: 5000
tensor_parallel_axes: ["mlp", "heads"]
fsdp_axis: "embed"
batch_axis: "batch"
load_checkpoint_path: "gs://wasabi-tpu-training/distill-tiny/checkpoints"


optimizer:
learning_rate: 1.2E-5 # set low for fine-tuning
weight_decay: 0.1
min_lr_ratio: 0.1
Loading

0 comments on commit 3a84009

Please sign in to comment.