Skip to content

Commit

Permalink
refactor parallelism
Browse files Browse the repository at this point in the history
  • Loading branch information
blahBlahhhJ committed May 12, 2024
1 parent 2888a35 commit 79815c0
Show file tree
Hide file tree
Showing 28 changed files with 192 additions and 54 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ trainer:
tags: [ "openwebtext", "gpt2"]

mp: p=f32,c=bfloat16
model_axis_size: 1
model_ici_axis_size: 1
per_device_parallelism: 4

train_batch_size: 512
Expand Down
2 changes: 1 addition & 1 deletion config/backpack.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ trainer:

num_train_steps: 50000
train_batch_size: 1024
model_axis_size: 1
model_ici_axis_size: 1

optimizer:
learning_rate: 6E-4
Expand Down
2 changes: 1 addition & 1 deletion config/backpack_nano.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ trainer:

num_train_steps: 100
train_batch_size: 32
model_axis_size: 1
model_ici_axis_size: 1

optimizer:
learning_rate: 6E-4
Expand Down
106 changes: 106 additions & 0 deletions config/config-markweb-web_comparison-owt_1b.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
data:
cache_dir: "gs://levanter-data/tokenized/markweb_llama/"
tokenizer: "meta-llama/Llama-2-7b-hf"
configs:
openwebtext:
train_urls:
- "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_train.{1..128}-of-128.jsonl.gz"
validation_urls:
- "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_val.{1..8}-of-8.jsonl.gz"
# these are just for eval
"paloma/4chan":
validation_urls:
- gs://levanter-data/paloma/4chan_meta_sep/val/val*.jsonl.gz
"paloma/c4_100_domains":
validation_urls:
- gs://levanter-data/paloma/c4_100_domains/val/val*.jsonl.gz
"paloma/c4_en":
validation_urls:
- gs://levanter-data/paloma/c4_en/val/val*.jsonl.gz
"paloma/dolma-v1_5":
validation_urls:
- gs://levanter-data/paloma/dolma-v1_5/val/val*.jsonl.gz
"paloma/dolma_100_programing_languages":
validation_urls:
- gs://levanter-data/paloma/dolma_100_programing_languages/val/val*.jsonl.gz
"paloma/dolma_100_subreddits":
validation_urls:
- gs://levanter-data/paloma/dolma_100_subreddits/val/val*.jsonl.gz
"paloma/falcon-refinedweb":
validation_urls:
- gs://levanter-data/paloma/falcon-refinedweb/val/val*.jsonl.gz
"paloma/gab":
validation_urls:
- gs://levanter-data/paloma/gab/val/val*.jsonl.gz
"paloma/m2d2_s2orc_unsplit":
validation_urls:
- gs://levanter-data/paloma/m2d2_s2orc_unsplit/val/val*.jsonl.gz
"paloma/m2d2_wikipedia_unsplit":
validation_urls:
- gs://levanter-data/paloma/m2d2_wikipedia_unsplit/val/val*.jsonl.gz
"paloma/manosphere_meta_sep":
validation_urls:
- gs://levanter-data/paloma/manosphere_meta_sep/val/val*.jsonl.gz
"paloma/mc4":
validation_urls:
- gs://levanter-data/paloma/mc4/val/val*.jsonl.gz
"paloma/ptb":
validation_urls:
- gs://levanter-data/paloma/ptb/val/val*.jsonl.gz
"paloma/redpajama":
validation_urls:
- gs://levanter-data/paloma/redpajama/val/val*.jsonl.gz
"paloma/twitterAAE_HELM_fixed":
validation_urls:
- gs://levanter-data/paloma/twitterAAE_HELM_fixed/val/val*.jsonl.gz
"paloma/wikitext_103":
validation_urls:
- gs://levanter-data/paloma/wikitext_103/val/val*.jsonl.gz

train_weights:
openwebtext: 1.0
paloma/4chan: 0.0
paloma/c4_100_domains: 0.0
paloma/c4_en: 0.0
paloma/dolma-v1_5: 0.0
paloma/dolma_100_programing_languages: 0.0
paloma/dolma_100_subreddits: 0.0
paloma/falcon-refinedweb: 0.0
paloma/gab: 0.0
paloma/m2d2_s2orc_unsplit: 0.0
paloma/m2d2_wikipedia_unsplit: 0.0
paloma/manosphere_meta_sep: 0.0
paloma/mc4: 0.0
paloma/ptb: 0.0
paloma/redpajama: 0.0
paloma/twitterAAE_HELM_fixed: 0.0
paloma/wikitext_103: 0.0
model:
# 1B class model
type: llama
seq_len: 2048
hidden_dim: 2048
intermediate_dim: 4096
num_layers: 24
num_heads: 32
num_kv_heads: 32
use_flash_attention: True
flash_attention_block_size: 2048
trainer:
tracker:
type: wandb
project: "markweb"
tags: ["owt", "llama", "web_comparison"]

