Skip to content

Commit

Permalink
Add factorized llama model for testing.
Browse files Browse the repository at this point in the history
  • Loading branch information
rjpower committed Jun 3, 2024
1 parent 8f2689e commit e5ec7ca
Show file tree
Hide file tree
Showing 11 changed files with 1,921 additions and 9 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
/scratch
/cache

# Configuration for TPU launches/secrets
.config
Expand Down
51 changes: 51 additions & 0 deletions config/distill_llama3_8b.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
data:
id: dlwh/wikitext_103_detokenized
tokenizer: "meta-llama/Meta-Llama-3-8B"
cache_dir: gs://wasabi-tpu-training/wikitext-103-detokenized

teacher:
type: llama
reference_checkpoint: "meta-llama/Meta-Llama-3-8B"
gradient_checkpointing: True
seq_len: 4096
hidden_dim: 4096
intermediate_dim: 14336
num_layers: 32
num_heads: 32
num_kv_heads: 8
use_flash_attention: False

student:
type: factorized_llama
reference_checkpoint: "meta-llama/Meta-Llama-3-8B"
gradient_checkpointing: True
seq_len: 4096
hidden_dim: 4096
intermediate_dim: 14336
num_layers: 32
num_heads: 32
num_kv_heads: 8
use_flash_attention: False
factor_dim: 128

trainer:
mp: p=bf16,c=bfloat16
train_batch_size: 16
num_train_steps: 10000
steps_per_eval: 5000
tensor_parallel_axes: ["mlp", "heads"]
fsdp_axis: "embed"
batch_axis: "batch"
load_checkpoint_path: "gs://wasabi-tpu-training/distill-8b/checkpoints"
tracker:
type: wandb
project: "distill-8B"

optimizer:
learning_rate: 1E-3
weight_decay: 0.1
min_lr_ratio: 0.1


init_from_hf: True

44 changes: 44 additions & 0 deletions config/distill_llama3_tiny.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
data:
id: dlwh/wikitext_103_detokenized
tokenizer: "meta-llama/Meta-Llama-3-8B"
cache_dir: gs://wasabi-tpu-training/wikitext-103-detokenized

teacher:
type: llama
seq_len: 4096
hidden_dim: 64
intermediate_dim: 64
num_layers: 4
num_heads: 4
num_kv_heads: 2
use_flash_attention: True

student:
type: factorized_llama
seq_len: 4096
hidden_dim: 64
intermediate_dim: 64
factor_dim: 16
num_layers: 4
num_heads: 4
num_kv_heads: 2
use_flash_attention: True

trainer:
mp: p=bf16,c=bfloat16
train_batch_size: 1
num_train_steps: 10000
steps_per_eval: 5000
tensor_parallel_axes: ["mlp", "heads"]
fsdp_axis: "embed"
batch_axis: "batch"
load_checkpoint_path: "gs://wasabi-tpu-training/distill-tiny/checkpoints"
tracker:
type: wandb
project: "distill-tiny"


optimizer:
learning_rate: 1.2E-5 # set low for fine-tuning
weight_decay: 0.1
min_lr_ratio: 0.1
4 changes: 3 additions & 1 deletion docker/tpu/Dockerfile.incremental
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,6 @@ WORKDIR /opt/levanter

ADD pyproject.toml README.md /opt/levanter/
RUN pip install -e '.[test]'
ADD . /opt/levanter
ADD . /opt/levanter

RUN cd haliax && pip install --no-deps -e '.'
Loading

0 comments on commit e5ec7ca

Please sign in to comment.