diff --git a/.github/workflows/run_tests.yaml b/.github/workflows/run_tests.yaml index 3adfe45d1..078324502 100644 --- a/.github/workflows/run_tests.yaml +++ b/.github/workflows/run_tests.yaml @@ -20,9 +20,8 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install flake8 pytest pip install . "jax[cpu]==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}" - pip install soundfile librosa + pip install -r ./tests/requirements.txt - name: Test with pytest run: | XLA_FLAGS=--xla_force_host_platform_device_count=8 PYTHONPATH=tests:src:. pytest tests -m "not entry and not slow and not ray" diff --git a/.github/workflows/tpu_unit_tests.yaml b/.github/workflows/tpu_unit_tests.yaml new file mode 100644 index 000000000..3e27426eb --- /dev/null +++ b/.github/workflows/tpu_unit_tests.yaml @@ -0,0 +1,54 @@ +name: CI with GCP TPU + +on: [pull_request] + +jobs: + test: + runs-on: ubuntu-latest + env: + TPU_ZONE: "us-central2-b" + + steps: + - name: Checkout code + uses: actions/checkout@v2 + + - name: Set up Google Cloud SDK + uses: google-github-actions/setup-gcloud@v1 + with: + project_id: ${{ secrets.GCP_PROJECT_ID }} + + - name: Authenticate to Google Cloud + uses: google-github-actions/auth@v1 + with: + credentials_json: ${{ secrets.GCP_SA_KEY }} + + - name: Configure Google Cloud + run: | + gcloud config set project ${{ secrets.GCP_PROJECT_ID }} + + - name: Create VM + run: | + export TPU_NAME=ci-run-${{ github.run_id }} + eval "$(ssh-agent -s)" + TRUE_SHA=${{ github.event.pull_request.head.sha }} + bash infra/spin-up-vm.sh $TPU_NAME -z ${TPU_ZONE} -t v4-8 --preemptible -s infra/helpers/setup-tpu-vm-tests.sh -b ${TRUE_SHA} --retries 1 +# infra/babysit-tpu-vm.sh $TPU_NAME -z ${{ TPU_ZONE }} -t v4-8 --preemptible -s infra/helpers/setup-tpu-vm-tests.sh -b ${{ github.sha }} --retries 1 -- \ +# PYTHONPATH=$PYTHONPATH:levanter/tests bash levanter/infra/run.sh pytest levanter/tests -m "not entry" + + - name: Run most tests + run: | + export TPU_NAME=ci-run-${{ github.run_id }} + gcloud compute tpus tpu-vm ssh $TPU_NAME --zone ${TPU_ZONE} --command "PYTHONPATH=$PYTHONPATH:levanter/tests bash levanter/infra/run.sh pytest levanter/tests -m 'not entry'" +# Something's wrong with these +# +# - name: Run forked tests +# run: | +# export TPU_NAME=ci-run-${{ github.run_id }} +# gcloud compute tpus tpu-vm ssh $TPU_NAME --zone ${TPU_ZONE} --command "PYTHONPATH=$PYTHONPATH:levanter/tests bash levanter/infra/run.sh pytest --forked levanter/tests -m 'entry'" +# + - name: Cleanup + if: ${{ always() }} + run: | + export TPU_NAME=ci-run-${{ github.run_id }} + echo gcloud compute tpus tpu-vm delete $TPU_NAME --zone ${TPU_ZONE} --quiet + gcloud compute tpus tpu-vm delete $TPU_NAME --zone ${TPU_ZONE} --quiet diff --git a/.gitignore b/.gitignore index 57dd43310..c66f6f352 100644 --- a/.gitignore +++ b/.gitignore @@ -133,6 +133,8 @@ dmypy.json # JetBrains .idea/ +# vscode +.vscode # Wandb stuff /wandb diff --git a/config/data/dolma_olmo_paloma.yaml b/config/data/dolma_olmo_paloma.yaml new file mode 100644 index 000000000..aaf45f802 --- /dev/null +++ b/config/data/dolma_olmo_paloma.yaml @@ -0,0 +1,138 @@ +cache_dir: "gs://marin-data/tokenized/OLMo-1B/dolma-v1.7" +tokenizer: "allenai/OLMo-1B" # requires `pip install ai2-olmo` +# tokenizer: "meta-llama/Llama-2-7b-hf" +stop_strategy: all_exhausted +configs: + dolma-algebraic-stack: + train_urls: + - gs://marin-data/raw/dolma/dolma-v1.7/algebraic-stack-train-{0000..0015}.json.gz + dolma-arxiv: + train_urls: + - gs://marin-data/raw/dolma/dolma-v1.7/arxiv-{0000..0099}.json.gz + dolma-gutenberg: + train_urls: + - gs://marin-data/raw/dolma/dolma-v1.7/books-{0000..0002}.json.gz + dolma-c4: + train_urls: + - gs://marin-data/raw/dolma/dolma-v1.7/c4-{0000..0170}.json.gz + dolma-cc: + train_urls: + - gs://marin-data/raw/dolma/dolma-v1.7/cc_en_head-{0000..0274}.json.gz + - gs://marin-data/raw/dolma/dolma-v1.7/cc_en_middle-{0000..0238}.json.gz # 239 is missing + - gs://marin-data/raw/dolma/dolma-v1.7/cc_en_middle-{0240..0379}.json.gz + - gs://marin-data/raw/dolma/dolma-v1.7/cc_en_tail-{0000..0152}.json.gz # 153 is missing + - gs://marin-data/raw/dolma/dolma-v1.7/cc_en_tail-{0154..0444}.json.gz + dolma-cc-news: + train_urls: + - gs://marin-data/raw/dolma/dolma-v1.7/cc_news_head-{0000..0004}.json.gz + - gs://marin-data/raw/dolma/dolma-v1.7/cc_news_middle-{0000..0002}.json.gz + - gs://marin-data/raw/dolma/dolma-v1.7/cc_news_tail-0000.json.gz + dolma-falcon: + train_urls: + - gs://marin-data/raw/dolma/dolma-v1.7/falcon-{0000..0499}.json.gz + dolma-megawika: + train_urls: + - gs://marin-data/raw/dolma/dolma-v1.7/megawika-{0000..0261}.json.gz + dolma-owmath: + train_urls: + - gs://marin-data/raw/dolma/dolma-v1.7/open-web-math-train-{0000..0012}.json.gz + dolma-pes2o: + train_urls: + - gs://marin-data/raw/dolma/dolma-v1.7/pes2o-{0000..0025}.json.gz + dolma-reddit: + train_urls: + - gs://marin-data/raw/dolma/dolma-v1.7/reddit-{0000..0077}.json.gz + dolma-stackexchange: + train_urls: + - gs://marin-data/raw/dolma/dolma-v1.7/stackexchange-{0000..0025}.json.gz + dolma-starcoder: + train_urls: + - gs://marin-data/raw/dolma/dolma-v1.7/starcoder-{0000..0048}.json.gz + dolma-flan: + train_urls: + - gs://marin-data/raw/dolma/dolma-v1.7/tulu_flan-{0000..0065}.json.gz + dolma-wiki: + train_urls: + - gs://marin-data/raw/dolma/dolma-v1.7/wiki-{0000..0001}.json.gz + # these are just for eval + "paloma/4chan": + validation_urls: + - gs://levanter-data/paloma/4chan_meta_sep/val/val*.jsonl.gz + "paloma/c4_100_domains": + validation_urls: + - gs://levanter-data/paloma/c4_100_domains/val/val*.jsonl.gz + "paloma/c4_en": + validation_urls: + - gs://levanter-data/paloma/c4_en/val/val*.jsonl.gz + "paloma/dolma-v1_5": + validation_urls: + - gs://levanter-data/paloma/dolma-v1_5/val/val*.jsonl.gz + "paloma/dolma_100_programing_languages": + validation_urls: + - gs://levanter-data/paloma/dolma_100_programing_languages/val/val*.jsonl.gz + "paloma/dolma_100_subreddits": + validation_urls: + - gs://levanter-data/paloma/dolma_100_subreddits/val/val*.jsonl.gz + "paloma/falcon-refinedweb": + validation_urls: + - gs://levanter-data/paloma/falcon-refinedweb/val/val*.jsonl.gz + "paloma/gab": + validation_urls: + - gs://levanter-data/paloma/gab/val/val*.jsonl.gz + "paloma/m2d2_s2orc_unsplit": + validation_urls: + - gs://levanter-data/paloma/m2d2_s2orc_unsplit/val/val*.jsonl.gz + "paloma/m2d2_wikipedia_unsplit": + validation_urls: + - gs://levanter-data/paloma/m2d2_wikipedia_unsplit/val/val*.jsonl.gz + "paloma/manosphere_meta_sep": + validation_urls: + - gs://levanter-data/paloma/manosphere_meta_sep/val/val*.jsonl.gz + "paloma/mc4": + validation_urls: + - gs://levanter-data/paloma/mc4/val/val*.jsonl.gz + "paloma/ptb": + validation_urls: + - gs://levanter-data/paloma/ptb/val/val*.jsonl.gz + "paloma/redpajama": + validation_urls: + - gs://levanter-data/paloma/redpajama/val/val*.jsonl.gz + "paloma/twitterAAE_HELM_fixed": + validation_urls: + - gs://levanter-data/paloma/twitterAAE_HELM_fixed/val/val*.jsonl.gz + "paloma/wikitext_103": + validation_urls: + - gs://levanter-data/paloma/wikitext_103/val/val*.jsonl.gz +train_weights: + # sampling proportion comes from https://huggingface.co/datasets/allenai/dolma + dolma-algebraic-stack: 12.6 # 12.6 * 1.0 + dolma-arxiv: 28.0 # 28.0 * 1.0 + dolma-gutenberg: 5.3 # 5.3 * 1.0 + dolma-c4: 69.2 # 138.4 * 0.5 + dolma-cc: 597.75 # 1,195.5 * 0.5 + dolma-cc-news: 14.3 # 1.0 + dolma-falcon: 456.4 # 1.0, refined web + dolma-megawika: 4.6 # 1.0 + dolma-owmath: 12.6 # 1.0 + dolma-pes2o: 57.2 # 1.0 + dolma-reddit: 79.9 # 1.0 + dolma-stackexchange: 19.6 # 1.0 + dolma-starcoder: 263.8 # 1.0 + dolma-flan: 16.5 # 6.5 * 1.0 + dolma-wiki: 7.4 # 3.7 * 2.0 + paloma/4chan: 0.0 + paloma/c4_100_domains: 0.0 + paloma/c4_en: 0.0 + paloma/dolma-v1_5: 0.0 + paloma/dolma_100_programing_languages: 0.0 + paloma/dolma_100_subreddits: 0.0 + paloma/falcon-refinedweb: 0.0 + paloma/gab: 0.0 + paloma/m2d2_s2orc_unsplit: 0.0 + paloma/m2d2_wikipedia_unsplit: 0.0 + paloma/manosphere_meta_sep: 0.0 + paloma/mc4: 0.0 + paloma/ptb: 0.0 + paloma/redpajama: 0.0 + paloma/twitterAAE_HELM_fixed: 0.0 + paloma/wikitext_103: 0.0 diff --git a/config/gemma_2b.yaml b/config/gemma_2b.yaml new file mode 100644 index 000000000..f64b5d881 --- /dev/null +++ b/config/gemma_2b.yaml @@ -0,0 +1,30 @@ +data: + train_urls: + - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_train.{1..128}-of-128.jsonl.gz" + validation_urls: + - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_val.{1..1}-of-8.jsonl.gz" + cache_dir: "gs://wasabi-tpu-training/openwebtext-mini" + tokenizer: "google/gemma-2b" +model: + type: gemma +initialize_from_hf: "google/gemma-2b" +use_hf_model_config: true +trainer: + checkpointer: + base_path: "gs://wasabi-tpu-training/gemma-2b/" + tracker: + type: wandb + project: "levanter" + tags: ["openwebtext", "gemma"] + + mp: p=bfloat16,c=bfloat16 + train_batch_size: 16 # set for v5e-16 TPU + num_train_steps: 100000 + steps_per_eval: 50 + tensor_parallel_axes: ["mlp", "heads"] + fsdp_axis: "embed" + batch_axis: "batch" +optimizer: + learning_rate: 1.2E-5 # set low for fine-tuning + weight_decay: 0.1 + min_lr_ratio: 0.1 diff --git a/config/gpt2_small.yaml b/config/gpt2_small.yaml index b3e0295af..36751f933 100644 --- a/config/gpt2_small.yaml +++ b/config/gpt2_small.yaml @@ -7,6 +7,7 @@ model: seq_len: 1024 gradient_checkpointing: true scale_attn_by_inverse_layer_idx: true + attn_backend: jax_flash trainer: tracker: project: "levanter" diff --git a/config/llama2_3b_pretrain.yaml b/config/llama2_3b_pretrain.yaml index 6b3fd4321..fb0577290 100644 --- a/config/llama2_3b_pretrain.yaml +++ b/config/llama2_3b_pretrain.yaml @@ -6,10 +6,11 @@ model: intermediate_dim: 8640 num_layers: 26 num_heads: 32 - use_flash_attention: True + attn_backend: jax_flash flash_attention_block_size: 2048 trainer: - wandb: + tracker: + type: wandb project: "levanter" tags: ["redpajama", "llama"] diff --git a/config/llama_1b_with_olmo_config.yaml b/config/llama_1b_with_olmo_config.yaml new file mode 100644 index 000000000..fe315de71 --- /dev/null +++ b/config/llama_1b_with_olmo_config.yaml @@ -0,0 +1,29 @@ +data: !include data/dolma_olmo_paloma.yaml +model: # 1B class model + type: llama + seq_len: 2048 + hidden_dim: 2048 + intermediate_dim: 8192 + num_layers: 16 + num_heads: 16 + num_kv_heads: 16 + use_flash_attention: True + flash_attention_block_size: 1024 +trainer: + tracker: + type: wandb + project: "marin" + tags: ["dolma", "olmo", "llama"] + + mp: p=f32,c=bfloat16 + train_batch_size: 1024 + num_train_steps: 750000 # 3,000,000,000,000 / 4,000,000 = 750,000 + steps_per_eval: 1000 + tensor_parallel_axes: ["mlp", "heads"] + fsdp_axis: "embed" + batch_axis: "batch" +optimizer: + learning_rate: 4E-4 + weight_decay: 0.1 + min_lr_ratio: 0.1 + warmup: 5000 diff --git a/docs/Getting-Started-TPU-VM.md b/docs/Getting-Started-TPU-VM.md index fc8d53e73..fe73eef70 100644 --- a/docs/Getting-Started-TPU-VM.md +++ b/docs/Getting-Started-TPU-VM.md @@ -174,20 +174,3 @@ gcloud compute tpus tpu-vm ssh $NAME --zone $ZONE --worker=all --command 'sudo r **and then you have to ctrl-c this after about 10 seconds**. Otherwise, gcloud will think the command failed and will try again, and get stuck in a loop forever. (You can ctrl-c it at any point after 10 seconds.) - - -## Random Tricks - -I (@dlwh) personally like to use pdsh instead of gcloud to run commands on all workers. It doesn't have the reboot -issue, and seems to work better for long-lived jobs and such. You can install it with `sudo apt-get install pdsh`. -You can then get the ips for your machines like so: - -```bash -gcloud compute tpus tpu-vm describe --zone us-east1-d $name | awk '/externalIp: (.*)/ {print $2}' > my-hosts -``` - -Then you can run a command on all workers like so: - -```bash -pdsh -R ssh -w ^my-hosts 'echo hello' -``` diff --git a/docs/Levanter-1.0-Release.md b/docs/Levanter-1.0-Release.md index 05c66683a..d9fb2c106 100644 --- a/docs/Levanter-1.0-Release.md +++ b/docs/Levanter-1.0-Release.md @@ -158,7 +158,7 @@ W = hax.random.uniform(PRNGKey(2), (Feature,)) def mse(pred, target): return hax.mean((pred - target) * (pred - target), axis=Batch) -y_pred = hax.dot(Feature, x, W) +y_pred = hax.dot(x, W, axis=Feature) mse(y_pred, y) ``` @@ -218,7 +218,7 @@ Embed = hax.Axis("embed", 512) # embedding size def attention(Key, KPos, query, key, value, mask): # how similar is each query to each key - scores = hax.dot(Key, query, key) / jnp.sqrt(Key.size) + scores = hax.dot(query, key, axis=Key) / jnp.sqrt(Key.size) # mask out invalid positions if mask is not None: @@ -228,7 +228,7 @@ def attention(Key, KPos, query, key, value, mask): scores = hax.nn.softmax(scores, axis=KPos) # weighted sum of values - return hax.dot(KPos, scores, value) + return hax.dot(scores, value, axis=KPos) ``` With named tensors, we can write the code in a way that conveys the semantics of the operation, rather than the diff --git a/docs/Training-On-Your-Data.md b/docs/Training-On-Your-Data.md index b9a1c2fd3..c14b0ba66 100644 --- a/docs/Training-On-Your-Data.md +++ b/docs/Training-On-Your-Data.md @@ -447,5 +447,5 @@ tokenizer = AutoTokenizer.from_pretrained("/tmp/my_exported_model") After training, you can run a separate script to export levanter checkpoints to Huggingface: ```bash -python -m levanter.main.export_to_hf --config_path my_config.yaml --output_dir gs://path/to/output +python -m levanter.main.export_lm_to_hf --config_path my_config.yaml --output_dir gs://path/to/output ``` diff --git a/docs/tutorials/Training-On-Audio-Data.md b/docs/tutorials/Training-On-Audio-Data.md index c378fda08..235b2e79b 100644 --- a/docs/tutorials/Training-On-Audio-Data.md +++ b/docs/tutorials/Training-On-Audio-Data.md @@ -229,7 +229,7 @@ model = WhisperForConditionalGeneration.from_pretrained("WillHeld/levanter-whisp After training, you can run a separate script to export levanter checkpoints to Huggingface: ```bash -python -m levanter.main.export_to_hf --config_path my_config.yaml --output_dir gs://path/to/output +python -m levanter.main.export_lm_to_hf --config_path my_config.yaml --output_dir gs://path/to/output ``` ### HuggingFace Inference diff --git a/infra/babysit-tpu-vm.sh b/infra/babysit-tpu-vm.sh index bd4bf6405..318d61604 100755 --- a/infra/babysit-tpu-vm.sh +++ b/infra/babysit-tpu-vm.sh @@ -59,6 +59,8 @@ CMD_ARGS_STR=$(printf ' %s' "${CMD_ARGS[@]}") CMD_ARGS_STR=${CMD_ARGS_STR:1} CMD_ARGS_STR="RUN_ID=${RUN_ID} ${CMD_ARGS_STR}" +TRIES=0 + # check if the VM is running # if not, spin it up # if it is, just run the command @@ -77,11 +79,19 @@ while true; do echo "Running command on VM $VM_NAME" echo "gcloud compute tpus tpu-vm ssh --zone=$ZONE $VM_NAME --command='$CMD_ARGS_STR' --worker=all" gcloud compute tpus tpu-vm ssh --zone=$ZONE $VM_NAME --command="$CMD_ARGS_STR" --worker=all - if [ $? -eq 0 ]; then + EXIT_CODE=$? + if [ $EXIT_CODE -eq 0 ]; then echo "Command succeeded. Exiting" break else echo "Command failed" + TRIES=$((TRIES+1)) + if [ "$RETRIES" -ge 0 ]; then + if [ $TRIES -ge "$RETRIES" ]; then + echo "Command failed $TRIES times, exiting with $EXIT_CODE" + break + fi + fi fi fi else @@ -92,7 +102,12 @@ while true; do sleep 10 done -echo "Job finished!" +# exit code is the exit code of the command +if [ $EXIT_CODE -eq 0 ]; then + echo "Command succeeded" +else + echo "Command failed too many times, ending with exit code $EXIT_CODE" +fi # delete the VM when we're done gcloud compute tpus tpu-vm describe --zone $ZONE $VM_NAME &> /dev/null @@ -100,3 +115,5 @@ if [ $? -eq 0 ]; then echo "Deleting VM $VM_NAME" yes | gcloud compute tpus tpu-vm delete --zone $ZONE $VM_NAME fi + +exit $EXIT_CODE diff --git a/infra/helpers/parse-tpu-creation-args.sh b/infra/helpers/parse-tpu-creation-args.sh index ec6796213..44da2a719 100644 --- a/infra/helpers/parse-tpu-creation-args.sh +++ b/infra/helpers/parse-tpu-creation-args.sh @@ -23,6 +23,7 @@ AUTODELETE=true SETUP_SCRIPT="$SCRIPT_DIR/helpers/setup-tpu-vm.sh" SUBNETWORK="default" USE_ALPHA=false +RETRIES=-1 # how many times babysit-tpu-vm.sh should retry before giving up. -1 means infinite if [ -z "$GIT_BRANCH" ]; then GIT_BRANCH=$(git rev-parse --abbrev-ref HEAD) @@ -86,6 +87,11 @@ while [[ $# -gt 0 ]]; do USE_ALPHA="true" shift # past argument ;; + --retries) + RETRIES="$2" + shift # past argument + shift # past value + ;; *) # unknown option, assume it's the vm name if it doesn't start with a dash if [[ $1 == -* ]]; then echo "Error: unknown option $1" >&2 @@ -115,19 +121,26 @@ done # check if the branch we chose has been pushed to the remote # if not, warn - -# get the remote branch name -REMOTE_BRANCH=$(git ls-remote --heads origin "$GIT_BRANCH" | awk '{print $2}' | sed 's/refs\/heads\///g') -# if it's empty, warn -if [ -z "$REMOTE_BRANCH" ]; then - >&2 echo "Warning: branch $GIT_BRANCH not found on remote $GIT_REPO" +# if it's a commit sha/short-sha (or something that looks like one), check if it's in the remote +if [[ "$GIT_BRANCH" =~ ^[0-9a-f]{7,40}$ ]]; then + # if it's a commit, check if it's in the remote + BRANCHES=$(git branch -r --contains "$GIT_BRANCH") + if [ -z "$BRANCHES" ]; then + >&2 echo "Warning: commit $GIT_BRANCH not found on remote $GIT_REPO" + fi else + # get the remote branch name + REMOTE_BRANCH=$(git ls-remote --heads origin "$GIT_BRANCH" | awk '{print $2}' | sed 's/refs\/heads\///g') + # if it's empty, warn + if [ -z "$REMOTE_BRANCH" ]; then + >&2 echo "Warning: branch $GIT_BRANCH not found on remote $GIT_REPO" + else + # make sure it's pushed + LOCAL_COMMIT=$(git rev-parse --short "$GIT_BRANCH") + REMOTE_COMMIT=$(git rev-parse --short "origin/$REMOTE_BRANCH") - # make sure it's pushed - LOCAL_COMMIT=$(git rev-parse --short "$GIT_BRANCH") - REMOTE_COMMIT=$(git rev-parse --short "origin/$REMOTE_BRANCH") - - if [ "$LOCAL_COMMIT" != "$REMOTE_COMMIT" ]; then - >&2 echo "Warning: branch $GIT_BRANCH not pushed to remote $GIT_REPO. Local commit: $LOCAL_COMMIT, remote commit: $REMOTE_COMMIT" + if [ "$LOCAL_COMMIT" != "$REMOTE_COMMIT" ]; then + >&2 echo "Warning: branch $GIT_BRANCH not pushed to remote $GIT_REPO. Local commit: $LOCAL_COMMIT, remote commit: $REMOTE_COMMIT" + fi fi fi diff --git a/infra/helpers/setup-tpu-vm-nfs.sh b/infra/helpers/setup-tpu-vm-nfs.sh deleted file mode 100755 index a159b8469..000000000 --- a/infra/helpers/setup-tpu-vm-nfs.sh +++ /dev/null @@ -1,68 +0,0 @@ -set -x -# broadly based on https://github.com/ayaka14732/tpu-starter - -# tcmalloc interferes with intellij remote ide -sudo patch -f -b /etc/environment << EOF -2c2 -< LD_PRELOAD="/usr/lib/x86_64-linux-gnu/libtcmalloc.so.4" ---- -> #LD_PRELOAD="/usr/lib/x86_64-linux-gnu/libtcmalloc.so.4" -EOF - -# don't complain if already applied -retCode=$? -[[ $retCode -le 1 ]] || exit $retCode - -# install python 3.10, latest git, and nfs -#sudo apt-get install -y software-properties-common -#sudo add-apt-repository -y ppa:deadsnakes/ppa -#sudo add-apt-repository -y ppa:git-core/ppa -#sudo apt-get update -#sudo apt-get install -y python3.10-full python3.10-dev nfs-common git golang - -sudo systemctl stop unattended-upgrades # this frequently holds the apt lock -sudo systemctl disable unattended-upgrades -sudo apt remove -y unattended-upgrades -# if it's still running somehow, kill it -if [ $(ps aux | grep unattended-upgrade | wc -l) -gt 1 ]; then - sudo kill -9 $(ps aux | grep unattended-upgrade | awk '{print $2}') -fi -# sometimes apt-get update fails, so retry a few times -for i in {1..5}; do - sudo apt-get install -y software-properties-common \ - && sudo add-apt-repository -y ppa:deadsnakes/ppa \ - && sudo add-apt-repository -y ppa:git-core/ppa \ - && sudo apt-get update \ - && sudo apt-get install -y python3.10-full python3.10-dev nfs-common git \ - && break -done -sudo systemctl start unattended-upgrades - -# set up nfs -NFS_SERVER=10.5.220.250 -MOUNT_POINT="/files" -sudo mkdir -p ${MOUNT_POINT} -CURRENT_NFS_ENTRY=$(grep ${NFS_SERVER} /etc/fstab) -DESIRED_NFS_ENTRY="${NFS_SERVER}:/propulsion ${MOUNT_POINT} nfs defaults 0 0" -# if different, fix -if [ "$CURRENT_NFS_ENTRY" != "$DESIRED_NFS_ENTRY" ]; then - set -e - echo "Setting up nfs" - grep -v "${NFS_SERVER}" /etc/fstab > /tmp/fstab.new - echo "${DESIRED_NFS_ENTRY}" >> /tmp/fstab.new - # then move the new fstab back into place - sudo cp /etc/fstab /etc/fstab.orig - sudo mv /tmp/fstab.new /etc/fstab -fi -sudo mount -a - - -# default to loading the venv -sudo bash -c "echo \"source ${MOUNT_POINT}/venv310/bin/activate\" > /etc/profile.d/activate_shared_venv.sh" - -for x in `ls -d /files/lev*`; do - git config --global --add safe.directory $x -done - -# symlink lev* to home -ln -s /files/lev* ~ diff --git a/infra/helpers/setup-tpu-vm-tests.sh b/infra/helpers/setup-tpu-vm-tests.sh new file mode 100755 index 000000000..4b6cf27f5 --- /dev/null +++ b/infra/helpers/setup-tpu-vm-tests.sh @@ -0,0 +1,126 @@ +# broadly based on https://github.com/ayaka14732/tpu-starter + +# parse some arguments +# usage: ./setup-tpu-vm.sh -b|--branch -r + +if [ "$DEBUG" == "1" ]; then + set -x +fi + +REPO="https://github.com/stanford-crfm/levanter.git" +BRANCH=main + +if [ "$GIT_BRANCH" != "" ]; then + BRANCH="$GIT_BRANCH" +fi + +while [[ $# -gt 0 ]]; do + key="$1" + case $key in + -b|--branch) + BRANCH="$2" + shift + shift + ;; + -r|--repo) + REPO="$2" + shift + shift + ;; + *) + >&2 echo "Unknown option $1" + exit 1 + ;; + esac +done + +# we frequently deal with commands failing, and we like to loop until they succeed. this function does that for us +function retry { + for i in {1..5}; do + $@ + if [ $? -eq 0 ]; then + break + fi + if [ $i -eq 5 ]; then + >&2 echo "Error running $*, giving up" + exit 1 + fi + >&2 echo "Error running $*, retrying in 5 seconds" + sleep 5 + done +} + +# tcmalloc interferes with intellij remote ide +sudo patch -f -b /etc/environment << EOF +2c2 +< LD_PRELOAD="/usr/lib/x86_64-linux-gnu/libtcmalloc.so.4" +--- +> #LD_PRELOAD="/usr/lib/x86_64-linux-gnu/libtcmalloc.so.4" +EOF + + + +# don't complain if already applied +retCode=$? +[[ $retCode -le 1 ]] || exit $retCode + + +# set these env variables b/c it makes tensorstore behave better +if ! grep -q TENSORSTORE_CURL_LOW_SPEED_TIME_SECONDS /etc/environment; then + # need sudo + echo "TENSORSTORE_CURL_LOW_SPEED_TIME_SECONDS=60" | sudo tee -a /etc/environment > /dev/null +fi + +if ! grep -q TENSORSTORE_CURL_LOW_SPEED_LIMIT_BYTES /etc/environment; then + echo "TENSORSTORE_CURL_LOW_SPEED_LIMIT_BYTES=1024" | sudo tee -a /etc/environment > /dev/null +fi + +# install python 3.10, latest git +sudo systemctl stop unattended-upgrades # this frequently holds the apt lock +sudo systemctl disable unattended-upgrades +sudo apt remove -y unattended-upgrades +# if it's still running somehow, kill it +if [ $(ps aux | grep unattended-upgrade | wc -l) -gt 1 ]; then + sudo kill -9 $(ps aux | grep unattended-upgrade | awk '{print $2}') +fi + +# sometimes apt-get update fails, so retry a few times +retry sudo apt-get install -y software-properties-common +retry sudo add-apt-repository -y ppa:deadsnakes/ppa +retry sudo add-apt-repository -y ppa:git-core/ppa +retry sudo apt-get -qq update +retry sudo apt-get -qq install -y python3.10-full python3.10-dev git + +VENV=~/venv310 +# if the venv doesn't exist, make it +if [ ! -d "$VENV" ]; then + echo "Creating virtualenv at $VENV" + python3.10 -m venv $VENV +fi + +source $VENV/bin/activate + +pip install -U pip +pip install -U wheel + +# jax and jaxlib +# libtpu sometimes has issues installing for clinical (probably firewall?) +retry pip install -U "jax[tpu]==0.4.26" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html + +# clone levanter +git clone $REPO levanter +echo $VENV > levanter/infra/venv_path.txt + +cd levanter + +# checkout the branch we want + +echo "Checking out branch $BRANCH" + +git checkout $BRANCH + +# install levanter + +pip install -e . + +pip install -r tests/requirements.txt diff --git a/pyproject.toml b/pyproject.toml index 8ea755b15..f17a26791 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,19 +29,20 @@ dependencies = [ "haliax>=1.4.dev296", "equinox>=0.11.4", "jaxtyping>=0.2.20", + "tokenizers>=0.15.2", "transformers>=4.39.3", "optax>=0.1.9", - "wandb~=0.16.6", + "wandb>=0.16.6,<0.18.0", # We don't actually directly depend on scipy, but recent JAX had an issue "scipy<=1.12.0", - "draccus>=0.7.2", + "draccus>=0.8.0", "pyarrow>=11.0.0", "zstandard>=0.20.0", "datasets~=2.18", - "gcsfs>=2024.2,<2024.4", + "gcsfs>=2024.2,<2024.6", "braceexpand>=0.1.7", "jmp>=0.0.3", - "fsspec[http]>=2024.2,<2024.4", + "fsspec[http]>=2024.2,<2024.6", "tensorstore==0.1.56", "pytimeparse>=1.1.8", "humanfriendly==10.0", diff --git a/src/levanter/compat/torch_serialization.py b/src/levanter/compat/torch_serialization.py index 77c7e6953..f07d69e15 100644 --- a/src/levanter/compat/torch_serialization.py +++ b/src/levanter/compat/torch_serialization.py @@ -247,7 +247,7 @@ def _flatten_linear(layer, prefix): return ret_dict tree_prefixes = leaf_key_paths(tree, prefix, is_leaf=lambda x: isinstance(x, hnn.Linear), use_state_dict_keys=True) - jax.tree_map(_flatten_linear, tree, tree_prefixes, is_leaf=lambda x: isinstance(x, hnn.Linear)) + jax.tree_util.tree_map(_flatten_linear, tree, tree_prefixes, is_leaf=lambda x: isinstance(x, hnn.Linear)) return ret_dict @@ -318,7 +318,7 @@ def _unflatten_linear(layer, prefix): tree_prefixes = leaf_key_paths( layer, prefix, is_leaf=lambda x: isinstance(x, hnn.Linear), use_state_dict_keys=True ) - jax.tree_map(_unflatten_linear, layer, tree_prefixes, is_leaf=lambda x: isinstance(x, hnn.Linear)) + jax.tree_util.tree_map(_unflatten_linear, layer, tree_prefixes, is_leaf=lambda x: isinstance(x, hnn.Linear)) return ret_dict diff --git a/src/levanter/data/shard_cache.py b/src/levanter/data/shard_cache.py index e4f9b5059..f2ceb5571 100644 --- a/src/levanter/data/shard_cache.py +++ b/src/levanter/data/shard_cache.py @@ -1399,6 +1399,8 @@ async def get_chunk(self, chunk_idx: int) -> Optional[ChunkMetadata]: return self.chunks[chunk_idx] elif self._is_finished: return None + elif self._finished_promise.exception() is not None: + raise self._finished_promise.exception() # type: ignore else: if chunk_idx not in self._reader_promises: self._reader_promises[chunk_idx] = asyncio.Future() @@ -1638,6 +1640,15 @@ def _get_chunk_unmapped(self, mapped_index: int, *, timeout: Optional[float] = N current_timeout *= 2 current_timeout = min(current_timeout, 100) continue + except asyncio.exceptions.InvalidStateError: + self.logger.warning( + f"Invalid state waiting for chunk {mapped_index} for {int(next_time - time_in)} seconds" + ) + next_time = time.time() + current_timeout *= 2 + current_timeout = min(current_timeout, 100) + time.sleep(current_timeout) + continue if chunk is None: raise IndexError(f"Chunk index out of bounds. (Mapped index {mapped_index})") diff --git a/src/levanter/data/sharded_dataset.py b/src/levanter/data/sharded_dataset.py index a31ba02d8..8ddee007f 100644 --- a/src/levanter/data/sharded_dataset.py +++ b/src/levanter/data/sharded_dataset.py @@ -308,6 +308,9 @@ def _sniff_format_for_dataset(url): format_from_url = format break + if format_from_url is None: + raise ValueError(f"Unknown format for {url}") + if format_from_url == ".json": # unfortunately, HF (and others) will use "json" for jsonl files, # so we have to do some extra work to distinguish them. @@ -422,6 +425,8 @@ def _mk_shard_name_mapping(urls): shard_name = shard_name[1:] shard_name = shard_name.replace(".", "_") + if shard_name in _shard_name_to_url_mapping: + raise ValueError(f"Duplicate shard name {shard_name}") _shard_name_to_url_mapping[shard_name] = url if missing_urls: diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index c635a98ea..7af090e52 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -19,7 +19,6 @@ import regex from draccus import field from jaxtyping import PRNGKeyArray -from tokenizers import normalizers import haliax as hax from haliax import Axis @@ -231,7 +230,9 @@ def build_or_load( cache_dir, source: ShardedDataset[str], tokenizer: PreTrainedTokenizerBase, + *, flatten_docs=True, + enforce_bos=True, enforce_eos=True, batch_size=128, rows_per_chunk=DEFAULT_ROWS_PER_CHUNK, @@ -239,7 +240,9 @@ def build_or_load( await_finished=True, override_resources=None, ) -> "TokenizedDocumentCache": - bt = BatchTokenizer(tokenizer, enforce_eos=enforce_eos, override_resources=override_resources) + bt = BatchTokenizer( + tokenizer, enforce_bos=enforce_bos, enforce_eos=enforce_eos, override_resources=override_resources + ) monitors = monitors or [] cache = build_cache( cache_dir, @@ -312,9 +315,6 @@ def _maybe_force_tokenizer_parallelism(tokenizer: PreTrainedTokenizerBase): os.environ["TOKENIZERS_PARALLELISM"] = "true" -LONG_STRING_WORKAROUND = 100_000 - - ws = regex.compile(r"\s") @@ -332,7 +332,6 @@ def __init__( *, batch_size=128, override_resources=None, - _workaround_len=LONG_STRING_WORKAROUND, return_attention_mask=False, padding=False, max_length=None, @@ -359,7 +358,7 @@ def __init__( if enforce_eos or enforce_bos: input_ids = tokenizer("hi there")["input_ids"] should_append_eos = input_ids[-1] != tokenizer.eos_token_id and enforce_eos - should_append_bos = input_ids[0] == tokenizer.bos_token_id and enforce_bos + should_append_bos = input_ids[0] != tokenizer.bos_token_id and enforce_bos else: should_append_eos = False should_append_bos = False @@ -368,107 +367,21 @@ def __init__( self._need_to_add_eos = should_append_eos self._need_to_add_bos = should_append_bos - self._workaround_len = _workaround_len def __call__(self, batch: Sequence[str]) -> BatchEncoding: - orig_lengths = [len(d) for d in batch] if self._need_to_add_bos: batch = [self.tokenizer.bos_token + " " + d for d in batch] if self._need_to_add_eos: batch = [d + " " + self.tokenizer.eos_token for d in batch] - if self._needs_long_sequence_workaround: - # break any strings that are longer than 50K characters into smaller chunks - orig_batch = batch - batch = [] - needs_merge = [] - for i, d in enumerate(orig_batch): - needs_merge.append(False) - orig_len = orig_lengths[i] - while len(d) > self._workaround_len: - # we'd rather break strings at whitespace, so find the first whitespace - match = ws.search(d, self._workaround_len) - # this is vanishingly unlikely, but if we can't find a whitespace, just break it at the limit - if match is None: - split = len(d) - else: - split = match.start() - - batch.append(d[:split]) - needs_merge.append(True) - - d = d[split:] - orig_len -= split - - batch.append(d) - else: - needs_merge = [] - if self.padding is not False: encoding = self.tokenizer(batch, return_attention_mask=self.return_attention_mask, verbose=False, padding=self.padding, max_length=self.max_length, truncation=True) # type: ignore else: encoding = self.tokenizer(batch, return_attention_mask=self.return_attention_mask, verbose=False) # type: ignore - if needs_merge: - new_encoding = self._merge_split_encodings(batch, encoding, needs_merge) - encoding = BatchEncoding(new_encoding) - return encoding - @staticmethod - def _merge_split_encodings(batch, encoding, needs_merge): - # merge the encodings back together - # we might need to merge multiple encodings together - # needs merge marks the first n-1 encodings that need to be merged for each document - new_encoding = {} - for k, v in encoding.items(): - if len(v) == 0: - continue - if isinstance(v[0], np.ndarray): - assert len(v) == len(batch) - v_out = [] - vs_to_merge = [] - for i in range(len(batch)): - if not needs_merge[i]: - v_out.append(np.concatenate(vs_to_merge)) - vs_to_merge = [] - vs_to_merge.append(v[i]) - - if len(vs_to_merge) > 0: - v_out.append(np.concatenate(vs_to_merge)) - - new_encoding[k] = v_out - elif isinstance(v[0], list): - v_out = [] - vs_to_merge = [] - for i in range(len(batch)): - if not needs_merge[i]: - if len(vs_to_merge) > 0: - v_out.append(list(chain(*vs_to_merge))) - vs_to_merge = [] - vs_to_merge.append(v[i]) - - if len(vs_to_merge) > 0: - v_out.append(list(chain(*vs_to_merge))) - new_encoding[k] = v_out - else: - raise ValueError(f"Unknown type {type(v[0])}") - return new_encoding - - # TODO remove this when it's resolved https://github.com/huggingface/tokenizers/issues/1449 - @cached_property - def _needs_long_sequence_workaround(self): - if isinstance(self.tokenizer, PreTrainedTokenizerFast): - normalizer = self.tokenizer.backend_tokenizer.normalizer - if normalizer is None: - return False - # if there's a "Replace" normalizer, then we need to do the workaround - # inexplicably there's no way to see inside a Sequence so we also have to assume it needs it - return isinstance(normalizer, (normalizers.Replace, normalizers.Sequence)) - else: - return False - @property def num_cpus(self) -> int: if self.override_resources is not None: diff --git a/src/levanter/lora.py b/src/levanter/lora.py index 3e0dee750..83558f75d 100644 --- a/src/levanter/lora.py +++ b/src/levanter/lora.py @@ -150,7 +150,7 @@ def init(In: hax.Axis, Out: Axis, r: int, alpha: float, dropout_prob: float, *, return LowRankLinear(lora_A, lora_B, dropout, alpha / r) def merge(self) -> hax.NamedArray: - return hax.dot(LORA_R, self.lora_A.weight, self.lora_B.weight) * self.scale + return hax.dot(self.lora_A.weight, self.lora_B.weight, axis=LORA_R) * self.scale class LoraLinear(eqx.Module, StateDictSerializationMixin): diff --git a/src/levanter/main/cache_dataset.py b/src/levanter/main/cache_dataset.py index 4dd46e63c..74e216ad2 100644 --- a/src/levanter/main/cache_dataset.py +++ b/src/levanter/main/cache_dataset.py @@ -37,7 +37,9 @@ def main(args: RayCachedLMDatasetConfig): logger.warning(f"Skipping {split} because it is empty.") continue - monitors = [RichMetricsMonitor(source.num_shards), LoggingMetricsMonitor("preprocess/" + split, commit=True)] + monitors: list = [RichMetricsMonitor(source.num_shards)] + if not isinstance(args.tracker, NoopConfig): + monitors.append(LoggingMetricsMonitor("preprocess/" + split, commit=True)) cache = build_cache( cache_dir=split_cache_dir, diff --git a/src/levanter/main/eval_lm.py b/src/levanter/main/eval_lm.py index 843379a28..bb4b3b93b 100644 --- a/src/levanter/main/eval_lm.py +++ b/src/levanter/main/eval_lm.py @@ -1,5 +1,5 @@ import logging -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Optional import equinox as eqx @@ -32,9 +32,9 @@ class EvalLmConfig: checkpoint_path: Optional[str] = None hf_checkpoint: Optional[RepoRef] = None - trainer: TrainerConfig = TrainerConfig() - data: LMDatasetConfig = LMDatasetConfig() - model: LmConfig = Gpt2Config() + trainer: TrainerConfig = field(default_factory=TrainerConfig) + data: LMDatasetConfig = field(default_factory=LMDatasetConfig) + model: LmConfig = field(default_factory=Gpt2Config) compare_torch: bool = False eval_on_train: bool = False diff --git a/src/levanter/main/viz_logprobs.py b/src/levanter/main/viz_logprobs.py index ef16a7238..bf8b603b2 100644 --- a/src/levanter/main/viz_logprobs.py +++ b/src/levanter/main/viz_logprobs.py @@ -1,5 +1,5 @@ import logging -from dataclasses import dataclass +from dataclasses import dataclass, field import equinox as eqx import jax @@ -28,9 +28,9 @@ class VizGpt2Config: checkpoint_path: str path: str = "logprobs.html" - trainer: TrainerConfig = TrainerConfig() - data: LMDatasetConfig = LMDatasetConfig() - model: LmConfig = Gpt2Config() + trainer: TrainerConfig = field(default_factory=TrainerConfig) + data: LMDatasetConfig = field(default_factory=LMDatasetConfig) + model: LmConfig = field(default_factory=Gpt2Config) num_docs: int = 256 diff --git a/src/levanter/models/attention.py b/src/levanter/models/attention.py index fc17ea016..e7c94f50b 100644 --- a/src/levanter/models/attention.py +++ b/src/levanter/models/attention.py @@ -1,6 +1,7 @@ import functools import math import warnings +from enum import Enum from typing import Optional, Union, overload import equinox as eqx @@ -11,13 +12,31 @@ from jaxtyping import PRNGKeyArray import haliax -from haliax import Axis, AxisSelection, AxisSelector, NamedArray +from haliax import Axis, AxisSelection, AxisSelector, NamedArray, axis_name from haliax.jax_utils import named_call from haliax.nn.attention import causal_mask, combine_masks_and, combine_masks_or from haliax.partitioning import pspec_for_axis from haliax.types import PrecisionLike +class AttentionBackend(Enum): + DEFAULT = "default" # use the default attention type for the accelerator + NVTE = "nvte" # with Transformer Engine on NVIDIA GPUs + SPLASH = "splash" # on TPU. + JAX_FLASH = "jax_flash" # Use the JAX reference implementation + VANILLA = "vanilla" # regular dot product attention + + +def default_attention_type() -> AttentionBackend: + accelerator_type = jax.local_devices()[0].platform + if accelerator_type == "gpu": + return AttentionBackend.NVTE + elif accelerator_type == "tpu": + return AttentionBackend.SPLASH + else: + return AttentionBackend.JAX_FLASH + + @named_call def dot_product_attention( QPos: AxisSelector, @@ -30,7 +49,8 @@ def dot_product_attention( bias: Optional[NamedArray] = None, attention_dtype: Optional[jnp.dtype] = None, precision: PrecisionLike = None, - use_flash: bool = False, + use_flash: Optional[bool] = None, + attn_backend: Optional[AttentionBackend] = None, flash_block_size: Optional[int] = None, dropout: float = 0.0, *, @@ -38,9 +58,12 @@ def dot_product_attention( prng: Optional[PRNGKeyArray] = None, ): """ - This method is similar to [haliax.nn.attention.dot_product_attention][] but uses the [AttentionMask][] class, - which we might move to haliax.nn.attention in the future. + This method is similar to [haliax.nn.attention.dot_product_attention][] but it can use different backends for + attention. In particular, it can use the Transformer Engine for NVIDIA GPUs, the Splash Attention kernel for TPUs, + or a pure JAX reference flash attention 2 implementation for other platforms, or it can fall back to regular dot + product attention. + It also uses the [AttentionMask][] class, which we might move to haliax.nn.attention in the future. Unlike the Haliax version, it requires that the QPos and KPos already be different. Args: @@ -62,74 +85,114 @@ def dot_product_attention( Returns: NamedArray of shape (value.axes - KPos + QPos) """ - if QPos == KPos: - raise ValueError("QPos and KPos must be different") + if axis_name(QPos) == axis_name(KPos): + raise ValueError("QPos and KPos must have different names") - accelerator_type = jax.local_devices()[0].platform + if use_flash is not None: + if attn_backend is None: + if not use_flash: + attn_backend = AttentionBackend.VANILLA + else: + attn_backend = AttentionBackend.DEFAULT + else: + if attn_backend != AttentionBackend.VANILLA and not use_flash: + raise ValueError("use_flash is False, but flash_backend is not VANILLA") + elif attn_backend == AttentionBackend.VANILLA and use_flash: + raise ValueError("use_flash is True, but flash_backend is VANILLA") + elif use_flash is None and attn_backend is None: + # if the block_size doesn't divide the seq lens, we can't use flash. Previously default was use_flash=False + if flash_block_size is not None: + qlen = query.axis_size(QPos) + klen = key.axis_size(KPos) + if qlen % flash_block_size != 0 or klen % flash_block_size != 0: + use_flash = False + attn_backend = AttentionBackend.VANILLA + + if attn_backend is None or attn_backend == AttentionBackend.DEFAULT: + was_default = True + attn_backend = default_attention_type() + else: + was_default = False + + match attn_backend: + case AttentionBackend.NVTE: + attention_out = _try_te_attention( + QPos, + KPos, + Key, + query, + key, + value, + mask, + bias, + dropout, + inference, + force_te=not was_default, + prng=prng, + attention_dtype=attention_dtype, + precision=precision, + flash_block_size=flash_block_size, + ) + case AttentionBackend.SPLASH: + attention_out = _try_tpu_splash_attention( + QPos, + KPos, + Key, + query, + key, + value, + mask, + bias, + dropout, + inference, + force_flash=not was_default, + prng=prng, + attention_dtype=attention_dtype, + precision=precision, + block_size=flash_block_size, + ) + case AttentionBackend.VANILLA: + attention_out = simple_attention_with_dropout( + QPos, + KPos, + Key, + query, + key, + value, + mask, + bias, + inference, + dropout, + attention_dtype, + precision, + prng=prng, + ) + case _: + attention_out = None - if not use_flash: - return simple_attention_with_dropout( - QPos, KPos, Key, query, key, value, mask, bias, inference, dropout, attention_dtype, precision, prng=prng - ) - elif accelerator_type == "gpu": - attention_out = _try_te_attention( - QPos, - KPos, - Key, - query, - key, - value, - mask, - bias, - dropout, - inference, - prng=prng, - attention_dtype=attention_dtype, - precision=precision, - flash_block_size=flash_block_size, - ) - if attention_out is not None: - return attention_out - elif accelerator_type == "tpu": - attention_out = _try_tpu_splash_attention( + if attention_out is not None: + return attention_out + else: + # local import to avoid circular imports + from levanter.models.flash_attention import flash_attention + + return flash_attention( QPos, KPos, Key, query, key, value, - mask, - bias, - dropout, - inference, - prng=prng, - attention_dtype=attention_dtype, - precision=precision, block_size=flash_block_size, + mask=mask, + bias=bias, + dropout=dropout, + inference=inference, + key=prng, + dtype=attention_dtype, + precision=precision, ) - if attention_out is not None: - return attention_out - - from levanter.models.flash_attention import flash_attention - - return flash_attention( - QPos, - KPos, - Key, - query, - key, - value, - block_size=flash_block_size, - mask=mask, - bias=bias, - dropout=dropout, - inference=inference, - key=prng, - dtype=attention_dtype, - precision=precision, - ) - def simple_attention_with_dropout( QPos: Axis, @@ -154,7 +217,7 @@ def simple_attention_with_dropout( Key, KPos, query, key, mask=m, bias=bias, attention_dtype=attention_dtype, precision=precision ) weights = haliax.nn.dropout(weights, dropout, key=prng, inference=inference) - return haliax.dot(KPos, weights, value) + return haliax.dot(weights, value, axis=KPos) def _try_te_attention( @@ -173,6 +236,7 @@ def _try_te_attention( attention_dtype: Optional[jnp.dtype] = None, precision: PrecisionLike = None, flash_block_size: Optional[int] = None, + force_te: bool, ): try: return _te_flash_attention( @@ -195,27 +259,34 @@ def _try_te_attention( if "transformer_engine" not in str(e): raise - warnings.warn( - "transformer_engine is not installed. Please install it to use NVIDIA's optimized fused attention. " - "Falling back to the reference implementation." - ) + msg = "transformer_engine is not installed. Please install it to use NVIDIA's optimized fused attention." + if force_te: + raise ImportError(msg) + + warnings.warn(f"{msg}. Falling back to the reference implementation.") return None except NotImplementedError as e: - message = str(e) - warnings.warn( - f"Could not use transformer_engine for flash attention: {message}. Falling back to the reference" - ) + message = f"Could not use transformer_engine for flash attention: {str(e)}." + if force_te: + raise NotImplementedError(message) + + warnings.warn(f"{message}. Falling back to the reference implementation.") + return None except ValueError as e: message = str(e) if message.startswith("Unsupported backend="): _dtype = attention_dtype or query.dtype - msg = "TE doesn't work with these arguments. Falling back to the reference implementation.\n" + msg = "NVTE doesn't work with these arguments. Falling back to the reference implementation.\n" "Check nvte_get_fused_attn_backend for supported configurations:\n" "https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/common/fused_attn/fused_attn.cpp#L71" if _dtype not in (jnp.float16, jnp.bfloat16, jnp.float8_e5m2, jnp.float8_e4m3fn): - msg += f"In particular, TE doesn't support {_dtype} yet." + msg += f"In particular, NVTE doesn't support {_dtype} yet." + + if force_te: + raise NotImplementedError(msg) + warnings.warn(msg) else: raise @@ -248,7 +319,7 @@ def _te_flash_attention( value = value.astype(attention_dtype) if precision is not None: - warnings.warn("precision is not supported for TE fused attention. Ignoring.") + warnings.warn("precision is not supported for NVTE fused attention. Ignoring.") # references: https://github.com/NVIDIA/TransformerEngine/blob/8255f87f3ee8076db21777795ce15b6ddf8754c0/transformer_engine/jax/fused_attn.py#L31 # https://github.com/NVIDIA/TransformerEngine/blob/8255f87f3ee8076db21777795ce15b6ddf8754c0/transformer_engine/jax/flax/transformer.py#L269 @@ -300,7 +371,7 @@ def _te_flash_attention( is_training=is_training, ) - # per the TE code, the output is BSHD. we can reshape it to match our axes + # per the NVTE code, the output is BSHD. we can reshape it to match our axes # we have to ungroup the axes, then reshape them to match our expected output attn_output = haliax.named(attn_output, ("B", "S", "H", "D")) # the output shape is B, S_q, H_q, D_v. Right now we're requiring D_k == D_v @@ -366,10 +437,10 @@ def _te_materialize_mask(KPos, QPos, batch_size, mask): def _bin_and_group_axes_by_function(q, k, v, QPos, KPos, Key): """ - TE and the Splash Attention kernel require the Q, K, and V to be in a specific format. This function groups the axes + NVTE and the Splash Attention kernel require the Q, K, and V to be in a specific format. This function groups the axes of Q, K, and V into the right bins to match that format. - TE requires Q, K, and V to have shape BSHD (Batch, Sequence, Head, Embed), while Splash Attention requires BHSD + NVTE requires Q, K, and V to have shape BSHD (Batch, Sequence, Head, Embed), while Splash Attention requires BHSD the size of the axes is a bit flexible, with the following conditions: - B must be the same for all (TODO: is this true?) @@ -377,7 +448,7 @@ def _bin_and_group_axes_by_function(q, k, v, QPos, KPos, Key): - H: Q's H must be a multiple of K's H (for GQA or MQA) - D must be the same for all (TODO: is this true? possibly V can be different) - We can thus classify the axes in q, k, v by their function and populate the TE axes in the right order + We can thus classify the axes in q, k, v by their function and populate the NVTE axes in the right order - Key is D. ATM we're assuming this is a single axis. - QPos and KPos are always S - the latest other axis that is present in all three is H. If there are no other axes, we'll add a dummy axis @@ -606,16 +677,21 @@ def _try_tpu_splash_attention( dropout: float = 0.0, inference: bool = False, *, + force_flash: bool, prng: Optional[PRNGKeyArray] = None, attention_dtype: Optional[jnp.dtype] = None, precision: PrecisionLike = None, block_size: Optional[int] = None, ) -> Optional[NamedArray]: if dropout != 0.0: + if force_flash: + raise NotImplementedError("Splash attention does not support dropout.") warnings.warn("Splash attention does not support. Falling back to the reference implementation.") return None if bias is not None: + if force_flash: + raise NotImplementedError("Splash attention does not support bias.") warnings.warn("Splash attention does not support bias. Falling back to the reference implementation.") return None @@ -639,6 +715,8 @@ def _try_tpu_splash_attention( except ImportError as e: if "pallas" not in str(e): raise + if force_flash: + raise ImportError("Could not import splash attention. You need to update your JAX to at least 0.4.26.") warnings.warn( "Could not import splash attention. You need to update your JAX to at least 0.4.26. " "Falling back to the reference implementation." @@ -646,6 +724,8 @@ def _try_tpu_splash_attention( return None except NotImplementedError as e: message = str(e) + if force_flash: + raise NotImplementedError(f"Could not use splash attention: {message}") warnings.warn(f"Could not use splash attention: {message}. Falling back to the reference") return None @@ -685,6 +765,9 @@ def _tpu_splash_attention( q_class, k_class, v_class = _bin_and_group_axes_by_function(query, key, value, QPos, KPos, Key) + # pre-divide q_ by sqrt(d) to match the reference implementation + query = query / jnp.sqrt(query.resolve_axis(Key).size) + q_: jax.Array = _reshape_axes_for_bshd_bins(query, q_class, output_order=list("BHSD")).array k_ = _reshape_axes_for_bshd_bins(key, k_class, output_order=list("BHSD")).array v_ = _reshape_axes_for_bshd_bins(value, v_class, output_order=list("BHSD")).array @@ -692,6 +775,13 @@ def _tpu_splash_attention( B, Hq, Sq, D = q_.shape Bk, Hk, Sk, Dk = k_.shape + # number + if Sk % 128 != 0: + raise NotImplementedError("Splash attention requires KPos to be a multiple of 128") + + if block_size is not None and block_size % 128 != 0: + raise NotImplementedError(f"Splash attention requires block_size to be a multiple of 128, got {block_size}") + QPos = query.resolve_axis(QPos) KPos = key.resolve_axis(KPos) diff --git a/src/levanter/models/backpack.py b/src/levanter/models/backpack.py index 6a7cb9b96..332842a7c 100644 --- a/src/levanter/models/backpack.py +++ b/src/levanter/models/backpack.py @@ -342,7 +342,7 @@ def embed(self, input_ids, *, key): return x def unembed(self, x: NamedArray): - return hax.dot("embed", x, self.token_embeddings) + return hax.dot(x, self.token_embeddings, axis="embed") def _state_dict_key_map(self) -> Dict[str, Optional[str]]: return {"token_embeddings": "wte.weight", "position_embeddings": "wpe.weight"} @@ -416,7 +416,9 @@ def __call__( sense_vectors = sense_vectors.rename({self.Pos: self.config.KeyPos}) ## Weight-and-sum - hidden_states = hax.dot(self.config.KeyPos, contextualization_weights, sense_vectors) # (seq, senses, embed) + hidden_states = hax.dot( + contextualization_weights, sense_vectors, axis=self.config.KeyPos + ) # (seq, senses, embed) hidden_states = hax.sum(hidden_states, axis=self.config.Senses) # Rescale - this is important for large num_senses diff --git a/src/levanter/models/flash_attention.py b/src/levanter/models/flash_attention.py index 0e25a12a8..5998cf34d 100644 --- a/src/levanter/models/flash_attention.py +++ b/src/levanter/models/flash_attention.py @@ -180,7 +180,7 @@ def do_qk_block(state): v_j = v[KPos, ds.block(j, block_size)] # Step 8: compute Sij = QiKj^T - attn_ij = hax.dot(Key, q_i, k_j, precision=precision) + attn_ij = hax.dot(q_i, k_j, precision=precision, axis=Key) if bias is not None: if bias.has_axis(QPos.name): @@ -211,7 +211,7 @@ def do_qk_block(state): sumexp_i = exp_diff * sumexp_i + hax.sum(P_ij, axis=KPos.name) # Step 10: Compute O_i = diag(exp(m_i^{j-1} - m_i^j) O_i + P_i^j V_j - o_i = exp_diff * o_i + hax.dot(KPos.name, P_ij, v_j) + o_i = exp_diff * o_i + hax.dot(P_ij, v_j, axis=KPos.name) return (i, j + 1, o_i, q_i, sumexp_i, max_i) @@ -296,7 +296,7 @@ def do_inner_block(state): L_i = L[QPos, ds.block(i, block_size)] D_i = D[QPos, ds.block(i, block_size)] - attn_ij = hax.dot(Key, q_i, k_j, precision=precision) + attn_ij = hax.dot(q_i, k_j, precision=precision, axis=Key) if dropout > 0 and not inference: attn_ij = hax.nn.dropout(attn_ij, dropout, inference=False, key=jax.random.fold_in(key, i * Tc + j)) @@ -314,12 +314,12 @@ def do_inner_block(state): if dropout > 0 and not inference: p_ij = hax.nn.dropout(p_ij, dropout, inference=False, key=jax.random.fold_in(key, i * Tc + j)) - dP_ij = hax.dot(Key, dO_i, v_j) + dP_ij = hax.dot(dO_i, v_j, axis=Key) dAttn_ij = p_ij * (dP_ij - D_i) dAttn_ij = dAttn_ij.astype(dQ_i.dtype) - dV_ji = hax.dot(QPos.name, p_ij, dO_i).astype(dV_j.dtype) - dK_ji = hax.dot(QPos.name, dAttn_ij, q_i).astype(dK_j.dtype) + dV_ji = hax.dot(p_ij, dO_i, axis=QPos.name).astype(dV_j.dtype) + dK_ji = hax.dot(dAttn_ij, q_i, axis=QPos.name).astype(dK_j.dtype) # GQA-specific: eliminate unnecessary axes (e.g. 'q_heads_per_group') unnecessary_axes = hax.eliminate_axes(dV_ji.axes, v.axes) @@ -329,7 +329,7 @@ def do_inner_block(state): dV_j = dV_j + dV_ji dK_j = dK_j + dK_ji - dQ_i = dQ_i + hax.dot(KPos.name, dAttn_ij, k_j).astype(dQ.dtype) + dQ_i = dQ_i + hax.dot(dAttn_ij, k_j, axis=KPos.name).astype(dQ.dtype) # dQ[i*block_size:(i+1)*block_size] = dQi dQ = dQ.updated_slice({QPos: i * block_size}, dQ_i) diff --git a/src/levanter/models/gemma.py b/src/levanter/models/gemma.py new file mode 100644 index 000000000..962f797c6 --- /dev/null +++ b/src/levanter/models/gemma.py @@ -0,0 +1,370 @@ +import dataclasses +from dataclasses import dataclass +from typing import Dict, Optional, Type, Union + +import equinox as eqx +import jax.numpy as jnp +import jax.random as jrandom + +import haliax as hax +import haliax.nn as hnn +from haliax import Axis, AxisSpec, NamedArray +from haliax.jax_utils import maybe_rng_split, named_call, shaped_rng_split +from haliax.nn.scan import Stacked + +from levanter.compat.hf_checkpoints import HFCheckpointConverter, HFCompatConfig +from levanter.compat.torch_serialization import ( + StateDict, + StateDictSerializationMixin, + apply_prefix, + stack_state_dict, + unstack_state_dict, +) +from levanter.logging import silence_transformer_nag +from levanter.models.attention import AttentionBackend, AttentionMask +from levanter.models.llama import ( # Gemma attention and MLP is identical to LLama + LlamaAttention, + LlamaEmbedding, + LlamaMlp, +) +from levanter.models.lm_model import LmConfig, LmHeadModel +from levanter.types import BlockFoldable +from levanter.utils.py_utils import cached_classproperty + + +silence_transformer_nag() +from transformers import GemmaConfig as HfGemmaConfig # noqa: E402 +from transformers import PretrainedConfig as HfConfig # noqa: E402 + + +# Gemma is... very similar to Llama, so we use much of the same modeling code. +# +# The key differences are: +# * Activation is changed to approximate gelu +# * Embedding weights are tied to the LM head +# * Gemma allows specifying a head dimension independently of the hidden and intermediate dims. + + +@LmConfig.register_subclass("gemma") +@dataclass(frozen=True) +class GemmaConfig(HFCompatConfig): + """Config for GemmaModel. + + Defaults are set for Gemma-2B. + + Args: + seq_len (int, optional): maximum length of the input sequence. Defaults to 8192. + hidden_dim (int, optional): dimension of the hidden state. Defaults to 2048. + intermediate_dim (int, optional): dimension of the intermediate state. Defaults to 16384. + num_layers (int, optional): number of hidden layers in the Transformer encoder. Defaults to 18. + num_heads (int, optional): number of attention heads for each attention layer. Defaults to 8. + num_kv_heads (int, optional): number of attention heads for keys and values in each attention layer. + Setting to 1 means MQA. Setting to num_heads means MHA. Otherwise GQA. + Note that num_heads must be divisible by this number. Defaults to 1. + activation_function (str, optional): activation function for the hidden layer. Defaults to "gelu". + rope_scaling (Dict, ignored): dict containing the scaling configuration for the Rotary Positional Embedding. + """ + + activation_function: str = "gelu" + initializer_range: float = 0.02 + layer_norm_epsilon: float = 1e-5 + + seq_len: int = 8192 + hidden_dim: int = 2048 + intermediate_dim: int = 16384 + vocab_size: int = 256_000 + num_layers: int = 18 + num_heads: int = 8 + head_dim: int = 256 + num_kv_heads: int = 1 + attn_dropout = 0.0 + norm_eps = 1e-6 + + rope_base: int = 10_000 + norm_embeddings: bool = True + + # Attention-related config + upcast_attn: bool = False + use_flash_attention: Optional[bool] = None + attn_backend: Optional[AttentionBackend] = None + flash_attention_block_size: Optional[int] = None + + gradient_checkpointing: bool = True + gradient_checkpointing_block_size: int = 5 + scan_layers: bool = True + + use_bias: bool = False + rope_scaling: Optional[dict] = None + + # Axis + Pos = property(lambda self: Axis(name="position", size=self.seq_len)) + KeyPos = property(lambda self: self.Pos.alias("key_position")) + Embed = property(lambda self: Axis(name="embed", size=self.hidden_dim)) + Heads = property(lambda self: Axis(name="heads", size=self.num_heads)) + KVHeads = property(lambda self: Axis(name="kv_heads", size=self.num_kv_heads)) + Layers = property(lambda self: Axis(name="layers", size=self.num_layers)) + Mlp = property(lambda self: Axis(name="mlp", size=self.intermediate_dim)) + HeadSize = property(lambda self: Axis(name="head_size", size=self.hidden_dim // self.num_heads)) + + def __post_init__(self): + assert ( + self.num_heads % self.num_kv_heads == 0 + ), f"num_heads={self.num_heads} not divisible by num_kv_heads={self.num_kv_heads}." + + @cached_classproperty + def default_hf_checkpoint_converter(cls) -> HFCheckpointConverter["GemmaConfig"]: # type: ignore + return HFCheckpointConverter( + cls, # type: ignore + reference_checkpoint="google/gemma-2b", + trust_remote_code=True, + HfConfigClass=HfGemmaConfig, + ) + + # The activation function specified in the Gemma HF repo is "gelu", but this is incorrect, it should + # be "gelu_pytorch_tanh". For backwards compatibility, HF did not change the value in the repo, but + # instead patches around it. We mimic this behavior and use the approximate gelu internally, and + # specify the approximate gelu for HF when appropriate. + # See https://github.com/huggingface/transformers/pull/29402 for more detail. + @classmethod + def from_hf_config(cls, hf_config: HfConfig): + if hf_config.hidden_activation: + activation_function = hf_config.hidden_activation + else: + activation_function = hf_config.hidden_act + + if activation_function == "gelu_pytorch_tanh": + activation_function = "gelu" + + assert activation_function is not None, "No activation function found in HF configuration." + return GemmaConfig( + seq_len=hf_config.max_position_embeddings, + activation_function=activation_function, + hidden_dim=hf_config.hidden_size, + intermediate_dim=hf_config.intermediate_size, + num_layers=hf_config.num_hidden_layers, + num_heads=hf_config.num_attention_heads, + num_kv_heads=hf_config.num_key_value_heads, + initializer_range=hf_config.initializer_range, + layer_norm_epsilon=hf_config.rms_norm_eps, + rope_base=hf_config.rope_theta, + ) + + def to_hf_config(self, vocab_size: int, config_overrides: Optional[Dict] = None) -> HfGemmaConfig: + """Convert to HuggingFace's GemmaConfig + + Args: + vocab_size (int, optional): Vocabulary size of the tokenizer. Defaults to 32000. + config_overrides (dict, optional): Overrides for the config. Defaults to None. + + Returns: + HfGemmaConfig: HuggingFace's GemmaConfig + """ + if config_overrides is None: + config_overrides = {} + + config = HfGemmaConfig( + max_position_embeddings=self.seq_len, + hidden_size=self.hidden_dim, + intermediate_size=self.intermediate_dim, + num_hidden_layers=self.num_layers, + num_attention_heads=self.num_heads, + num_key_value_heads=self.num_kv_heads, + head_dim=self.hidden_dim // self.num_heads, + hidden_activation=( + "gelu_pytorch_tanh" if self.activation_function == "gelu" else self.activation_function + ), + initializer_range=self.initializer_range, + rms_norm_eps=self.layer_norm_epsilon, + vocab_size=vocab_size, + **config_overrides, + ) + return config + + @property + def model_type(cls) -> Type["GemmaLMHeadModel"]: + return GemmaLMHeadModel + + +class GemmaRMSNorm(hnn.LayerNorm): + """ + Like Llama, Gemma uses an RMSNorm instead of a layer norm. + + The canonical Gemma model computes the variances calculation in fp32 explicity, so + we do the same for compatibility. + """ + + @staticmethod + def init(axis: AxisSpec, eps: float = 1e-6, use_weight: bool = True, use_bias: bool = False): + assert use_weight, "GemmaRMSNorm does not support use_weight=False" + assert not use_bias, "GemmaRMSNorm does not support use_bias=True" + + weight = hax.zeros(axis) + bias = None + + return GemmaRMSNorm(axis, weight, bias, eps) + + def __call__(self, x: NamedArray) -> NamedArray: + # Gemma's norm is calculated in fp32 explicitly + # See https://github.com/google/gemma_pytorch/blob/main/gemma/model.py#L173 + dtype = x.dtype + x = x.astype(jnp.float32) + + var = hax.mean(hax.square(x), axis=self.axis) + inv = hax.rsqrt(var + self.eps) + out = x * inv + out = out * (1.0 + self.weight) + return out.astype(dtype) + + +class GemmaDecoderLayer(StateDictSerializationMixin, eqx.Module): + config: GemmaConfig = eqx.static_field() + self_attn: LlamaAttention + mlp: LlamaMlp + input_layernorm: GemmaRMSNorm + post_attention_layernorm: GemmaRMSNorm + + @staticmethod + def init(config: GemmaConfig, *, key) -> "GemmaDecoderLayer": + k_attn, k_mlp = jrandom.split(key, 2) + + attn = LlamaAttention.init(config, key=k_attn) # type: ignore + mlp = LlamaMlp.init( + config.Embed, + config.Mlp, + config.activation_function, + key=k_mlp, + use_bias=config.use_bias, + ) + ln_1 = GemmaRMSNorm.init(config.Embed, eps=config.layer_norm_epsilon, use_bias=config.use_bias) + ln_2 = GemmaRMSNorm.init(config.Embed, eps=config.layer_norm_epsilon, use_bias=config.use_bias) + + return GemmaDecoderLayer(config, attn, mlp, ln_1, ln_2) + + @named_call + def __call__(self, x: NamedArray, mask: Optional[NamedArray | AttentionMask], *, key=None) -> NamedArray: + k_attn, k_mlp = maybe_rng_split(key, 2) + # self attention and skip connection + residual = x + x = self.input_layernorm(x) + attn_output = self.self_attn(x=x, mask=mask, key=k_attn) + x = residual + attn_output + + # MLP and skip connection + residual = x + x = self.post_attention_layernorm(x) + mlp_output = self.mlp(x, key=k_mlp) + output = residual + mlp_output + return output + + +class GemmaTransformer(StateDictSerializationMixin, eqx.Module): + config: GemmaConfig = eqx.static_field() + layers: BlockFoldable[GemmaDecoderLayer] + norm: GemmaRMSNorm + + @staticmethod + def init(config: GemmaConfig, *, key) -> "GemmaTransformer": + S = Stacked + if not config.scan_layers: + from haliax.nn.scan import BlockSeq + + S = BlockSeq + + layers = S.init(config.Layers, GemmaDecoderLayer, gradient_checkpointing=config.gradient_checkpointing)( + config, + key=shaped_rng_split(key, config.num_layers), + ) + ln_f = GemmaRMSNorm.init(config.Embed, eps=config.layer_norm_epsilon, use_bias=config.use_bias) + + return GemmaTransformer(config, layers, ln_f) + + @named_call + def __call__(self, x: NamedArray, attn_mask: Optional[NamedArray | AttentionMask], *, key) -> NamedArray: + keys = maybe_rng_split(key, self.config.num_layers) if key is not None else None + x = self.layers.fold(x, mask=attn_mask, key=keys) + x = self.norm(x) + + return x + + def from_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None): + if isinstance(self.layers, Stacked): + state_dict = stack_state_dict(state_dict, prefix=apply_prefix(prefix, "layers")) + + out = super().from_state_dict(state_dict, prefix=prefix) + return out + + def update_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None) -> StateDict: + my_state_dict: StateDict = {} + super().update_state_dict(my_state_dict, prefix=prefix) + + if isinstance(self.layers, Stacked): + stacked_dict = unstack_state_dict(my_state_dict, prefix=apply_prefix(prefix, "layers")) + state_dict.update(stacked_dict) + + return state_dict + + +class GemmaLMHeadModel(eqx.Module, LmHeadModel[GemmaConfig], StateDictSerializationMixin): + transformer: GemmaTransformer + + # Gemma ties the weights of the embedding matrix and LM head. Rather than + # use eqx.Shared which is a bit cumbersome, we juse re-use the embedding matrix + # as we do in GPT-2. + embeddings: LlamaEmbedding + + @property + def config(self): + return self.transformer.config + + @property + def vocab_size(self) -> int: + return self.Vocab.size + + @property + def Vocab(self) -> Axis: + return self.embeddings.Vocab + + @classmethod + def init(cls, Vocab: Axis, config: GemmaConfig, *, key) -> "GemmaLMHeadModel": + k_t, k_emb = jrandom.split(key, 2) + transformer = GemmaTransformer.init(config, key=k_t) + embeddings = LlamaEmbedding.init(Vocab, config, key=k_emb) + return GemmaLMHeadModel(transformer, embeddings) + + def __call__( + self, + input_ids: NamedArray, + attn_mask: Optional[Union[NamedArray, AttentionMask]] = None, + *, + key=None, + ) -> NamedArray: + """ + Args: + input_ids (NamedArray): [batch, position] + Indices of input sequence tokens in the vocabulary. + attn_mask (Union[NamedArray, AttentionMask], optional): [batch, position] + Mask to avoid performing attention on the padding token indices of the encoder input. + The attn_mask from training pipeline may be an AttentionMask object instead of NamedArray + """ + x = self.embeddings.embed(input_ids) + x = self.transformer(x, attn_mask=attn_mask, key=key) + lm_logits = self.embeddings.unembed(x) + return lm_logits + + def resize_vocab(self, new_size: int, key=None) -> "LmHeadModel[GemmaConfig]": + new_embeddings = self.embeddings.resize_embeddings(new_size, key=key) + + return dataclasses.replace(self, embeddings=new_embeddings) + + def _state_dict_key_map(self) -> Dict[str, Optional[str]]: + """Map from Levanter model names to HF.""" + return {"transformer": "model", "embeddings": None} + + def from_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None): + return super().from_state_dict(state_dict, prefix) + + def update_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None) -> StateDict: + my_dict: StateDict = {} + super().update_state_dict(my_dict, prefix=prefix) + state_dict.update(my_dict) + return state_dict diff --git a/src/levanter/models/gpt2.py b/src/levanter/models/gpt2.py index 191ac689d..c2caf5390 100644 --- a/src/levanter/models/gpt2.py +++ b/src/levanter/models/gpt2.py @@ -26,7 +26,7 @@ unstack_state_dict, ) from levanter.logging import silence_transformer_nag -from levanter.models.attention import AttentionMask, dot_product_attention +from levanter.models.attention import AttentionBackend, AttentionMask, dot_product_attention from levanter.models.lm_model import LmConfig from levanter.utils.py_utils import cached_classproperty @@ -64,7 +64,8 @@ class Gpt2Config(HFCompatConfig): use_bias: bool = True - use_flash_attention: bool = True # use flash attention. This is a pure jax impl, and is not faster than normal, but it scales to long sequence lengths + use_flash_attention: Optional[bool] = None + attn_backend: Optional[AttentionBackend] = None flash_attention_block_size: Optional[int] = None # Axes @@ -191,6 +192,7 @@ def __call__(self, x: NamedArray, mask: Optional[AttentionMask | NamedArray], la mask=mask, inference=self.inference, use_flash=self.config.use_flash_attention, + attn_backend=self.config.attn_backend, flash_block_size=self.config.flash_attention_block_size, prng=k_drop, attention_dtype=jnp.float32 if self.config.upcast_attn else None, @@ -340,7 +342,7 @@ def embed(self, input_ids, *, key): return x def unembed(self, x: NamedArray): - return hax.dot("embed", x, self.token_embeddings.weight) + return hax.dot(x, self.token_embeddings.weight, axis="embed") def _state_dict_key_map(self) -> Dict[str, Optional[str]]: return {"token_embeddings": "wte", "position_embeddings": "wpe"} diff --git a/src/levanter/models/llama.py b/src/levanter/models/llama.py index 6d74241c5..c0e1ca45a 100644 --- a/src/levanter/models/llama.py +++ b/src/levanter/models/llama.py @@ -25,7 +25,7 @@ unstack_state_dict, ) from levanter.logging import silence_transformer_nag -from levanter.models.attention import AttentionMask, dot_product_attention +from levanter.models.attention import AttentionBackend, AttentionMask, dot_product_attention from levanter.models.gpt2 import ACT2FN from levanter.models.lm_model import LmConfig, LmHeadModel from levanter.types import BlockFoldable @@ -67,7 +67,8 @@ class LlamaConfig(HFCompatConfig): # Attention-related config upcast_attn: bool = False - use_flash_attention: bool = True + use_flash_attention: Optional[bool] = True + attn_backend: Optional[AttentionBackend] = None flash_attention_block_size: Optional[int] = None gradient_checkpointing: bool = True @@ -270,6 +271,7 @@ def __call__(self, x: NamedArray, mask: Optional[NamedArray | AttentionMask], *, mask, attention_dtype=jnp.float32 if self.config.upcast_attn else x.dtype, use_flash=c.use_flash_attention, + attn_backend=self.config.attn_backend, flash_block_size=c.flash_attention_block_size, ) @@ -439,9 +441,7 @@ class LlamaEmbedding(StateDictSerializationMixin, eqx.Module): @staticmethod def init(Vocab: Axis, config: LlamaConfig, *, key) -> "LlamaEmbedding": - k_wte = jrandom.split(key, 1) - - token_embeddings = hax.random.normal(k_wte, (Vocab, config.Embed)) + token_embeddings = hax.random.normal(key, (Vocab, config.Embed)) return LlamaEmbedding(Vocab, config, token_embeddings) @named_call @@ -451,7 +451,7 @@ def embed(self, input_ids, *args): return x def unembed(self, x: NamedArray): - return hax.dot("embed", x, self.token_embeddings) + return hax.dot(x, self.token_embeddings, axis="embed") def _state_dict_key_map(self) -> Dict[str, Optional[str]]: return {"token_embeddings": "model.embed_tokens.weight"} diff --git a/src/levanter/models/mistral.py b/src/levanter/models/mistral.py index 913d3e7ac..c1fc62b54 100644 --- a/src/levanter/models/mistral.py +++ b/src/levanter/models/mistral.py @@ -19,7 +19,7 @@ unflatten_linear_layers, ) from levanter.logging import silence_transformer_nag -from levanter.models.attention import AttentionMask +from levanter.models.attention import AttentionBackend, AttentionMask from levanter.models.llama import LlamaConfig, LlamaEmbedding, LlamaTransformer from levanter.models.lm_model import LmConfig, LmHeadModel from levanter.utils.py_utils import cached_classproperty @@ -61,7 +61,8 @@ class MistralConfig(LlamaConfig): # Attention-related config upcast_attn: bool = False - use_flash_attention: bool = True + use_flash_attention: Optional[bool] = True + attn_backend: Optional[AttentionBackend] = None flash_attention_block_size: Optional[int] = None gradient_checkpointing: bool = True diff --git a/src/levanter/models/whisper.py b/src/levanter/models/whisper.py index ff725151b..6a86d1a50 100644 --- a/src/levanter/models/whisper.py +++ b/src/levanter/models/whisper.py @@ -27,7 +27,7 @@ ) from levanter.logging import silence_transformer_nag from levanter.models.asr_model import ASRConfig, ASRMixin -from levanter.models.attention import AttentionMask, dot_product_attention +from levanter.models.attention import AttentionBackend, AttentionMask, dot_product_attention from levanter.models.lm_model import LmConfig from levanter.utils.py_utils import cached_classproperty @@ -63,6 +63,7 @@ class WhisperConfig(HFCompatConfig, ASRConfig): # Attention-related config upcast_attn: bool = True use_flash_attention: bool = False + attn_backend: Optional[AttentionBackend] = None flash_attention_block_size: Optional[int] = None @property @@ -447,7 +448,7 @@ def embed(self, input_ids, *, key): return x def unembed(self, x: NamedArray): - return hax.dot("embed_dim", x, self.token_embeddings.weight) + return hax.dot(x, self.token_embeddings.weight, axis="embed_dim") def resize_embeddings(self, new_size: int, key: Optional[PRNGKeyArray] = None): new_token_embeddings = self.token_embeddings.resize_embeddings(new_size, key=key) diff --git a/src/levanter/tracker/tracker_fns.py b/src/levanter/tracker/tracker_fns.py index e3b6a1f71..5e3b6ba4f 100644 --- a/src/levanter/tracker/tracker_fns.py +++ b/src/levanter/tracker/tracker_fns.py @@ -49,8 +49,9 @@ def log_metrics(metrics: dict[str, Any], *, step: Optional[int], commit: Optiona def _no_throw_log_metrics(metrics: dict[str, Any], *, step: Optional[int], commit: Optional[bool] = None): try: if _global_tracker is None: - raise RuntimeError("No global tracker set") - _global_tracker.log(metrics, step=step, commit=False) + warnings.warn("No global tracker set") + else: + _global_tracker.log(metrics, step=step, commit=False) except Exception: logger.exception("Error logging metrics") diff --git a/tests/gpt2_test.py b/tests/gpt2_test.py index 8c524c409..293947188 100644 --- a/tests/gpt2_test.py +++ b/tests/gpt2_test.py @@ -7,13 +7,14 @@ import haliax as hax from haliax import Axis -from levanter.models.attention import AttentionMask +from levanter.models.attention import AttentionBackend, AttentionMask from levanter.models.gpt2 import Gpt2Config, Gpt2LMHeadModel from test_utils import check_load_config, check_model_works_with_seqlen, parameterize_with_configs @pytest.mark.parametrize("num_blocks", [1, 4, 12]) -def test_gradient_checkpointing(num_blocks): +@pytest.mark.parametrize("attn_backend", [AttentionBackend.JAX_FLASH, AttentionBackend.VANILLA]) +def test_gradient_checkpointing(num_blocks, attn_backend): # ensure that gradient checkpointing doesn't change the output # (this is a regression test for a bug that caused the output to change) config = Gpt2Config( @@ -22,7 +23,8 @@ def test_gradient_checkpointing(num_blocks): num_layers=num_blocks, num_heads=8, gradient_checkpointing=False, - use_flash_attention=True, + # use_flash_attention=True, + attn_backend=attn_backend, ) config_checkpoint = dataclasses.replace(config, gradient_checkpointing=True) key = PRNGKey(0) diff --git a/tests/requirements.txt b/tests/requirements.txt new file mode 100644 index 000000000..fc1700a2d --- /dev/null +++ b/tests/requirements.txt @@ -0,0 +1,5 @@ +flake8 +pytest +soundfile +librosa +pytest-forked diff --git a/tests/test_attention.py b/tests/test_attention.py index be664281b..7defcb4a0 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -1,9 +1,17 @@ +import jax import jax.numpy as jnp +import jax.random as jrandom import pytest +from chex import assert_trees_all_close import haliax as hax -from levanter.models.attention import AttentionMask, _bin_and_group_axes_by_function, _te_flash_attention +from levanter.models.attention import ( + AttentionMask, + _bin_and_group_axes_by_function, + _te_flash_attention, + _tpu_splash_attention, +) from test_utils import skip_if_module_missing @@ -155,7 +163,7 @@ def test_llama_attention_uses_te(q_heads): attention_dtype=jnp.bfloat16, ) - assert jnp.allclose(out.array, 0.0) + assert_trees_all_close(out.array, 0.0) @skip_if_module_missing("transformer_engine") @@ -181,4 +189,28 @@ def test_gpt2_attention_uses_te(): mask, attention_dtype=jnp.bfloat16, ) - assert jnp.allclose(out.array, 0.0) + assert_trees_all_close(out.array, 0.0) + + +def test_tpu_splash_attention(): + if jax.default_backend() != "tpu": + pytest.skip("TPU only") + + BLOCK_SIZE = 512 + + Head = hax.Axis("Head", 8) + Key = hax.Axis("Key", 128) # splash only supports 128 + QPos = hax.Axis("QPos", BLOCK_SIZE * 2) + KPos = hax.Axis("KPos", BLOCK_SIZE * 2) + + q = hax.random.normal(jrandom.PRNGKey(0), (QPos, Head, Key)) * 0.02 + k = hax.random.normal(jrandom.PRNGKey(1), (KPos, Head, Key)) * 0.02 + v = hax.random.normal(jrandom.PRNGKey(2), (KPos, Head, Key)) * 0.02 + + mask = AttentionMask.causal() + + with jax.sharding.Mesh(jax.devices(), ("dp",)): + flash_out = _tpu_splash_attention(QPos, KPos, Key, q, k, v, inference=True, mask=mask, block_size=BLOCK_SIZE) + hax_out = hax.nn.attention.dot_product_attention(KPos, Key, q, k, v, mask=mask.materialize(QPos, KPos)) + assert hax_out.axes == flash_out.axes + assert_trees_all_close(hax_out.array, flash_out.array, atol=1e-3, rtol=1e-3) diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index b48ff90c2..306bec9cd 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -10,7 +10,7 @@ import jax.tree_util as jtu import numpy as np import optax -from chex import assert_trees_all_equal +from chex import assert_trees_all_close, assert_trees_all_equal from jax import ShapeDtypeStruct from jax import numpy as jnp @@ -331,7 +331,8 @@ def init_fn(key): assert not any(jax.tree_util.tree_leaves(eqx.filter(loaded, lambda x: isinstance(x, ShapeDtypeStruct)))) # should be the same as model1 - assert_trees_all_equal( + # on TPU, there's a very slight difference for some reason + assert_trees_all_close( jax.tree_util.tree_leaves(arrays_only(eqx.filter(loaded, is_checkpointed))), jax.tree_util.tree_leaves(arrays_only(eqx.filter(model1, is_checkpointed))), ) diff --git a/tests/test_flash_attention.py b/tests/test_flash_attention.py index a79aa36fa..7a944f597 100644 --- a/tests/test_flash_attention.py +++ b/tests/test_flash_attention.py @@ -5,6 +5,7 @@ import jax.random as jrandom import jax.sharding import pytest +from chex import assert_trees_all_close import haliax as hax import haliax.nn as hnn @@ -30,7 +31,7 @@ def test_flash_attention_acausal(): hax_out = hnn.attention.dot_product_attention(KPos, Key, q, k, v) assert hax_out.axes == flash_out.axes - assert jnp.allclose(hax_out.array, flash_out.array, atol=1e-3, rtol=1e-3) + assert_trees_all_close(hax_out.array, flash_out.array, atol=1e-3, rtol=1e-3) def test_flash_attention_causal_mask(): @@ -40,15 +41,19 @@ def test_flash_attention_causal_mask(): mask = AttentionMask.causal() - q = hax.random.normal(jrandom.PRNGKey(0), (QPos, Key)) - k = hax.random.normal(jrandom.PRNGKey(1), (KPos, Key)) - v = hax.random.normal(jrandom.PRNGKey(2), (KPos, Key)) + q = hax.random.normal(jrandom.PRNGKey(0), (QPos, Key)) * 0.02 + k = hax.random.normal(jrandom.PRNGKey(1), (KPos, Key)) * 0.02 + v = hax.random.normal(jrandom.PRNGKey(2), (KPos, Key)) * 0.02 - flash_out = flash_attention(QPos, KPos, Key, q, k, v, inference=True, mask=mask, block_size=BLOCK_SIZE) - hax_out = hnn.attention.dot_product_attention(KPos, Key, q, k, v, mask=mask.materialize(QPos, KPos)) + flash_out = flash_attention( + QPos, KPos, Key, q, k, v, inference=True, mask=mask, block_size=BLOCK_SIZE, precision="highest" + ) + hax_out = hnn.attention.dot_product_attention( + KPos, Key, q, k, v, mask=mask.materialize(QPos, KPos), precision="highest" + ) assert hax_out.axes == flash_out.axes - assert jnp.allclose(hax_out.array, flash_out.array, atol=1e-3, rtol=1e-3) + assert_trees_all_close(hax_out.array, flash_out.array, atol=1e-3, rtol=1e-3) def test_grad_attention(): @@ -73,14 +78,14 @@ def d_attn(qkv, fn): (q, k, v), functools.partial(flash_attention, inference=True, block_size=BLOCK_SIZE) ) - assert jnp.allclose(hax_val, fa_val, atol=1e-3, rtol=1e-3) + assert_trees_all_close(hax_val, fa_val, atol=1e-3, rtol=1e-3) assert hax_dq.axes == fa_dq.axes assert hax_dk.axes == fa_dk.axes assert hax_dv.axes == fa_dv.axes - assert jnp.allclose(hax_dq.array, fa_dq.array, atol=1e-3, rtol=1e-3) - assert jnp.allclose(hax_dk.array, fa_dk.array, atol=1e-3, rtol=1e-3) - assert jnp.allclose(hax_dv.array, fa_dv.array, atol=1e-3, rtol=1e-3) + assert_trees_all_close(hax_dq.array, fa_dq.array, atol=1e-3, rtol=1e-3) + assert_trees_all_close(hax_dk.array, fa_dk.array, atol=1e-3, rtol=1e-3) + assert_trees_all_close(hax_dv.array, fa_dv.array, atol=1e-3, rtol=1e-3) @pytest.mark.parametrize("num_kv_heads", [1, 2, 4]) @@ -109,14 +114,14 @@ def d_attn(qkv, fn): (q, k, v), functools.partial(flash_attention, inference=True, block_size=BLOCK_SIZE, mask=mask) ) - assert jnp.allclose(hax_val, fa_val, atol=1e-3, rtol=1e-3) + assert_trees_all_close(hax_val, fa_val, atol=1e-3, rtol=1e-3) assert hax_dq.axes == fa_dq.axes assert hax_dk.axes == fa_dk.axes assert hax_dv.axes == fa_dv.axes - assert jnp.allclose(hax_dq.array, fa_dq.array, atol=1e-3, rtol=1e-3) - assert jnp.allclose(hax_dk.array, fa_dk.array, atol=1e-3, rtol=1e-3) - assert jnp.allclose(hax_dv.array, fa_dv.array, atol=1e-3, rtol=1e-3) + assert_trees_all_close(hax_dq.array, fa_dq.array, atol=1e-3, rtol=1e-3) + assert_trees_all_close(hax_dk.array, fa_dk.array, atol=1e-3, rtol=1e-3) + assert_trees_all_close(hax_dv.array, fa_dv.array, atol=1e-3, rtol=1e-3) def test_fa_dropout_does_something(): @@ -165,4 +170,4 @@ def test_tpu_flash_attention(): hax_out = hnn.attention.dot_product_attention(KPos, Key, q, k, v, mask=mask.materialize(QPos, KPos)) assert hax_out.axes == flash_out.axes - assert jnp.allclose(hax_out.array, flash_out.array, atol=1e-3, rtol=1e-3) + assert_trees_all_close(hax_out.array, flash_out.array, atol=1e-3, rtol=1e-3) diff --git a/tests/test_gemma.py b/tests/test_gemma.py new file mode 100644 index 000000000..35a7b9084 --- /dev/null +++ b/tests/test_gemma.py @@ -0,0 +1,258 @@ +import tempfile + +import equinox as eqx +import jax +import numpy as np +import pytest +import transformers +from jax import random + +import haliax as hax + +from levanter.models.attention import AttentionMask +from levanter.models.gemma import GemmaConfig, GemmaDecoderLayer, GemmaLMHeadModel, GemmaRMSNorm +from levanter.utils.jax_utils import parameter_count +from test_utils import check_load_config, check_model_works_with_seqlen, parameterize_with_configs, skip_if_no_torch + + +# N.B. Gemma uses LLamaAttention directly so we skip tests for attention and rotary embeddings. + + +@skip_if_no_torch +def test_gemma_config(): + # load HF config and convert to levanter config + hf_config = transformers.GemmaConfig.from_pretrained("google/gemma-2b") + gemma_config = GemmaConfig.from_hf_config(hf_config) + + # convert back to HF config + config_overrides = { + "_name_or_path": hf_config._name_or_path, + "architectures": hf_config.architectures, + "torch_dtype": hf_config.torch_dtype, + } + new_hf_config = gemma_config.to_hf_config( + vocab_size=hf_config.vocab_size, + config_overrides=config_overrides, + ) + + # Gemma has some weird patched behavior in the HF configuration to deal with the original + # version not using an approximate gelu layer: the configuration has both `hidden_act` and + # `hidden_activation` fields. We don't touch the `hidden_act` field, and it is overridden + # by `hidden_activation`. + # See https://github.com/huggingface/transformers/pull/29402 for more info. + assert gemma_config.activation_function == "gelu" + assert new_hf_config.hidden_activation == "gelu_pytorch_tanh" + assert new_hf_config.hidden_act == "gelu_pytorch_tanh" + + # assert the content in new_hf_config is the same as hf_config + for k in new_hf_config.__dict__.keys(): + if k in ["_commit_hash", "transformers_version"]: + continue + + if k in ["hidden_act", "hidden_activation"]: + continue + + assert getattr(new_hf_config, k) == getattr( + hf_config, k + ), f"{k} {getattr(new_hf_config, k)} != {getattr(hf_config, k)}" + + +def test_gemma_param_counts_dont_change_with_seqlen(): + model = GemmaLMHeadModel.init(hax.Axis("v", 2048), _get_gemma_config(seq_len=128), key=random.PRNGKey(0)) + model2 = GemmaLMHeadModel.init(hax.Axis("v", 2048), _get_gemma_config(seq_len=256), key=random.PRNGKey(0)) + assert parameter_count(model) == parameter_count(model2) + + +@skip_if_no_torch +def test_gemma_rms_norm(): + import torch + from transformers.models.gemma.modeling_gemma import GemmaRMSNorm as HFGemmaRMSNorm + + config = _get_gemma_config() + ln = GemmaRMSNorm.init(config.Embed, eps=config.layer_norm_epsilon, use_bias=config.use_bias) + hf_ln = HFGemmaRMSNorm(config.Embed.size, eps=config.layer_norm_epsilon) + + x, _ = _get_random_inputs(config) + x_torch = torch.from_numpy(np.array(x.array)) + + out = ln(x) + hf_out = hf_ln(x_torch) + + assert np.isclose( + hf_out.detach().cpu().numpy(), np.array(out.array), rtol=1e-6, atol=1e-6 + ).all(), f"{hf_out} != {out}" + + +@skip_if_no_torch +@pytest.mark.parametrize("num_kv_heads", [1, 2, 4]) +def test_gemma_decoder_layer(num_kv_heads): + import torch + from transformers.models.gemma.modeling_gemma import GemmaDecoderLayer as HFGemmaDecoderLayer + + gemma_config = _get_gemma_config(num_kv_heads=num_kv_heads) + key = random.PRNGKey(0) + gemma_decoder_layer = GemmaDecoderLayer.init(config=gemma_config, key=key) + + state = gemma_decoder_layer.to_state_dict() + state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()} + hf_decoder_layer = HFGemmaDecoderLayer(gemma_config.to_hf_config(32000), layer_idx=0) + hf_decoder_layer.load_state_dict(state, strict=True) + + x, mask = _get_random_inputs(gemma_config) + x_torch = torch.from_numpy(np.array(x.array)) + batch_size = x_torch.shape[0] + explicit_mask = torch.from_numpy(np.array(mask.materialize(gemma_config.Pos, gemma_config.KeyPos).array)) + mask_torch = explicit_mask.broadcast_to((batch_size, 1, -1, -1)) + mask_torch = (mask_torch == 0).float() * -1e9 + + position_ids = torch.arange(gemma_config.Pos.size).reshape(1, -1) + + out = gemma_decoder_layer(x, mask) + hf_out = hf_decoder_layer(x_torch, position_ids=position_ids, attention_mask=mask_torch) + + assert np.isclose( + hf_out[0].detach().cpu().numpy(), np.array(out.array), rtol=1e-4, atol=1e-4 + ).all(), f"{hf_out[0]} != {out}" + + +@pytest.mark.parametrize("num_kv_heads", [1, 2, 4]) +def test_gemma_lm_head_model(num_kv_heads): + gemma_config = _get_gemma_config(num_kv_heads=num_kv_heads) + Batch = hax.Axis("batch", 2) + Vocab = hax.Axis("vocab", 1000) + Pos = gemma_config.Pos + input_ids = hax.random.randint(random.PRNGKey(0), (Batch, Pos), 0, Vocab.size) + mask = AttentionMask.causal() + + gemma_model = GemmaLMHeadModel.init(Vocab=Vocab, config=gemma_config, key=random.PRNGKey(0)) + out = gemma_model(input_ids, mask) + assert out.array.shape == (Batch.size, Pos.size, Vocab.size) + + +@pytest.mark.parametrize("use_flash", [True, False]) +@pytest.mark.parametrize("num_kv_heads", [1, 2, 4]) +def test_gemma_lm_head_model_bwd(use_flash, num_kv_heads): + gemma_config = _get_gemma_config(use_flash=use_flash, num_kv_heads=num_kv_heads) + Batch = hax.Axis("batch", 2) + Vocab = hax.Axis("vocab", 1000) + Pos = gemma_config.Pos + input_ids = hax.random.randint(random.PRNGKey(0), (Batch, Pos), 0, Vocab.size) + mask = AttentionMask.causal() + + gemma_model = GemmaLMHeadModel.init(Vocab=Vocab, config=gemma_config, key=random.PRNGKey(0)) + + def f(gemma_model, input_ids, mask): + out = gemma_model(input_ids, mask) + return hax.sum(out).scalar() + + _, grads = eqx.filter_value_and_grad(f)(gemma_model, input_ids, mask) + + +@skip_if_no_torch +@pytest.mark.parametrize("scan_layers", [True, False]) +@pytest.mark.parametrize("num_kv_heads", [1, 2, 4]) +def test_gemma_roundtrip(scan_layers, num_kv_heads): + import torch + from transformers import AutoModelForCausalLM, GemmaForCausalLM + + converter = GemmaConfig.default_hf_checkpoint_converter + + config = GemmaConfig( + seq_len=128, + hidden_dim=16, + num_heads=4, + num_kv_heads=num_kv_heads, + gradient_checkpointing=False, + scan_layers=scan_layers, + ) + Vocab = hax.Axis("vocab", 1000) + hf_config = config.to_hf_config(Vocab.size) + + # Make input and attn_mask + input = hax.random.randint(random.PRNGKey(0), config.Pos, 0, Vocab.size) + attn_mask = AttentionMask.causal() + input_torch = torch.from_numpy(np.array(input.array)).to(torch.int32).unsqueeze(0) + + torch.random.manual_seed(0) + + torch_model = GemmaForCausalLM(hf_config) + torch_model.eval() + + torch_out = torch_model(input_torch) + torch_out = torch_out.logits[0].detach().cpu().numpy() + torch_out = jax.nn.softmax(torch_out, axis=-1) + + with tempfile.TemporaryDirectory() as tmpdir: + torch_model.save_pretrained(f"{tmpdir}/torch_model") + + model = converter.load_pretrained( + GemmaLMHeadModel, f"{tmpdir}/torch_model", resize_vocab_to_match_tokenizer=False + ) + + def compute(input): + model_output = model(input, attn_mask=attn_mask) + return hax.nn.softmax(model_output, axis=model.Vocab) + + compute = jax.jit(compute) + jax_out = compute(input).array + + assert torch_out.shape == jax_out.shape, f"{torch_out.shape} != {jax_out.shape}" + assert np.isclose(torch_out, np.array(jax_out), rtol=1e-3, atol=1e-3).all(), f"{torch_out} != {jax_out}" + + converter.save_pretrained(model, f"{tmpdir}/lev_model", save_reference_code=False) + torch_model2 = AutoModelForCausalLM.from_pretrained(f"{tmpdir}/lev_model") + torch_model2.eval() + + torch_out2 = torch_model2(input_torch) + torch_out2 = torch_out2.logits[0].detach().cpu().numpy() + torch_out2 = jax.nn.softmax(torch_out2, axis=-1) + assert torch_out2.shape == jax_out.shape, f"{torch_out2.shape} != {jax_out.shape}" + assert np.isclose(torch_out2, np.array(jax_out), rtol=1e-3, atol=1e-3).all(), f"{torch_out2} != {jax_out}" + + +def _get_gemma_config(use_flash=False, num_kv_heads=4, seq_len=128) -> GemmaConfig: + rope_scaling = { + "type": "linear", + "factor": 2.0, + } + return GemmaConfig( + seq_len=seq_len, + hidden_dim=16, + num_heads=4, + num_kv_heads=num_kv_heads, + rope_scaling=rope_scaling, + gradient_checkpointing=False, # disable for tests so debugging is easier + use_flash_attention=use_flash, + flash_attention_block_size=8 if use_flash else None, + ) + + +def _get_random_inputs(config: GemmaConfig): + Embed = config.Embed + Pos = config.Pos + Batch = hax.Axis("batch", 2) + x = hax.random.normal(random.PRNGKey(0), (Batch, Pos, Embed)) + mask = AttentionMask.causal() + return x, mask + + +@parameterize_with_configs("gemma*.yaml") +def test_gemma_configs(config_file): + from levanter.main.train_lm import TrainLmConfig + + config_class = TrainLmConfig + + check_load_config(config_class, config_file) + + +@pytest.mark.parametrize("num_kv_heads", [1, 2]) +def test_pass_different_length_seq(num_kv_heads): + config = GemmaConfig( + seq_len=64, + hidden_dim=64, + intermediate_dim=32, + num_heads=2, + num_kv_heads=num_kv_heads, + use_flash_attention=True, + ) + check_model_works_with_seqlen(GemmaLMHeadModel, config, 16) diff --git a/tests/test_grad_accum.py b/tests/test_grad_accum.py index 855873b26..4ca151589 100644 --- a/tests/test_grad_accum.py +++ b/tests/test_grad_accum.py @@ -1,7 +1,7 @@ import equinox as eqx import jax -import jax.numpy as jnp import pytest +from chex import assert_trees_all_close from jax.sharding import Mesh import haliax as hax @@ -28,9 +28,9 @@ def init(In: hax.Axis, Out: hax.Axis, Mid: hax.Axis, *, key): return Mlp(w_in, w_out, In, Out, Mid) def __call__(self, x): - x = hax.dot(self.In, self.w_in, x) + x = hax.dot(self.w_in, x, axis=self.In) x = hnn.relu(x) - x = hax.dot(self.Mid, self.w_out, x) + x = hax.dot(self.w_out, x, axis=self.Mid) return x @@ -69,7 +69,7 @@ def jit_grad_accum(mlp, x): acc_v, acc_g = jit_grad_accum(mlp, x) v, g = grad_fn(mlp, x) - assert jnp.allclose(acc_v, v, atol=1e-3, rtol=1e-3) + assert_trees_all_close(acc_v, v, atol=1e-3, rtol=1e-3) for l1, l2 in zip(jax.tree_util.tree_leaves(acc_g), jax.tree_util.tree_leaves(g)): - assert jnp.allclose(l1, l2, atol=1e-3, rtol=1e-3) + assert_trees_all_close(l1, l2, atol=1e-3, rtol=1e-3) diff --git a/tests/test_hf_checkpoints.py b/tests/test_hf_checkpoints.py index 157d80e22..7416214c0 100644 --- a/tests/test_hf_checkpoints.py +++ b/tests/test_hf_checkpoints.py @@ -1,8 +1,10 @@ import tempfile import jax.numpy as jnp +import jmp import numpy as np import pytest +from chex import assert_trees_all_close, assert_trees_all_equal from jax.random import PRNGKey import haliax @@ -75,7 +77,10 @@ def test_save_backpack_model_with_code(): torch_input = torch.from_numpy(np.array(input.array)).to(torch.int64).unsqueeze(0) loaded_model.eval() np.testing.assert_allclose( - model(torch_input).logits[0].detach().numpy(), loaded_model(torch_input).logits[0].detach().numpy() + model(torch_input).logits[0].detach().numpy(), + loaded_model(torch_input).logits[0].detach().numpy(), + rtol=1e-3, + atol=1e-3, ) @@ -90,7 +95,7 @@ def test_conversion_to_jnp_bfloat16(): x_jnp = _convert_to_jnp(x, None) assert x_jnp.dtype == jnp.bfloat16 assert x_jnp.shape == x.shape - assert jnp.allclose(x_jnp, jnp.arange(10, dtype=jnp.bfloat16) / 3.14) + assert_trees_all_close(x_jnp, jnp.arange(10, dtype=jnp.bfloat16) / 3.14) def test_save_sharded_checkpoints(): @@ -100,6 +105,9 @@ def test_save_sharded_checkpoints(): nano_model = Gpt2LMHeadModel.init(converter.Vocab, nano_config, key=PRNGKey(3)) + mp = jmp.get_policy("f32") + nano_model = mp.cast_to_param(nano_model) + with tempfile.TemporaryDirectory() as tmpdir: converter.save_pretrained(nano_model, tmpdir, max_shard_size=1024) @@ -108,16 +116,12 @@ def test_save_sharded_checkpoints(): assert len(glob.glob(tmpdir + "/*.safetensors")) > 1 - loaded_model = converter.load_pretrained(nano_model.config, ref=tmpdir) + loaded_model = converter.load_pretrained(nano_model.config, ref=tmpdir, dtype=mp.param_dtype) assert loaded_model.config == nano_model.config assert loaded_model.Vocab == nano_model.Vocab - input = haliax.random.randint(PRNGKey(0), nano_model.config.Pos, 0, nano_model.Vocab.size) - causal_mask = AttentionMask.causal() - np.testing.assert_allclose( - np.array(nano_model(input, causal_mask, key=None).array), - np.array(loaded_model(input, causal_mask, key=None).array), - rtol=1e-6, - atol=1e-6, + assert_trees_all_equal( + nano_model, + loaded_model, ) diff --git a/tests/test_llama.py b/tests/test_llama.py index 6099aa8ae..d6cb8d5eb 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -62,7 +62,7 @@ def test_llama_rotary_embedding(): levanter_output = llama_rotary_pos_emb(HeadSize=HeadSize, Pos=Pos) hf_rope = HFLlamaRotaryEmbedding(dim=hidden_dim, max_position_embeddings=seq_len, device=device) - hf_output = hf_rope(x_torch, torch.arange(seq_len).reshape(1, -1), seq_len=seq_len) + hf_output = hf_rope(x_torch, torch.arange(seq_len).reshape(1, -1)) for jax_out, torch_out in zip(levanter_output, hf_output): torch_out = torch_out.numpy() @@ -295,7 +295,7 @@ def compute(input): jax_out = compute(input).array assert torch_out.shape == jax_out.shape, f"{torch_out.shape} != {jax_out.shape}" - assert np.isclose(torch_out, np.array(jax_out), rtol=1e-2, atol=1e-2).all(), f"{torch_out} != {jax_out}" + assert np.isclose(torch_out, np.array(jax_out), rtol=1e-4, atol=1e-4).all(), f"{torch_out} != {jax_out}" converter.save_pretrained(model, f"{tmpdir}/lev_model", save_reference_code=False) torch_model2 = AutoModelForCausalLM.from_pretrained(f"{tmpdir}/lev_model") @@ -305,7 +305,7 @@ def compute(input): torch_out2 = torch_out2.logits[0].detach().cpu().numpy() torch_out2 = jax.nn.softmax(torch_out2, axis=-1) assert torch_out2.shape == jax_out.shape, f"{torch_out2.shape} != {jax_out.shape}" - assert np.isclose(torch_out2, np.array(jax_out), rtol=1e-2, atol=1e-2).all(), f"{torch_out2} != {jax_out}" + assert np.isclose(torch_out2, np.array(jax_out), rtol=1e-4, atol=1e-4).all(), f"{torch_out2} != {jax_out}" def _get_llama_config(use_flash=False, num_kv_heads=4, seq_len=128) -> LlamaConfig: diff --git a/tests/test_longformer.py b/tests/test_longformer.py index b7ae2c7e1..c964499a0 100644 --- a/tests/test_longformer.py +++ b/tests/test_longformer.py @@ -1,6 +1,7 @@ import jax import jax.numpy as jnp import numpy as np +from chex import assert_trees_all_close import haliax as hax from haliax import Axis @@ -32,8 +33,8 @@ def test_causal_sliding_window_attention_simple(): # we should be able to attend to the previous W positions for each position (including current), so 6-10 can't attend # to 0-4 and can't get the 100.0 key result = result.rearrange((Pos, Head)).array - assert jnp.allclose(result[0:W, 1], 300) - assert jnp.allclose(result[W:, 1], 0) + assert_trees_all_close(result[0:W, 1], 300) + assert_trees_all_close(result[W:, 1], 0) def test_sliding_window_attention_fancier(): @@ -64,7 +65,7 @@ def test_sliding_window_attention_fancier(): expected = expected.rearrange((Pos, Head)).array - assert jnp.allclose(result, expected, atol=1e-3, rtol=1e-3) + assert_trees_all_close(result, expected, atol=1e-3, rtol=1e-3) def test_longformer_alibi_bias_pos_invariance(): diff --git a/tests/test_mistral.py b/tests/test_mistral.py index 6a8153a2e..1d82dca44 100644 --- a/tests/test_mistral.py +++ b/tests/test_mistral.py @@ -122,7 +122,7 @@ def compute(input): jax_out = compute(input).array assert torch_out.shape == jax_out.shape, f"{torch_out.shape} != {jax_out.shape}" - assert np.isclose(torch_out, np.array(jax_out), rtol=1e-2, atol=1e-2).all(), f"{torch_out} != {jax_out}" + assert np.isclose(torch_out, np.array(jax_out), rtol=1e-4, atol=1e-4).all(), f"{torch_out} != {jax_out}" converter.save_pretrained(model, f"{tmpdir}/lev_model", save_reference_code=False) torch_model2 = AutoModelForCausalLM.from_pretrained(f"{tmpdir}/lev_model") @@ -132,7 +132,7 @@ def compute(input): torch_out2 = torch_out2.logits[0].detach().cpu().numpy() torch_out2 = jax.nn.softmax(torch_out2, axis=-1) assert torch_out2.shape == jax_out.shape, f"{torch_out2.shape} != {jax_out.shape}" - assert np.isclose(torch_out2, np.array(jax_out), rtol=1e-2, atol=1e-2).all(), f"{torch_out2} != {jax_out}" + assert np.isclose(torch_out2, np.array(jax_out), rtol=1e-4, atol=1e-4).all(), f"{torch_out2} != {jax_out}" def _get_mistral_config(use_flash=False, num_kv_heads=4) -> MistralConfig: diff --git a/tests/test_shard_cache.py b/tests/test_shard_cache.py index 9e6eee607..1743a1c1f 100644 --- a/tests/test_shard_cache.py +++ b/tests/test_shard_cache.py @@ -8,7 +8,7 @@ from levanter.data._preprocessor import BatchProcessor from levanter.data.shard_cache import ChunkMetadata, SerialCacheWriter, _get_broker_actor, build_cache -from levanter.data.sharded_dataset import ShardedDataset +from levanter.data.sharded_dataset import ShardedDataset, TextUrlDataset from levanter.utils.py_utils import logical_cpu_core_count @@ -320,3 +320,48 @@ def freeze_batch(batch): return tuple(batch["test"].values.to_numpy()) assert set(freeze_batch(batch) for batch in serial) == set(freeze_batch(batch) for batch in ray_ds) + + +@pytest.mark.ray +def test_shard_cache_fails_with_multiple_shards_with_the_same_name(): + with tempfile.TemporaryDirectory() as tmpdir: + with open(f"{tmpdir}/data.txt", "w") as f: + f.write("") + + with pytest.raises(ValueError): + TextUrlDataset( + [f"{tmpdir}/data.txt", f"{tmpdir}/data.txt"], + ) + + with open(f"{tmpdir}/data.txt.1", "w") as f: + f.write("") + + dataset = TextUrlDataset( + [f"{tmpdir}/data.txt", f"{tmpdir}/data.txt.1"], + ) + + build_cache(tmpdir, dataset, TestProcessor(), await_finished=True) + + +@pytest.mark.ray +def test_shard_cache_fails_gracefully_with_unknown_file_type(): + with tempfile.TemporaryDirectory() as tmpdir: + with open(f"{tmpdir}/data.not_a_real_extension", "w") as f: + f.write("") + + dataset = TextUrlDataset( + [f"{tmpdir}/data.not_a_real_extension"], + ) + + with pytest.raises(ValueError): + build_cache(tmpdir, dataset, TestProcessor(), await_finished=True) + + # now make sure it works in non-blocking mode + + cache = build_cache(tmpdir, dataset, TestProcessor(), await_finished=False) + + with pytest.raises(ValueError): + cache.get_chunk(0, timeout=5) + + with pytest.raises(ValueError): + cache.await_finished(timeout=10) diff --git a/tests/test_sophia.py b/tests/test_sophia.py index 1ca3a7265..282d89d07 100644 --- a/tests/test_sophia.py +++ b/tests/test_sophia.py @@ -6,6 +6,7 @@ import jax import jax.numpy as jnp import numpy as np +from chex import assert_trees_all_close import levanter import levanter.optim.sophia @@ -42,7 +43,7 @@ def loss_fn(model, data): # print('Test-estimated hessian: most coordinates should be approximately 2') # print('Estimated hessian:', opt_state[0].h.weight) - assert jnp.allclose(opt_state[0].h.weight, 2, rtol=0.2, atol=0.3) # this is very approximate + assert_trees_all_close(opt_state[0].h.weight, 2, rtol=0.2, atol=0.3) # this is very approximate grad_loss_fn = eqx.filter_jit(eqx.filter_value_and_grad(loss_fn)) @@ -50,11 +51,10 @@ def loss_fn(model, data): model_updates, opt_state = optimizer.update(grad, opt_state, params=model, obj_fn=obj_fn) model = eqx.apply_updates(model, model_updates) - # loss should be 15.74834156036377 - assert jnp.allclose(loss, 15.74834156036377) + assert_trees_all_close(loss, 15.74834156036377, rtol=1e-3, atol=1e-3) # print("Test-model param after 1 step: most coordinates should be very loosely 0.5") - assert jnp.allclose(model.weight, 0.5, rtol=0.2, atol=0.1) # this is very approximate + assert_trees_all_close(model.weight, 0.5, rtol=0.2, atol=0.1) # this is very approximate # print("Test-loss: loss should shrink by approximately 75% after each iteration") for i in range(10): diff --git a/tests/test_text.py b/tests/test_text.py index 70b2d26a7..a9d407b44 100644 --- a/tests/test_text.py +++ b/tests/test_text.py @@ -1,14 +1,12 @@ import tempfile import jax.numpy as jnp -from transformers import AutoTokenizer import haliax as hax -from levanter.data.text import BatchTokenizer, LMDatasetConfig +from levanter.data.text import LMDatasetConfig from levanter.models.lm_model import LmExample from levanter.models.loss import next_token_loss -from test_utils import skip_if_hf_model_not_accessible def test_dont_blow_up_without_validation_set(): @@ -41,29 +39,3 @@ def test_lm_example_handles_ignore_id(): no_ignore_loss = next_token_loss(Pos, Vocab, distr, tokens, loss_mask=ex_no_ignore.loss_mask) assert no_ignore_loss.item() >= ignored_loss.item() + 100 / Pos.size - - -def test_merge_split_encodings(): - tokenizer = AutoTokenizer.from_pretrained("gpt2") - # make this very short for testing - - lorem = """Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.""" - - short_batch_tokenizer = BatchTokenizer(tokenizer, _workaround_len=len(lorem) // 3) - # force this - short_batch_tokenizer._needs_long_sequence_workaround = True - - batch_tokenizer = BatchTokenizer(tokenizer, _workaround_len=50000) - batch = [lorem] - - short_out = short_batch_tokenizer(batch) - reg_out = batch_tokenizer(batch) - - assert short_out == reg_out - - -@skip_if_hf_model_not_accessible("meta-llama/Llama-2-7b-hf") -def test_llama_tokenizer_needs_long_sequence_workaround(): - tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") - batch_tokenizer = BatchTokenizer(tokenizer) - assert batch_tokenizer._needs_long_sequence_workaround diff --git a/tests/test_tokenized_document_cache.py b/tests/test_tokenized_document_cache.py index ff7531c77..7798d0e58 100644 --- a/tests/test_tokenized_document_cache.py +++ b/tests/test_tokenized_document_cache.py @@ -36,6 +36,7 @@ def test_index_empty_file(): source, tokenizer, flatten_docs=True, + enforce_bos=False, enforce_eos=False, override_resources={"num_cpus": 1}, ) diff --git a/tests/test_train_lm.py b/tests/test_train_lm.py index 3c3376620..2d870d483 100644 --- a/tests/test_train_lm.py +++ b/tests/test_train_lm.py @@ -23,7 +23,7 @@ def test_train_lm(): num_heads=2, seq_len=64, hidden_dim=32, - use_flash_attention=True, + attn_backend=None, # use default for platform ), trainer=train_lm.TrainerConfig( num_train_steps=2,