mp: p=f32,c=bfloat16
train_batch_size: 512
num_train_steps: 50000
steps_per_eval: 1000
per_device_eval_parallelism: 64
tensor_parallel_axes: ["mlp", "heads"]
fsdp_axis: "embed"
batch_axis: "batch"
optimizer:
learning_rate: 2E-4
weight_decay: 0.1
min_lr_ratio: 0.1
2 changes: 1 addition & 1 deletion config/gpt2_1536.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ trainer:
tags: [ "openwebtext", "gpt2"]

mp: p=f32,c=bfloat16
model_axis_size: 1
model_ici_axis_size: 1
per_device_eval_parallelism: 8
optimizer:
learning_rate: 1E-4
Expand Down
2 changes: 1 addition & 1 deletion config/gpt2_1536_sophiah.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ trainer:
tags: [ "openwebtext", "gpt2"]

mp: p=f32,c=bfloat16
model_axis_size: 1
model_ici_axis_size: 1
optimizer:
type: sophia-h
learning_rate: 2E-4
Expand Down
2 changes: 1 addition & 1 deletion config/gpt2_7b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ trainer:

mp: p=f32,c=bfloat16

model_axis_size: 1
model_ici_axis_size: 1
per_device_parallelism: -1
per_device_eval_parallelism: -1

Expand Down
2 changes: 1 addition & 1 deletion config/gpt2_large.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ trainer:
tags: [ "openwebtext", "gpt2"]

mp: p=f32,c=bfloat16
model_axis_size: 1
model_ici_axis_size: 1
per_device_parallelism: -1
optimizer:
learning_rate: 2E-4
Expand Down
2 changes: 1 addition & 1 deletion config/gpt2_medium.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ trainer:
tags: [ "openwebtext", "gpt2"]

mp: p=f32,c=bfloat16
model_axis_size: 1
model_ici_axis_size: 1
optimizer:
learning_rate: 3E-4
weight_decay: 0.1
Expand Down
2 changes: 1 addition & 1 deletion config/gpt2_small.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ trainer:
tags: [ "openwebtext", "gpt2"]

mp: p=f32,c=bfloat16
model_axis_size: 1
model_ici_axis_size: 1
per_device_parallelism: -1

train_batch_size: 512
Expand Down
2 changes: 1 addition & 1 deletion config/gpt2_small_fast.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ trainer:
tags: [ "openwebtext", "gpt2", "itest"]

mp: p=f32,c=bfloat16
model_axis_size: 1
model_ici_axis_size: 1
per_device_parallelism: -1

train_batch_size: 256
Expand Down
2 changes: 1 addition & 1 deletion config/gpt2_small_fast_fp8.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ trainer:

mp: p=f32,c=bfloat16
fp8: true
model_axis_size: 1
model_ici_axis_size: 1
per_device_parallelism: -1

train_batch_size: 256
Expand Down
2 changes: 1 addition & 1 deletion config/gpt2_small_fast_mix.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ trainer:
tags: [ "openwebtext+wiki", "gpt2", "itest"]

mp: p=f32,c=bfloat16
model_axis_size: 1
model_ici_axis_size: 1

train_batch_size: 256
num_train_steps: 20000
Expand Down
2 changes: 1 addition & 1 deletion config/gpt2_small_fast_pile.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ trainer:
tags: [ "pile", "gpt2", "itest"]

mp: p=f32,c=bfloat16
model_axis_size: 1
model_ici_axis_size: 1

train_batch_size: 256
num_train_steps: 20000
Expand Down
2 changes: 1 addition & 1 deletion config/gpt2_small_fast_public.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ trainer:
tags: [ "openwebtext", "gpt2", "itest"]

mp: p=f32,c=bfloat16
model_axis_size: 1
model_ici_axis_size: 1
per_device_parallelism: -1

train_batch_size: 256
Expand Down
2 changes: 1 addition & 1 deletion config/gpt2_small_fast_sophia_h.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ trainer:
tags: [ "openwebtext", "gpt2", "itest", "sophia-h"]

mp: p=f32,c=bfloat16
model_axis_size: 1
model_ici_axis_size: 1

train_batch_size: 256
num_train_steps: 20000
Expand Down
2 changes: 1 addition & 1 deletion config/gpt2_small_fast_sophiah.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ trainer:
tags: [ "openwebtext", "gpt2", "itest"]

mp: p=f32,c=bfloat16
model_axis_size: 1
model_ici_axis_size: 1
per_device_parallelism: -1

train_batch_size: 256
Expand Down
2 changes: 1 addition & 1 deletion config/gpt2_small_fast_wiki.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ trainer:
tags: [ "openwebtext", "gpt2", "itest"]

mp: p=f32,c=bfloat16
model_axis_size: 1
model_ici_axis_size: 1
per_device_parallelism: -1

