Skip to content

Commit

Permalink
enough device puts and we're good
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed Dec 3, 2024
2 parents 6395305 + 9804b34 commit 87d7665
Show file tree
Hide file tree
Showing 90 changed files with 5,097 additions and 2,102 deletions.
8 changes: 5 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,18 @@ Haliax's documentation is available at [haliax.readthedocs.io](https://haliax.re

## Features

* **Distributed Training**: We support distributed training on TPUs (and soon, GPUs), including FSDP and tensor parallelism.
* **Distributed Training**: We support distributed training on TPUs and GPUs, including FSDP and tensor parallelism.
* **Compatibility**: Levanter supports importing and exporting models to/from the Hugging Face ecosystem, including tokenizers, datasets, and models via [SafeTensors](https://github.com/huggingface/safetensors).
* **Performance**: Levanter's performance rivals commercially-backed frameworks like MosaicML's Composer or Google's MaxText.
* **Resilience**: Levanter supports fast, distributed checkpointing and fast resume from checkpoints with no data seek, making Levanter robust to preemption and hardware failure.
* **Cached On-Demand Data Preprocessing**: We preprocess corpora online, but we cache the results of preprocessing so
that resumes are much faster and so that subsequent runs are even faster. As soon as the first part of the cache is complete, Levanter will start training.
* **Optimization**: Levanter supports the new [Sophia](https://arxiv.org/abs/2305.14342) optimizer, which can be 2x as fast as Adam. We also support ses [Optax](https://github.com/deepmind/optax) for optimization with AdamW, etc.
* **Logging**: Levanter supports a few different logging backends, including [WandB](https://wandb.ai/site) and [TensorBoard](https://www.tensorflow.org/tensorboard). (Adding a new logging backend is easy!) Levanter even exposes the ability
* **Logging**: Levanter logs a rich and detailed set of metrics covering loss and performance. Levanter also supports a few different logging backends, including [WandB](https://wandb.ai/site) and [TensorBoard](https://www.tensorflow.org/tensorboard). (Adding a new logging backend is easy!) Levanter even exposes the ability
to log inside of JAX `jit`-ted functions.
* **Reproducibility**: On TPU, Levanter is bitwise deterministic, meaning that the same configuration will always produce the same results, even in the face of preemption and resumption.
* **Distributed Checkpointing**: Distributed checkpointing is supported via Google's [TensorStore](https://google.github.io/tensorstore/) library. Training can even be resumed on a different number of hosts, though this breaks reproducibility for now.
* * **Optimization**: Levanter supports the new [Sophia](https://arxiv.org/abs/2305.14342) optimizer, which can be 2x as fast as Adam. We also support ses [Optax](https://github.com/deepmind/optax) for optimization with AdamW, etc.
* * **Flexible**: Levanter supports tuning data mixtures without having to retokenize or shuffle data.

<!--levanter-intro-end-->

Expand Down
200 changes: 200 additions & 0 deletions config/data/marin_dolma.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
tokenizer: meta-llama/Meta-Llama-3.1-8B
cache_dir: null
cache_options:
batch_size: 128
num_shard_groups: 128
target_size_per_flush: 512MB
configs:
dolma/algebraic-stack:
cache_dir: gs://marin-us-central2/tokenized/dolma/algebraic-stack-cc00cf
train_urls:
- gs://marin-us-central2/raw/dolma/v1.7/algebraic-stack-train-{0000..0015}.json.gz
dolma/arxiv:
cache_dir: gs://marin-us-central2/tokenized/dolma/arxiv-07a51f
train_urls:
- gs://marin-us-central2/raw/dolma/v1.7/arxiv-{0000..0099}.json.gz
validation_urls: []
dolma/c4:
cache_dir: gs://marin-us-central2/tokenized/dolma/c4-e0e5ec
train_urls:
- gs://marin-us-central2/raw/dolma/v1.7/c4-{0000..0170}.json.gz
validation_urls: []
dolma/cc:
cache_dir: gs://marin-us-central2/tokenized/dolma/cc-74b017
train_urls:
- gs://marin-us-central2/raw/dolma/v1.7/cc_en_head-{0000..0274}.json.gz
- gs://marin-us-central2/raw/dolma/v1.7/cc_en_middle-{0000..0238}.json.gz
- gs://marin-us-central2/raw/dolma/v1.7/cc_en_middle-{0240..0379}.json.gz
- gs://marin-us-central2/raw/dolma/v1.7/cc_en_tail-{0000..0152}.json.gz
- gs://marin-us-central2/raw/dolma/v1.7/cc_en_tail-{0154..0444}.json.gz
validation_urls: []
dolma/cc-news:
cache_dir: gs://marin-us-central2/tokenized/dolma/cc-news-625d3e
train_urls:
- gs://marin-us-central2/raw/dolma/v1.7/cc_news_head-{0000..0004}.json.gz
- gs://marin-us-central2/raw/dolma/v1.7/cc_news_middle-{0000..0002}.json.gz
- gs://marin-us-central2/raw/dolma/v1.7/cc_news_tail-0000.json.gz
validation_urls: []
dolma/falcon:
cache_dir: gs://marin-us-central2/tokenized/dolma/falcon-da8fd0
train_urls:
- gs://marin-us-central2/raw/dolma/v1.7/falcon-{0000..0499}.json.gz
validation_urls: []
dolma/flan:
cache_dir: gs://marin-us-central2/tokenized/dolma/flan-a99cb2
train_urls:
- gs://marin-us-central2/raw/dolma/v1.7/tulu_flan-{0000..0065}.json.gz
validation_urls: []
dolma/gutenberg:
cache_dir: gs://marin-us-central2/tokenized/dolma/gutenberg-f9eb99
train_urls:
- gs://marin-us-central2/raw/dolma/v1.7/books-{0000..0002}.json.gz
validation_urls: []
dolma/megawika:
cache_dir: gs://marin-us-central2/tokenized/dolma/megawika-34abf2
train_urls:
- gs://marin-us-central2/raw/dolma/v1.7/megawika-{0000..0261}.json.gz
validation_urls: []
dolma/open-web-math:
cache_dir: gs://marin-us-central2/tokenized/dolma/open-web-math-79823d
train_urls:
- gs://marin-us-central2/raw/dolma/v1.7/open-web-math-train-{0000..0012}.json.gz
validation_urls: []
dolma/pes2o:
cache_dir: gs://marin-us-central2/tokenized/dolma/pes2o-538363
train_urls:
- gs://marin-us-central2/raw/dolma/v1.7/pes2o-{0000..0025}.json.gz
validation_urls: []
dolma/reddit:
cache_dir: gs://marin-us-central2/tokenized/dolma/reddit-62a64a
train_urls:
- gs://marin-us-central2/raw/dolma/v1.7/reddit-{0000..0077}.json.gz
validation_urls: []
dolma/stackexchange:
cache_dir: gs://marin-us-central2/tokenized/dolma/stackexchange-adfc49
train_urls:
- gs://marin-us-central2/raw/dolma/v1.7/stackexchange-{0000..0025}.json.gz
validation_urls: []
dolma/starcoder:
cache_dir: gs://marin-us-central2/tokenized/dolma/starcoder-8b6089
train_urls:
- gs://marin-us-central2/raw/dolma/v1.7/starcoder-{0000..0048}.json.gz
validation_urls: []
dolma/wiki:
cache_dir: gs://marin-us-central2/tokenized/dolma/wiki-212315
train_urls:
- gs://marin-us-central2/raw/dolma/v1.7/wiki-{0000..0001}.json.gz
validation_urls: []
paloma/4chan:
cache_dir: gs://marin-us-central2/tokenized/paloma/4chan-48513a
train_urls: []
validation_urls:
- gs://marin-us-central2/raw/paloma-fc6827/65cd6fc/4chan_meta_sep/val/val*.jsonl.gz
paloma/c4_100_domains:
cache_dir: gs://marin-us-central2/tokenized/paloma/c4_100_domains-96277e
validation_urls:
- gs://marin-us-central2/raw/paloma-fc6827/65cd6fc/c4_100_domains/val/val*.jsonl.gz
paloma/c4_en:
cache_dir: gs://marin-us-central2/tokenized/paloma/c4_en-21e708
train_urls: []
validation_urls:
- gs://marin-us-central2/raw/paloma-fc6827/65cd6fc/c4_en/val/val*.jsonl.gz
paloma/dolma-v1_5:
cache_dir: gs://marin-us-central2/tokenized/paloma/dolma-v1_5-ed8c0b
train_urls: []
validation_urls:
- gs://marin-us-central2/raw/paloma-fc6827/65cd6fc/dolma-v1_5/val/val*.jsonl.gz
paloma/dolma_100_programing_languages:
cache_dir: gs://marin-us-central2/tokenized/paloma/dolma_100_programing_languages-3eb825
train_urls: []
validation_urls:
- gs://marin-us-central2/raw/paloma-fc6827/65cd6fc/dolma_100_programing_languages/val/val*.jsonl.gz
paloma/dolma_100_subreddits:
cache_dir: gs://marin-us-central2/tokenized/paloma/dolma_100_subreddits-2381a1
train_urls: []
validation_urls:
- gs://marin-us-central2/raw/paloma-fc6827/65cd6fc/dolma_100_subreddits/val/val*.jsonl.gz
paloma/falcon-refinedweb:
cache_dir: gs://marin-us-central2/tokenized/paloma/falcon-refinedweb-fea8c8
train_urls: []
validation_urls:
- gs://marin-us-central2/raw/paloma-fc6827/65cd6fc/falcon-refinedweb/val/val*.jsonl.gz
paloma/gab:
cache_dir: gs://marin-us-central2/tokenized/paloma/gab-487af2
train_urls: []
validation_urls:
- gs://marin-us-central2/raw/paloma-fc6827/65cd6fc/gab/val/val*.jsonl.gz
paloma/m2d2_s2orc_unsplit:
cache_dir: gs://marin-us-central2/tokenized/paloma/m2d2_s2orc_unsplit-9fb6dc
train_urls: []
validation_urls:
- gs://marin-us-central2/raw/paloma-fc6827/65cd6fc/m2d2_s2orc_unsplit/val/val*.jsonl.gz
paloma/m2d2_wikipedia_unsplit:
cache_dir: gs://marin-us-central2/tokenized/paloma/m2d2_wikipedia_unsplit-0fdb36
train_urls: []
validation_urls:
- gs://marin-us-central2/raw/paloma-fc6827/65cd6fc/m2d2_wikipedia_unsplit/val/val*.jsonl.gz
paloma/manosphere_meta_sep:
cache_dir: gs://marin-us-central2/tokenized/paloma/manosphere_meta_sep-bb4a78
train_urls: []
validation_urls:
- gs://marin-us-central2/raw/paloma-fc6827/65cd6fc/manosphere_meta_sep/val/val*.jsonl.gz
paloma/mc4:
cache_dir: gs://marin-us-central2/tokenized/paloma/mc4-8ead3d
train_urls: []
validation_urls:
- gs://marin-us-central2/raw/paloma-fc6827/65cd6fc/mc4/val/val*.jsonl.gz
paloma/ptb:
cache_dir: gs://marin-us-central2/tokenized/paloma/ptb-4559d1
train_urls: []
validation_urls:
- gs://marin-us-central2/raw/paloma-fc6827/65cd6fc/ptb/val/val*.jsonl.gz
paloma/redpajama:
cache_dir: gs://marin-us-central2/tokenized/paloma/redpajama-f43c27
train_urls: []
validation_urls:
- gs://marin-us-central2/raw/paloma-fc6827/65cd6fc/redpajama/val/val*.jsonl.gz
paloma/twitterAAE_HELM_fixed:
cache_dir: gs://marin-us-central2/tokenized/paloma/twitterAAE_HELM_fixed-b346ca
train_urls: []
validation_urls:
- gs://marin-us-central2/raw/paloma-fc6827/65cd6fc/twitterAAE_HELM_fixed/val/val*.jsonl.gz
paloma/wikitext_103:
cache_dir: gs://marin-us-central2/tokenized/paloma/wikitext_103-631615
train_urls: []
validation_urls:
- gs://marin-us-central2/raw/paloma-fc6827/65cd6fc/wikitext_103/val/val*.jsonl.gz
shuffle: true
stop_strategy: restart
train_weights:
dolma/algebraic-stack: 12.6
dolma/arxiv: 28.0
dolma/c4: 124.95
dolma/cc: 597.75
dolma/cc-news: 14.3
dolma/falcon: 456.4
dolma/flan: 16.5
dolma/gutenberg: 5.3
dolma/megawika: 4.6
dolma/open-web-math: 12.6
dolma/pes2o: 57.2
dolma/reddit: 79.9
dolma/stackexchange: 19.6
dolma/starcoder: 263.8
dolma/wiki: 7.4
paloma/4chan: 0.0
paloma/c4_100_domains: 0.0
paloma/c4_en: 0.0
paloma/dolma-v1_5: 0.0
paloma/dolma_100_programing_languages: 0.0
paloma/dolma_100_subreddits: 0.0
paloma/falcon-refinedweb: 0.0
paloma/gab: 0.0
paloma/m2d2_s2orc_unsplit: 0.0
paloma/m2d2_wikipedia_unsplit: 0.0
paloma/manosphere_meta_sep: 0.0
paloma/mc4: 0.0
paloma/ptb: 0.0
paloma/redpajama: 0.0
paloma/twitterAAE_HELM_fixed: 0.0
paloma/wikitext_103: 0.0
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,17 @@ data:
tokenizer: gpt2
cache_dir: "gs://levanter-data/tokenized/data_mix"
supervised_data:
validation_urls:
- "gs://marin-us-central2/benchmarks/mmlu/mmlu-*-dev-evaluation.jsonl.gz"
cache_dir: "gs://marin-us-central2/benchmarks/tokenized-gpt2/mmlu/"
mmlu:
validation_urls:
- "gs://marin-us-central2/evaluation/mmlu-eval-subject-2eb39e/cais/*-validation-evaluation.jsonl.gz"
cache_dir: "gs://levanter-data/tokenized-gpt2/mmlu/"
tags: [ "e"]
arc_easy:
validation_urls:
- "gs://marin-us-central2/evaluation/arc-easy-b39e70/allenai/ai2_arc-ARC-Easy-validation-evaluation.jsonl.gz"
cache_dir: "gs://levanter-data/tokenized-gpt2/arc_easy/"
tags: [ "arc", "e"]

model:
type: gpt2
hidden_dim: 768
Expand Down
2 changes: 1 addition & 1 deletion config/gpt2_small_fast_pile.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
data: !include data/pile_source_old.yaml
data: !include data/pile_mixture.yaml
model:
type: gpt2
hidden_dim: 768
Expand Down
32 changes: 32 additions & 0 deletions config/llama3_small_fast.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
data:
train_urls:
- "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_train.{1..128}-of-128.jsonl.gz"
validation_urls:
- "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_val.{1..8}-of-8.jsonl.gz"
cache_dir: "gs://levanter-data/tokenized/openwebtext_llama3/"
tokenizer: "nvidia/Llama-3.1-Nemotron-70B-Instruct-HF"
model:
type: llama
hidden_dim: 768
intermediate_dim: 2048
num_heads: 12
num_kv_heads: 12
num_layers: 12
seq_len: 1024
gradient_checkpointing: true
trainer:
tracker:
- type: wandb
project: "levanter"
tags: [ "openwebtext", "llama", "itest"]

mp: p=f32,c=bfloat16
model_axis_size: 1
per_device_parallelism: -1

train_batch_size: 256
num_train_steps: 20000
optimizer:
learning_rate: 1E-3
weight_decay: 0.1
warmup: 0.01
39 changes: 39 additions & 0 deletions config/llama_7b_tulu.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
data:
train_urls:
- "gs://marin-us-central2/documents/instruct/tulu_v2_mix/text/tulu-v2-sft-mixture-000.jsonl.gz"
- "gs://marin-us-central2/documents/instruct/tulu_v2_mix/text/tulu-v2-sft-mixture-001.jsonl.gz"
- "gs://marin-us-central2/documents/instruct/tulu_v2_mix/text/tulu-v2-sft-mixture-002.jsonl.gz"
cache_dir: "gs://marin-us-central2/tokenized/OLMo-1B/tuluv2_sft/"
tokenizer: "allenai/OLMo-1B"
model: # 7B class model
type: llama
seq_len: 2048
hidden_dim: 4096
intermediate_dim: 11008
num_layers: 32
num_heads: 32
num_kv_heads: 32
use_flash_attention: True
flash_attention_block_size: 512
use_bias: false
use_layer_norm_weight: false
trainer:
tracker:
type: wandb
project: "marin"
tags: ["dolma", "olmo", "llama"]

mp: p=f32,c=bfloat16
train_batch_size: 256
num_train_steps: 750000 # 3,000,000,000,000 / 4,000,000 = 750,000
steps_per_eval: 1000
tensor_parallel_axes: ["mlp", "heads"]
fsdp_axis: "embed"
batch_axis: "batch"
optimizer:
learning_rate: 4E-4
weight_decay: 0.1
min_lr_ratio: 0.1
warmup: 5000

epoch: 3
2 changes: 1 addition & 1 deletion config/llama_7b_with_dclm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ trainer:

mp: p=f32,c=bfloat16
train_batch_size: 2048
num_train_steps: 70000 # 280B / 4M
num_train_steps: 480000 # 2T / 4M
steps_per_eval: 1000
tensor_parallel_axes: ["mlp", "heads"]
fsdp_axis: "embed"
Expand Down
13 changes: 13 additions & 0 deletions config/llama_sft_hf_ckpt.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Model configuration
model:
type: llama
seq_len: 2048
hidden_dim: 4096
intermediate_dim: 11008
num_layers: 32
num_heads: 32
num_kv_heads: 32
use_flash_attention: true
flash_attention_block_size: 512
use_bias: false
use_layer_norm_weight: false
3 changes: 2 additions & 1 deletion docker/tpu/Dockerfile.base
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ RUN pip install virtualenv
# venv binaries encode their directory, so we need to setup the venv in the final location
RUN virtualenv -p python3.10 /opt/levanter/.venv
ENV PATH /opt/levanter/.venv/bin:$PATH
RUN /opt/levanter/.venv/bin/pip install -U "jax[tpu]==0.4.30" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
#RUN /opt/levanter/.venv/bin/pip install -U "jax[tpu]==0.4.34" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
RUN /opt/levanter/.venv/bin/pip install -U "jax[tpu]@git+https://github.com/dlwh/jax@retry_refuse" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

# Install package dependencies to make incremental builds faster.
WORKDIR /tmp/
Expand Down
Loading

0 comments on commit 87d7665

Please sign in to comment.