Skip to content

Commit

Permalink
Merge branch 'main' into psdp
Browse files Browse the repository at this point in the history
  • Loading branch information
blahBlahhhJ committed May 27, 2024
2 parents 24fbf37 + 81ba8c0 commit 4a63e09
Show file tree
Hide file tree
Showing 50 changed files with 1,450 additions and 400 deletions.
3 changes: 1 addition & 2 deletions .github/workflows/run_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install flake8 pytest
pip install . "jax[cpu]==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}"
pip install soundfile librosa
pip install -r ./tests/requirements.txt
- name: Test with pytest
run: |
XLA_FLAGS=--xla_force_host_platform_device_count=8 PYTHONPATH=tests:src:. pytest tests -m "not entry and not slow and not ray"
54 changes: 54 additions & 0 deletions .github/workflows/tpu_unit_tests.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
name: CI with GCP TPU

on: [pull_request]

jobs:
test:
runs-on: ubuntu-latest
env:
TPU_ZONE: "us-central2-b"

steps:
- name: Checkout code
uses: actions/checkout@v2

- name: Set up Google Cloud SDK
uses: google-github-actions/setup-gcloud@v1
with:
project_id: ${{ secrets.GCP_PROJECT_ID }}

- name: Authenticate to Google Cloud
uses: google-github-actions/auth@v1
with:
credentials_json: ${{ secrets.GCP_SA_KEY }}

- name: Configure Google Cloud
run: |
gcloud config set project ${{ secrets.GCP_PROJECT_ID }}
- name: Create VM
run: |
export TPU_NAME=ci-run-${{ github.run_id }}
eval "$(ssh-agent -s)"
TRUE_SHA=${{ github.event.pull_request.head.sha }}
bash infra/spin-up-vm.sh $TPU_NAME -z ${TPU_ZONE} -t v4-8 --preemptible -s infra/helpers/setup-tpu-vm-tests.sh -b ${TRUE_SHA} --retries 1
# infra/babysit-tpu-vm.sh $TPU_NAME -z ${{ TPU_ZONE }} -t v4-8 --preemptible -s infra/helpers/setup-tpu-vm-tests.sh -b ${{ github.sha }} --retries 1 -- \
# PYTHONPATH=$PYTHONPATH:levanter/tests bash levanter/infra/run.sh pytest levanter/tests -m "not entry"

- name: Run most tests
run: |
export TPU_NAME=ci-run-${{ github.run_id }}
gcloud compute tpus tpu-vm ssh $TPU_NAME --zone ${TPU_ZONE} --command "PYTHONPATH=$PYTHONPATH:levanter/tests bash levanter/infra/run.sh pytest levanter/tests -m 'not entry'"
# Something's wrong with these
#
# - name: Run forked tests
# run: |
# export TPU_NAME=ci-run-${{ github.run_id }}
# gcloud compute tpus tpu-vm ssh $TPU_NAME --zone ${TPU_ZONE} --command "PYTHONPATH=$PYTHONPATH:levanter/tests bash levanter/infra/run.sh pytest --forked levanter/tests -m 'entry'"
#
- name: Cleanup
if: ${{ always() }}
run: |
export TPU_NAME=ci-run-${{ github.run_id }}
echo gcloud compute tpus tpu-vm delete $TPU_NAME --zone ${TPU_ZONE} --quiet
gcloud compute tpus tpu-vm delete $TPU_NAME --zone ${TPU_ZONE} --quiet
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ dmypy.json
# JetBrains
.idea/

# vscode
.vscode