train_batch_size: 256
Expand Down
2 changes: 1 addition & 1 deletion config/gpt2_small_pile.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ trainer:
tags: [ "pile", "gpt2"]

mp: p=f32,c=bfloat16
model_axis_size: 1
model_ici_axis_size: 1

train_batch_size: 256
num_train_steps: 50000
Expand Down
2 changes: 1 addition & 1 deletion config/gpt2_small_pile_mixture.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ trainer:
tags: [ "pile", "gpt2"]

mp: p=f32,c=bfloat16
model_axis_size: 1
model_ici_axis_size: 1

train_batch_size: 256
num_train_steps: 50000
Expand Down
2 changes: 1 addition & 1 deletion config/gpt2_small_sophiah.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ trainer:
tags: [ "openwebtext", "gpt2", "sophia-h"]

mp: p=f32,c=bfloat16
model_axis_size: 1
model_ici_axis_size: 1

train_batch_size: 512
optimizer: !include optim/sophia-h_small.yaml
2 changes: 1 addition & 1 deletion config/llama2_7b_continued.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ trainer:

mp: p=f32,c=bfloat16

model_axis_size: 1
model_ici_axis_size: 1
per_device_eval_parallelism: 4

train_batch_size: 1024
Expand Down
2 changes: 1 addition & 1 deletion config/llama_small_fast.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ trainer:
tags: [ "openwebtext", "llama", "itest"]

mp: p=f32,c=bfloat16
model_axis_size: 1
model_ici_axis_size: 1
per_device_parallelism: -1

train_batch_size: 256
Expand Down
2 changes: 1 addition & 1 deletion config/lora/mpt_biomed.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ trainer:

mp: p=f32,c=bfloat16

model_axis_size: 1
model_ici_axis_size: 1
per_device_parallelism: 4
per_device_eval_parallelism: 4

Expand Down
2 changes: 1 addition & 1 deletion config/whisper_tiny_librispeech.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ trainer:
tags: [ "librispeech", "whisper"]

mp: p=f32,c=bf16
model_axis_size: 1
model_ici_axis_size: 1
per_device_parallelism: -1

train_batch_size: 128
Expand Down
8 changes: 4 additions & 4 deletions docs/Configuration-Guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ trainer:
tags: [ "openwebtext", "gpt2"]

mp: p=f32,c=bfloat16
model_axis_size: 1
model_ici_axis_size: 1
per_device_parallelism: 4

train_batch_size: 512
Expand Down Expand Up @@ -100,7 +100,7 @@ The following table lists some of the parameters that you might want to change.
| `seed` | The random seed | 0 |
| `num_train_steps` | The number of training steps to run | 400,000 |
| `train_batch_size` | The batch size | 32 |
| `per_device_train_parallelism` | Number of examples to process on each device during training | `train_batch_size / (num_accelerators * model_axis_size)` |
| `per_device_train_parallelism` | Number of examples to process on each device during training | `train_batch_size / (num_accelerators * model_ici_axis_size)` |
| `per_device_eval_parallelism` | Number of examples to process on each device during eval | `per_device_train_parallelism` |
| `steps_per_eval` | How often to evaluate the model during training | 1,000 |
| `max_eval_batches` | How many batches to evaluate during each evaluation | `None` (meaning all) |
Expand Down Expand Up @@ -133,7 +133,7 @@ reasonable defaults and an "advanced" mode that gives you more control.
| `batch_axis` | The axis to shard the batch over, for distributed data parallelism | `"batch"` |
| `fsdp_axis` | The axis or axes to shard the model over, for Fully Sharded Data Parallelism | `"embed"` |
| `tensor_parallel_axes` | The axis or axes to shard the model over, for Tensor Parallelism | `None` |
| `model_axis_size` | How many devices for tensor parallelism | `1` |
| `model_ici_axis_size` | How many devices for tensor parallelism | `1` |

#### Advanced Mode

Expand All @@ -142,7 +142,7 @@ reasonable defaults and an "advanced" mode that gives you more control.
| `axis_resources` | Mapping from logical axis to physical axis shared by both mappings | -- |
| `parameter_axis_resources` | Mapping from logical axis to physical axis for the parameter mapping | -- |
| `compute_axis_resources` | Mapping from logical axis to physical axis for the compute mapping | -- |
| `model_axis_size` | How many devices for tensor parallelism | `1` |
| `model_ici_axis_size` | How many devices for tensor parallelism | `1` |

### Checkpointing and Initialization

Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/Training-On-Audio-Data.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ trainer:
tags: [ "librispeech", "whisper"]
mp: p=f32,c=bf16
model_axis_size: 1
model_ici_axis_size: 1
per_device_parallelism: -1
train_batch_size: 128
Expand Down
Loading

0 comments on commit 79815c0

Please sign in to comment.