Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into dclm
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed Oct 14, 2024
2 parents 8ecb7ea + 02f34ac commit 0ea3eb4
Show file tree
Hide file tree
Showing 133 changed files with 10,270 additions and 5,177 deletions.
2 changes: 2 additions & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

scratch
cache
new-cache
wandb
checkpoints

Expand Down Expand Up @@ -116,3 +117,4 @@ dmypy.json

# local execution commands
local_*.sh
.aider*
9 changes: 6 additions & 3 deletions .github/workflows/docker-base-image.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
name: Build and Push Docker TPU Images

on:
push:
branches:
- main
workflow_run:
workflows: ["Run Tests"]
types:
- completed
branches: [main]
workflow_dispatch:

jobs:
build:
Expand Down
72 changes: 72 additions & 0 deletions .github/workflows/launch_small_fast.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
name: Launch Llama 2 Small Fast

on:
workflow_run:
workflows: ["Build and Push Docker TPU Images"]
types:
- completed
branches: [main, "experiment/*"]
# pull_request:
workflow_dispatch:

jobs:
test:
if: (github.event.pull_request.head.repo.full_name == github.repository)
runs-on: ubuntu-latest
env:
TPU_ZONE: "us-central2-b"
TPU_TYPE: "v4-32"

steps:
- name: Checkout code
uses: actions/checkout@v2

- name: Set up Google Cloud SDK
uses: google-github-actions/setup-gcloud@v1
with:
project_id: ${{ secrets.GCP_PROJECT_ID }}

- name: Authenticate to Google Cloud
uses: google-github-actions/auth@v1
with:
credentials_json: ${{ secrets.GCP_SA_KEY }}

- name: Configure Google Cloud
run: |
gcloud config set project ${{ secrets.GCP_PROJECT_ID }}
REGION=${TPU_ZONE%-*}
echo "$REGION"
gcloud auth configure-docker $REGION-docker.pkg.dev
- name: Install locally
run: |
python -m pip install --upgrade pip
pip install -e .[test] "jax[cpu]==0.4.30"
- name: Launch Small Fast TPU Train LM job
run: |
export TPU_NAME=small-fast-${{ github.run_id }}
export WANDB_API_KEY=${{ secrets.WANDB_API_KEY }}
export RUN_ID=small_fast_${{ github.run_id }}
export HF_TOKEN=${{ secrets.HF_TOKEN }}
cat > .config <<EOF
env:
WANDB_API_KEY: ${{ secrets.WANDB_API_KEY }}
WANDB_ENTITY: stanford-mercury
WANDB_PROJECT: levanter
HF_TOKEN: ${{ secrets.HF_TOKEN }}
EOF
python infra/launch.py -e CI 1 --foreground --tpu_name ${TPU_NAME} --run_id $RUN_ID --zone ${TPU_ZONE} --tpu_type ${TPU_TYPE} --preemptible -- \
python -m levanter.main.train_lm \
--config_path config/llama_small_fast.yaml \
--trainer.checkpointer.base_path gs://levanter-checkpoints/llama-itest/ \
--trainer.checkpointer.save_interval 10m
--trainer.num_train_steps 10000
- name: Cleanup
if: ${{ always() }}
run: |
export TPU_NAME=small-fast-${{ github.run_id }}
gcloud compute tpus queued-resources delete $TPU_NAME --zone ${TPU_ZONE} --quiet --force
2 changes: 1 addition & 1 deletion .github/workflows/run_entry_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install flake8 pytest
pip install . "jax[cpu]==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}"
pip install .[test] "jax[cpu]==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}"
pip install soundfile librosa
- name: Run entry tests with pytest
run: |
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/run_ray_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install flake8 pytest
pip install . "jax[cpu]==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}"
pip install .[test] "jax[cpu]==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}"
pip install soundfile librosa
- name: Run ray tests with pytest
run: |
XLA_FLAGS=--xla_force_host_platform_device_count=8 PYTHONPATH=$(pwd)/tests:$(pwd)/src:$(pwd):. pytest tests -m ray
PYTHONPATH=$(pwd)/tests:$(pwd)/src:$(pwd):. pytest tests -m ray
2 changes: 1 addition & 1 deletion .github/workflows/run_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install . "jax[cpu]==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}"
pip install .[test] "jax[cpu]==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}"
pip install -r ./tests/requirements.txt
- name: Test with pytest
run: |
Expand Down
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -150,5 +150,9 @@ ledger.json
/checkpoints
*.jaxpr