# Wandb stuff
/wandb
Expand Down
138 changes: 138 additions & 0 deletions config/data/dolma_olmo_paloma.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
cache_dir: "gs://marin-data/tokenized/OLMo-1B/dolma-v1.7"
tokenizer: "allenai/OLMo-1B" # requires `pip install ai2-olmo`
# tokenizer: "meta-llama/Llama-2-7b-hf"
stop_strategy: all_exhausted
configs:
dolma-algebraic-stack:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/algebraic-stack-train-{0000..0015}.json.gz
dolma-arxiv:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/arxiv-{0000..0099}.json.gz
dolma-gutenberg:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/books-{0000..0002}.json.gz
dolma-c4:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/c4-{0000..0170}.json.gz
dolma-cc:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/cc_en_head-{0000..0274}.json.gz
- gs://marin-data/raw/dolma/dolma-v1.7/cc_en_middle-{0000..0238}.json.gz # 239 is missing
- gs://marin-data/raw/dolma/dolma-v1.7/cc_en_middle-{0240..0379}.json.gz
- gs://marin-data/raw/dolma/dolma-v1.7/cc_en_tail-{0000..0152}.json.gz # 153 is missing
- gs://marin-data/raw/dolma/dolma-v1.7/cc_en_tail-{0154..0444}.json.gz
dolma-cc-news:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/cc_news_head-{0000..0004}.json.gz
- gs://marin-data/raw/dolma/dolma-v1.7/cc_news_middle-{0000..0002}.json.gz
- gs://marin-data/raw/dolma/dolma-v1.7/cc_news_tail-0000.json.gz
dolma-falcon:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/falcon-{0000..0499}.json.gz
dolma-megawika:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/megawika-{0000..0261}.json.gz
dolma-owmath:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/open-web-math-train-{0000..0012}.json.gz
dolma-pes2o:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/pes2o-{0000..0025}.json.gz
dolma-reddit:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/reddit-{0000..0077}.json.gz
dolma-stackexchange:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/stackexchange-{0000..0025}.json.gz
dolma-starcoder:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/starcoder-{0000..0048}.json.gz
dolma-flan:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/tulu_flan-{0000..0065}.json.gz
dolma-wiki:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/wiki-{0000..0001}.json.gz
# these are just for eval
"paloma/4chan":
validation_urls:
- gs://levanter-data/paloma/4chan_meta_sep/val/val*.jsonl.gz
"paloma/c4_100_domains":
validation_urls:
- gs://levanter-data/paloma/c4_100_domains/val/val*.jsonl.gz
"paloma/c4_en":
validation_urls:
- gs://levanter-data/paloma/c4_en/val/val*.jsonl.gz
"paloma/dolma-v1_5":
validation_urls:
- gs://levanter-data/paloma/dolma-v1_5/val/val*.jsonl.gz
"paloma/dolma_100_programing_languages":
validation_urls:
- gs://levanter-data/paloma/dolma_100_programing_languages/val/val*.jsonl.gz
"paloma/dolma_100_subreddits":
validation_urls:
- gs://levanter-data/paloma/dolma_100_subreddits/val/val*.jsonl.gz
"paloma/falcon-refinedweb":
validation_urls:
- gs://levanter-data/paloma/falcon-refinedweb/val/val*.jsonl.gz
"paloma/gab":
validation_urls:
- gs://levanter-data/paloma/gab/val/val*.jsonl.gz
"paloma/m2d2_s2orc_unsplit":
validation_urls:
- gs://levanter-data/paloma/m2d2_s2orc_unsplit/val/val*.jsonl.gz
"paloma/m2d2_wikipedia_unsplit":
validation_urls:
- gs://levanter-data/paloma/m2d2_wikipedia_unsplit/val/val*.jsonl.gz
"paloma/manosphere_meta_sep":
validation_urls:
- gs://levanter-data/paloma/manosphere_meta_sep/val/val*.jsonl.gz
"paloma/mc4":
validation_urls:
- gs://levanter-data/paloma/mc4/val/val*.jsonl.gz
"paloma/ptb":
validation_urls:
- gs://levanter-data/paloma/ptb/val/val*.jsonl.gz
"paloma/redpajama":
validation_urls:
- gs://levanter-data/paloma/redpajama/val/val*.jsonl.gz
"paloma/twitterAAE_HELM_fixed":
validation_urls:
- gs://levanter-data/paloma/twitterAAE_HELM_fixed/val/val*.jsonl.gz
"paloma/wikitext_103":
validation_urls:
- gs://levanter-data/paloma/wikitext_103/val/val*.jsonl.gz
train_weights:
# sampling proportion comes from https://huggingface.co/datasets/allenai/dolma
dolma-algebraic-stack: 12.6 # 12.6 * 1.0
dolma-arxiv: 28.0 # 28.0 * 1.0
dolma-gutenberg: 5.3 # 5.3 * 1.0
dolma-c4: 69.2 # 138.4 * 0.5
dolma-cc: 597.75 # 1,195.5 * 0.5
dolma-cc-news: 14.3 # 1.0
dolma-falcon: 456.4 # 1.0, refined web
dolma-megawika: 4.6 # 1.0
dolma-owmath: 12.6 # 1.0
dolma-pes2o: 57.2 # 1.0
dolma-reddit: 79.9 # 1.0
dolma-stackexchange: 19.6 # 1.0
dolma-starcoder: 263.8 # 1.0
dolma-flan: 16.5 # 6.5 * 1.0
dolma-wiki: 7.4 # 3.7 * 2.0
paloma/4chan: 0.0
paloma/c4_100_domains: 0.0
paloma/c4_en: 0.0
paloma/dolma-v1_5: 0.0
paloma/dolma_100_programing_languages: 0.0
paloma/dolma_100_subreddits: 0.0
paloma/falcon-refinedweb: 0.0
paloma/gab: 0.0
paloma/m2d2_s2orc_unsplit: 0.0
paloma/m2d2_wikipedia_unsplit: 0.0
paloma/manosphere_meta_sep: 0.0
paloma/mc4: 0.0
paloma/ptb: 0.0
paloma/redpajama: 0.0
paloma/twitterAAE_HELM_fixed: 0.0
paloma/wikitext_103: 0.0
30 changes: 30 additions & 0 deletions config/gemma_2b.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
data:
train_urls:
- "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_train.{1..128}-of-128.jsonl.gz"
validation_urls:
- "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_val.{1..1}-of-8.jsonl.gz"
cache_dir: "gs://wasabi-tpu-training/openwebtext-mini"
tokenizer: "google/gemma-2b"
model:
type: gemma
initialize_from_hf: "google/gemma-2b"
use_hf_model_config: true
trainer:
checkpointer:
base_path: "gs://wasabi-tpu-training/gemma-2b/"
tracker:
type: wandb
project: "levanter"
tags: ["openwebtext", "gemma"]

mp: p=bfloat16,c=bfloat16
train_batch_size: 16 # set for v5e-16 TPU
num_train_steps: 100000
steps_per_eval: 50
tensor_parallel_axes: ["mlp", "heads"]
fsdp_axis: "embed"
batch_axis: "batch"
optimizer:
learning_rate: 1.2E-5 # set low for fine-tuning
weight_decay: 0.1
min_lr_ratio: 0.1
1 change: 1 addition & 0 deletions config/gpt2_small.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ model:
seq_len: 1024
gradient_checkpointing: true
scale_attn_by_inverse_layer_idx: true
attn_backend: jax_flash
trainer:
tracker:
project: "levanter"
Expand Down
5 changes: 3 additions & 2 deletions config/llama2_3b_pretrain.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@ model:
intermediate_dim: 8640
num_layers: 26
num_heads: 32
use_flash_attention: True
attn_backend: jax_flash
flash_attention_block_size: 2048
trainer:
wandb:
tracker:
type: wandb
project: "levanter"
tags: ["redpajama", "llama"]

Expand Down
29 changes: 29 additions & 0 deletions config/llama_1b_with_olmo_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
data: !include data/dolma_olmo_paloma.yaml
model: # 1B class model
type: llama
seq_len: 2048
hidden_dim: 2048
intermediate_dim: 8192
num_layers: 16
num_heads: 16
num_kv_heads: 16
use_flash_attention: True
flash_attention_block_size: 1024
trainer:
tracker:
type: wandb
project: "marin"
tags: ["dolma", "olmo", "llama"]