# local execution commands
local_*.sh

# aider
.aider*

.benchmarks
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,4 @@ repos:
hooks:
- id: mypy
args: [--ignore-missing-imports]
additional_dependencies: [wandb, types-PyYAML]
additional_dependencies: [wandb==0.17.8, types-PyYAML]
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,8 @@ Please see the [CUDA Getting Started](docs/Getting-Started-GPU.md) guide for mor

## Contributing

[![GitHub repo Good Issues for newbies](https://img.shields.io/github/issues/stanford-crfm/levanter/good%20first%20issue?style=flat&logo=github&logoColor=green&label=Good%20First%20issues)](https://github.com/stanford-crfm/levanter/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22) [![GitHub Help Wanted issues](https://img.shields.io/github/issues/stanford-crfm/levanter/help%20wanted?style=flat&logo=github&logoColor=b545d1&label=%22Help%20Wanted%22%20issues)](https://github.com/stanford-crfm/levanter/issues?q=is%3Aopen+is%3Aissue+label%3A%22help+wanted%22) [![GitHub Help Wanted PRs](https://img.shields.io/github/issues-pr/stanford-crfm/levanter/help%20wanted?style=flat&logo=github&logoColor=b545d1&label=%22Help%20Wanted%22%20PRs)](https://github.com/stanford-crfm/levanter/pulls?q=is%3Aopen+is%3Aissue+label%3A%22help+wanted%22) [![GitHub repo Issues](https://img.shields.io/github/issues/stanford-crfm/levanter?style=flat&logo=github&logoColor=red&label=Issues)](https://github.com/stanford-crfm/levanter/issues?q=is%3Aopen)

We welcome contributions! Please see [CONTRIBUTING.md](CONTRIBUTING.md) for more information.

## License
Expand Down
9 changes: 6 additions & 3 deletions config/data/dclm_gpt_neo.yaml
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
cache_dir: "gs://marin-data/tokenized/dclm/gpt_neo_tokenizer"
cache_dir: "gs://marin-us-central2/tokenized/gpt_neox/"
tokenizer: "EleutherAI/gpt-neox-20b"
cache_options:
batch_size: 256
num_shard_groups: 1024
stop_strategy: restart
shuffle_buffer_size: 100000
shuffle: 100000
configs:
"dclm":
train_urls:
- gs://marin-data/datacomp/dclm-baseline-dedup-07-09/*/*/*.jsonl.zstd
- gs://marin-us-central2/raw/dclm/v2024-07-09-baseline-dedup/**/*.zstd
# these are just for eval
"paloma/4chan":
validation_urls:
Expand Down
44 changes: 22 additions & 22 deletions config/data/dolma_olmo_paloma.yaml
Original file line number Diff line number Diff line change
@@ -1,59 +1,59 @@
cache_dir: "gs://marin-data/tokenized/OLMo-1B/dolma-v1.7"
cache_dir: "gs://marin-us-central2/tokenized/OLMo-1B/dolma/v1.7"
tokenizer: "allenai/OLMo-1B" # requires `pip install ai2-olmo`
# tokenizer: "meta-llama/Llama-2-7b-hf"
stop_strategy: restart
configs:
dolma-algebraic-stack:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/algebraic-stack-train-{0000..0015}.json.gz
- gs://marin-us-central2/raw/dolma/v1.7/algebraic-stack-train-{0000..0015}.json.gz
dolma-arxiv:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/arxiv-{0000..0099}.json.gz
- gs://marin-us-central2/raw/dolma/v1.7/arxiv-{0000..0099}.json.gz
dolma-gutenberg:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/books-{0000..0002}.json.gz
- gs://marin-us-central2/raw/dolma/v1.7/books-{0000..0002}.json.gz
dolma-c4:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/c4-{0000..0170}.json.gz
- gs://marin-us-central2/raw/dolma/v1.7/c4-{0000..0170}.json.gz
dolma-cc:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/cc_en_head-{0000..0274}.json.gz
- gs://marin-data/raw/dolma/dolma-v1.7/cc_en_middle-{0000..0238}.json.gz # 239 is missing
- gs://marin-data/raw/dolma/dolma-v1.7/cc_en_middle-{0240..0379}.json.gz
- gs://marin-data/raw/dolma/dolma-v1.7/cc_en_tail-{0000..0152}.json.gz # 153 is missing
- gs://marin-data/raw/dolma/dolma-v1.7/cc_en_tail-{0154..0444}.json.gz
- gs://marin-us-central2/raw/dolma/v1.7/cc_en_head-{0000..0274}.json.gz
- gs://marin-us-central2/raw/dolma/v1.7/cc_en_middle-{0000..0238}.json.gz # 239 is missing
- gs://marin-us-central2/raw/dolma/v1.7/cc_en_middle-{0240..0379}.json.gz
- gs://marin-us-central2/raw/dolma/v1.7/cc_en_tail-{0000..0152}.json.gz # 153 is missing
- gs://marin-us-central2/raw/dolma/v1.7/cc_en_tail-{0154..0444}.json.gz
dolma-cc-news:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/cc_news_head-{0000..0004}.json.gz
- gs://marin-data/raw/dolma/dolma-v1.7/cc_news_middle-{0000..0002}.json.gz
- gs://marin-data/raw/dolma/dolma-v1.7/cc_news_tail-0000.json.gz
- gs://marin-us-central2/raw/dolma/v1.7/cc_news_head-{0000..0004}.json.gz
- gs://marin-us-central2/raw/dolma/v1.7/cc_news_middle-{0000..0002}.json.gz
- gs://marin-us-central2/raw/dolma/v1.7/cc_news_tail-0000.json.gz
dolma-falcon:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/falcon-{0000..0499}.json.gz
- gs://marin-us-central2/raw/dolma/v1.7/falcon-{0000..0499}.json.gz
dolma-megawika:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/megawika-{0000..0261}.json.gz
- gs://marin-us-central2/raw/dolma/v1.7/megawika-{0000..0261}.json.gz
dolma-owmath:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/open-web-math-train-{0000..0012}.json.gz
- gs://marin-us-central2/raw/dolma/v1.7/open-web-math-train-{0000..0012}.json.gz
dolma-pes2o:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/pes2o-{0000..0025}.json.gz
- gs://marin-us-central2/raw/dolma/v1.7/pes2o-{0000..0025}.json.gz
dolma-reddit:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/reddit-{0000..0077}.json.gz
- gs://marin-us-central2/raw/dolma/v1.7/reddit-{0000..0077}.json.gz
dolma-stackexchange:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/stackexchange-{0000..0025}.json.gz
- gs://marin-us-central2/raw/dolma/v1.7/stackexchange-{0000..0025}.json.gz
dolma-starcoder:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/starcoder-{0000..0048}.json.gz
- gs://marin-us-central2/raw/dolma/v1.7/starcoder-{0000..0048}.json.gz
dolma-flan:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/tulu_flan-{0000..0065}.json.gz
- gs://marin-us-central2/raw/dolma/v1.7/tulu_flan-{0000..0065}.json.gz
dolma-wiki:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/wiki-{0000..0001}.json.gz
- gs://marin-us-central2/raw/dolma/v1.7/wiki-{0000..0001}.json.gz
# these are just for eval
"paloma/4chan":
validation_urls:
Expand Down
3 changes: 3 additions & 0 deletions config/data/openwebtext_source.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,6 @@ validation_urls:
- "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_val.{1..8}-of-8.jsonl.gz"
cache_dir: "gs://levanter-data/tokenized/openwebtext/"
tokenizer: "gpt2"
cache_options:
batch_size: 1024
num_shard_groups: 64
15 changes: 9 additions & 6 deletions config/data/pile_mixture.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
cache_dir: "gs://levanter-data/tokenized/pile-domains/"
tokenizer: "EleutherAI/gpt-neox-20b"
cache_options:
batch_size: 32
num_shard_groups: 16
configs:
arxiv:
train_urls:
Expand All @@ -11,11 +14,11 @@ configs:
- gs://levanter-data/pile-domains/books2/{00..29}.jsonl.zst
validation_urls:
- gs://levanter-data/pile-domains/books2/val.jsonl.zst
books3:
train_urls:
- gs://levanter-data/pile-domains/books3/{00..29}.jsonl.zst
validation_urls:
- gs://levanter-data/pile-domains/books3/val.jsonl.zst
# books3:
# train_urls:
# - gs://levanter-data/pile-domains/books3/{00..29}.jsonl.zst
# validation_urls:
# - gs://levanter-data/pile-domains/books3/val.jsonl.zst
dm_math:
train_urls:
- gs://levanter-data/pile-domains/dm_math/{00..29}.jsonl.zst
Expand Down Expand Up @@ -115,7 +118,7 @@ train_weights:
# these weights come from the paper https://arxiv.org/pdf/2101.00027.pdf
pile_cc: 0.1811
pubmed_central: 0.1440
books3: 0.1207
# books3: 0.1207
owt2: 0.1001
arxiv: 0.0896
github: 0.0759
Expand Down
1 change: 0 additions & 1 deletion config/data/redpajama_1b_source.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,3 @@ cache_dir: gs://levanter-data/tokenized/redpajama-sample/
tokenizer: EleutherAI/gpt-neox-20b
splits:
- train
rows_per_chunk: 32768
1 change: 0 additions & 1 deletion config/data/redpajama_1t_source.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,3 @@ cache_dir: gs://levanter-data/tokenized/redpajama/
tokenizer: EleutherAI/gpt-neox-20b
splits:
- train
rows_per_chunk: 4096
1 change: 0 additions & 1 deletion config/data/rpv1_llama.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
cache_dir: gs://levanter-data/tokenized/redpajama_v1_llama_mixture
rows_per_chunk: 4096
tokenizer: "meta-llama/Llama-2-7b-hf"
configs:
arxiv:
Expand Down
3 changes: 2 additions & 1 deletion config/gpt2_nano_mixture.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@ data:
id: dlwh/wikitext_103_detokenized
w2:
id: dlwh/wikitext_103_detokenized
cache_dir: wikitext2_cache
train_weights:
wikitext: 1.0
w2: 0
w2: 1.0
model:
type: gpt2
hidden_dim: 32
Expand Down
3 changes: 3 additions & 0 deletions config/gpt2_small_fast.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ trainer:

train_batch_size: 256
num_train_steps: 20000

# tensor_parallel_axes: ["position", "key_position"]
# tensor_parallel_axes: ["heads", "mlp"]
optimizer:
learning_rate: 1E-3
weight_decay: 0.1
Expand Down
40 changes: 40 additions & 0 deletions config/gpt2_small_fast_supervised.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
data:
configs:
owt:
train_urls:
- "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_train.{1..128}-of-128.jsonl.gz"
validation_urls:
- "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_val.{1..8}-of-8.jsonl.gz"
wikitext:
id: dlwh/wikitext_103_detokenized
train_weights:
owt: 0.6
wikitext: 0.4
tokenizer: gpt2
cache_dir: "gs://levanter-data/tokenized/data_mix"
supervised_data:
validation_urls:
- "gs://marin-us-central2/benchmarks/mmlu/mmlu-*-dev-evaluation.jsonl.gz"
cache_dir: "gs://marin-us-central2/benchmarks/tokenized-gpt2/mmlu/"
model:
type: gpt2
hidden_dim: 768
num_heads: 12
num_layers: 12
seq_len: 1024
gradient_checkpointing: true
scale_attn_by_inverse_layer_idx: true
trainer:
tracker:
project: "levanter"
tags: [ "openwebtext+wiki", "gpt2", "itest"]

mp: p=f32,c=bfloat16
model_axis_size: 1

train_batch_size: 256
num_train_steps: 20000
optimizer:
learning_rate: 1E-3
weight_decay: 0.1
warmup: 0.01
Loading

0 comments on commit 0ea3eb4

Please sign in to comment.