mp: p=f32,c=bfloat16
train_batch_size: 1024
num_train_steps: 750000 # 3,000,000,000,000 / 4,000,000 = 750,000
steps_per_eval: 1000
tensor_parallel_axes: ["mlp", "heads"]
fsdp_axis: "embed"
batch_axis: "batch"
optimizer:
learning_rate: 4E-4
weight_decay: 0.1
min_lr_ratio: 0.1
warmup: 5000
17 changes: 0 additions & 17 deletions docs/Getting-Started-TPU-VM.md
Original file line number Diff line number Diff line change
Expand Up @@ -174,20 +174,3 @@ gcloud compute tpus tpu-vm ssh $NAME --zone $ZONE --worker=all --command 'sudo r

**and then you have to ctrl-c this after about 10 seconds**. Otherwise, gcloud will think the command failed and will
try again, and get stuck in a loop forever. (You can ctrl-c it at any point after 10 seconds.)


## Random Tricks

I (@dlwh) personally like to use pdsh instead of gcloud to run commands on all workers. It doesn't have the reboot
issue, and seems to work better for long-lived jobs and such. You can install it with `sudo apt-get install pdsh`.
You can then get the ips for your machines like so:

```bash
gcloud compute tpus tpu-vm describe --zone us-east1-d $name | awk '/externalIp: (.*)/ {print $2}' > my-hosts
```

Then you can run a command on all workers like so:

```bash
pdsh -R ssh -w ^my-hosts 'echo hello'
```
6 changes: 3 additions & 3 deletions docs/Levanter-1.0-Release.md
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ W = hax.random.uniform(PRNGKey(2), (Feature,))
def mse(pred, target):
return hax.mean((pred - target) * (pred - target), axis=Batch)

y_pred = hax.dot(Feature, x, W)
y_pred = hax.dot(x, W, axis=Feature)
mse(y_pred, y)
```

Expand Down Expand Up @@ -218,7 +218,7 @@ Embed = hax.Axis("embed", 512) # embedding size

def attention(Key, KPos, query, key, value, mask):
# how similar is each query to each key
scores = hax.dot(Key, query, key) / jnp.sqrt(Key.size)
scores = hax.dot(query, key, axis=Key) / jnp.sqrt(Key.size)

# mask out invalid positions
if mask is not None:
Expand All @@ -228,7 +228,7 @@ def attention(Key, KPos, query, key, value, mask):
scores = hax.nn.softmax(scores, axis=KPos)

# weighted sum of values
return hax.dot(KPos, scores, value)
return hax.dot(scores, value, axis=KPos)
```

With named tensors, we can write the code in a way that conveys the semantics of the operation, rather than the
Expand Down
2 changes: 1 addition & 1 deletion docs/Training-On-Your-Data.md
Original file line number Diff line number Diff line change
Expand Up @@ -447,5 +447,5 @@ tokenizer = AutoTokenizer.from_pretrained("/tmp/my_exported_model")
After training, you can run a separate script to export levanter checkpoints to Huggingface:

```bash
python -m levanter.main.export_to_hf --config_path my_config.yaml --output_dir gs://path/to/output
python -m levanter.main.export_lm_to_hf --config_path my_config.yaml --output_dir gs://path/to/output
```
2 changes: 1 addition & 1 deletion docs/tutorials/Training-On-Audio-Data.md
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ model = WhisperForConditionalGeneration.from_pretrained("WillHeld/levanter-whisp
After training, you can run a separate script to export levanter checkpoints to Huggingface:

```bash
python -m levanter.main.export_to_hf --config_path my_config.yaml --output_dir gs://path/to/output
python -m levanter.main.export_lm_to_hf --config_path my_config.yaml --output_dir gs://path/to/output
```

### HuggingFace Inference
Expand Down
Loading

0 comments on commit 4a63e09

Please sign in to comment.