Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Try launching unit tests on TPUs from CI #596

Merged
merged 31 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions .github/workflows/tpu_unit_tests.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
name: CI with GCP TPU

on: [pull_request]

jobs:
test:
runs-on: ubuntu-latest
env:
TPU_ZONE: "us-central2-b"

steps:
- name: Checkout code
uses: actions/checkout@v2

- name: Set up Google Cloud SDK
uses: google-github-actions/setup-gcloud@v1
with:
project_id: ${{ secrets.GCP_PROJECT_ID }}

- name: Authenticate to Google Cloud
uses: google-github-actions/auth@v1
with:
credentials_json: ${{ secrets.GCP_SA_KEY }}

- name: Configure Google Cloud
run: |
gcloud config set project ${{ secrets.GCP_PROJECT_ID }}

- name: Create VM
run: |
export TPU_NAME=ci-run-${{ github.run_id }}
eval "$(ssh-agent -s)"
TRUE_SHA=${{ github.event.pull_request.head.sha }}
bash infra/spin-up-vm.sh $TPU_NAME -z ${TPU_ZONE} -t v4-8 --preemptible -s infra/helpers/setup-tpu-vm-tests.sh -b ${TRUE_SHA} --retries 1
# infra/babysit-tpu-vm.sh $TPU_NAME -z ${{ TPU_ZONE }} -t v4-8 --preemptible -s infra/helpers/setup-tpu-vm-tests.sh -b ${{ github.sha }} --retries 1 -- \
# PYTHONPATH=$PYTHONPATH:levanter/tests bash levanter/infra/run.sh pytest levanter/tests -m "not entry"

- name: Run most tests
run: |
export TPU_NAME=ci-run-${{ github.run_id }}
gcloud compute tpus tpu-vm ssh $TPU_NAME --zone ${TPU_ZONE} --command "PYTHONPATH=$PYTHONPATH:levanter/tests bash levanter/infra/run.sh pytest levanter/tests -m 'not entry'"
# Something's wrong with these
#
# - name: Run forked tests
# run: |
# export TPU_NAME=ci-run-${{ github.run_id }}
# gcloud compute tpus tpu-vm ssh $TPU_NAME --zone ${TPU_ZONE} --command "PYTHONPATH=$PYTHONPATH:levanter/tests bash levanter/infra/run.sh pytest --forked levanter/tests -m 'entry'"
#
- name: Cleanup
if: ${{ always() }}
run: |
export TPU_NAME=ci-run-${{ github.run_id }}
echo gcloud compute tpus tpu-vm delete $TPU_NAME --zone ${TPU_ZONE} --quiet
gcloud compute tpus tpu-vm delete $TPU_NAME --zone ${TPU_ZONE} --quiet
21 changes: 19 additions & 2 deletions infra/babysit-tpu-vm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ CMD_ARGS_STR=$(printf ' %s' "${CMD_ARGS[@]}")
CMD_ARGS_STR=${CMD_ARGS_STR:1}
CMD_ARGS_STR="RUN_ID=${RUN_ID} ${CMD_ARGS_STR}"

TRIES=0

# check if the VM is running
# if not, spin it up
# if it is, just run the command
Expand All @@ -77,11 +79,19 @@ while true; do
echo "Running command on VM $VM_NAME"
echo "gcloud compute tpus tpu-vm ssh --zone=$ZONE $VM_NAME --command='$CMD_ARGS_STR' --worker=all"
gcloud compute tpus tpu-vm ssh --zone=$ZONE $VM_NAME --command="$CMD_ARGS_STR" --worker=all
if [ $? -eq 0 ]; then
EXIT_CODE=$?
if [ $EXIT_CODE -eq 0 ]; then
echo "Command succeeded. Exiting"
break
else
echo "Command failed"
TRIES=$((TRIES+1))
if [ "$RETRIES" -ge 0 ]; then
if [ $TRIES -ge "$RETRIES" ]; then
echo "Command failed $TRIES times, exiting with $EXIT_CODE"
break
fi
fi
fi
fi
else
Expand All @@ -92,11 +102,18 @@ while true; do
sleep 10
done

echo "Job finished!"
# exit code is the exit code of the command
if [ $EXIT_CODE -eq 0 ]; then
echo "Command succeeded"
else
echo "Command failed too many times, ending with exit code $EXIT_CODE"
fi

# delete the VM when we're done
gcloud compute tpus tpu-vm describe --zone $ZONE $VM_NAME &> /dev/null
if [ $? -eq 0 ]; then
echo "Deleting VM $VM_NAME"
yes | gcloud compute tpus tpu-vm delete --zone $ZONE $VM_NAME
fi

exit $EXIT_CODE
37 changes: 25 additions & 12 deletions infra/helpers/parse-tpu-creation-args.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ AUTODELETE=true
SETUP_SCRIPT="$SCRIPT_DIR/helpers/setup-tpu-vm.sh"
SUBNETWORK="default"
USE_ALPHA=false
RETRIES=-1 # how many times babysit-tpu-vm.sh should retry before giving up. -1 means infinite

if [ -z "$GIT_BRANCH" ]; then
GIT_BRANCH=$(git rev-parse --abbrev-ref HEAD)
Expand Down Expand Up @@ -86,6 +87,11 @@ while [[ $# -gt 0 ]]; do
USE_ALPHA="true"
shift # past argument
;;
--retries)
RETRIES="$2"
shift # past argument
shift # past value
;;
*) # unknown option, assume it's the vm name if it doesn't start with a dash
if [[ $1 == -* ]]; then
echo "Error: unknown option $1" >&2
Expand Down Expand Up @@ -115,19 +121,26 @@ done

# check if the branch we chose has been pushed to the remote
# if not, warn

# get the remote branch name
REMOTE_BRANCH=$(git ls-remote --heads origin "$GIT_BRANCH" | awk '{print $2}' | sed 's/refs\/heads\///g')
# if it's empty, warn
if [ -z "$REMOTE_BRANCH" ]; then
>&2 echo "Warning: branch $GIT_BRANCH not found on remote $GIT_REPO"
# if it's a commit sha/short-sha (or something that looks like one), check if it's in the remote
if [[ "$GIT_BRANCH" =~ ^[0-9a-f]{7,40}$ ]]; then
# if it's a commit, check if it's in the remote
BRANCHES=$(git branch -r --contains "$GIT_BRANCH")
if [ -z "$BRANCHES" ]; then
>&2 echo "Warning: commit $GIT_BRANCH not found on remote $GIT_REPO"
fi
else
# get the remote branch name
REMOTE_BRANCH=$(git ls-remote --heads origin "$GIT_BRANCH" | awk '{print $2}' | sed 's/refs\/heads\///g')
# if it's empty, warn
if [ -z "$REMOTE_BRANCH" ]; then
>&2 echo "Warning: branch $GIT_BRANCH not found on remote $GIT_REPO"
else
# make sure it's pushed
LOCAL_COMMIT=$(git rev-parse --short "$GIT_BRANCH")
REMOTE_COMMIT=$(git rev-parse --short "origin/$REMOTE_BRANCH")

# make sure it's pushed
LOCAL_COMMIT=$(git rev-parse --short "$GIT_BRANCH")
REMOTE_COMMIT=$(git rev-parse --short "origin/$REMOTE_BRANCH")

if [ "$LOCAL_COMMIT" != "$REMOTE_COMMIT" ]; then
>&2 echo "Warning: branch $GIT_BRANCH not pushed to remote $GIT_REPO. Local commit: $LOCAL_COMMIT, remote commit: $REMOTE_COMMIT"
if [ "$LOCAL_COMMIT" != "$REMOTE_COMMIT" ]; then
>&2 echo "Warning: branch $GIT_BRANCH not pushed to remote $GIT_REPO. Local commit: $LOCAL_COMMIT, remote commit: $REMOTE_COMMIT"
fi
fi
fi
68 changes: 0 additions & 68 deletions infra/helpers/setup-tpu-vm-nfs.sh

This file was deleted.

126 changes: 126 additions & 0 deletions infra/helpers/setup-tpu-vm-tests.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# broadly based on https://github.com/ayaka14732/tpu-starter

# parse some arguments
# usage: ./setup-tpu-vm.sh -b|--branch <git commit or branch for levanter> -r <git repo for levanter>

if [ "$DEBUG" == "1" ]; then
set -x
fi

REPO="https://github.com/stanford-crfm/levanter.git"
BRANCH=main

if [ "$GIT_BRANCH" != "" ]; then
BRANCH="$GIT_BRANCH"
fi

while [[ $# -gt 0 ]]; do
key="$1"
case $key in
-b|--branch)
BRANCH="$2"
shift
shift
;;
-r|--repo)
REPO="$2"
shift
shift
;;
*)
>&2 echo "Unknown option $1"
exit 1
;;
esac
done

# we frequently deal with commands failing, and we like to loop until they succeed. this function does that for us
function retry {
for i in {1..5}; do
$@
if [ $? -eq 0 ]; then
break
fi
if [ $i -eq 5 ]; then
>&2 echo "Error running $*, giving up"
exit 1
fi
>&2 echo "Error running $*, retrying in 5 seconds"
sleep 5
done
}

# tcmalloc interferes with intellij remote ide
sudo patch -f -b /etc/environment << EOF
2c2
< LD_PRELOAD="/usr/lib/x86_64-linux-gnu/libtcmalloc.so.4"
---
> #LD_PRELOAD="/usr/lib/x86_64-linux-gnu/libtcmalloc.so.4"
EOF



# don't complain if already applied
retCode=$?
[[ $retCode -le 1 ]] || exit $retCode


# set these env variables b/c it makes tensorstore behave better
if ! grep -q TENSORSTORE_CURL_LOW_SPEED_TIME_SECONDS /etc/environment; then
# need sudo
echo "TENSORSTORE_CURL_LOW_SPEED_TIME_SECONDS=60" | sudo tee -a /etc/environment > /dev/null
fi

if ! grep -q TENSORSTORE_CURL_LOW_SPEED_LIMIT_BYTES /etc/environment; then
echo "TENSORSTORE_CURL_LOW_SPEED_LIMIT_BYTES=1024" | sudo tee -a /etc/environment > /dev/null
fi

# install python 3.10, latest git
sudo systemctl stop unattended-upgrades # this frequently holds the apt lock
sudo systemctl disable unattended-upgrades
sudo apt remove -y unattended-upgrades
# if it's still running somehow, kill it
if [ $(ps aux | grep unattended-upgrade | wc -l) -gt 1 ]; then
sudo kill -9 $(ps aux | grep unattended-upgrade | awk '{print $2}')
fi

# sometimes apt-get update fails, so retry a few times
retry sudo apt-get install -y software-properties-common
retry sudo add-apt-repository -y ppa:deadsnakes/ppa
retry sudo add-apt-repository -y ppa:git-core/ppa
retry sudo apt-get -qq update
retry sudo apt-get -qq install -y python3.10-full python3.10-dev git

VENV=~/venv310
# if the venv doesn't exist, make it
if [ ! -d "$VENV" ]; then
echo "Creating virtualenv at $VENV"
python3.10 -m venv $VENV
fi

source $VENV/bin/activate

pip install -U pip
pip install -U wheel

# jax and jaxlib
# libtpu sometimes has issues installing for clinical (probably firewall?)
retry pip install -U "jax[tpu]==0.4.26" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

# clone levanter
git clone $REPO levanter
echo $VENV > levanter/infra/venv_path.txt

cd levanter

# checkout the branch we want

echo "Checking out branch $BRANCH"

git checkout $BRANCH

# install levanter

pip install -e .

pip install -r tests/requirements.txt
2 changes: 1 addition & 1 deletion src/levanter/main/cache_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def main(args: RayCachedLMDatasetConfig):
logger.warning(f"Skipping {split} because it is empty.")
continue

monitors = [RichMetricsMonitor(source.num_shards)]
monitors: list = [RichMetricsMonitor(source.num_shards)]
if not isinstance(args.tracker, NoopConfig):
monitors.append(LoggingMetricsMonitor("preprocess/" + split, commit=True))

Expand Down
3 changes: 3 additions & 0 deletions src/levanter/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,9 @@ def _tpu_splash_attention(

q_class, k_class, v_class = _bin_and_group_axes_by_function(query, key, value, QPos, KPos, Key)

# pre-divide q_ by sqrt(d) to match the reference implementation
query = query / jnp.sqrt(query.resolve_axis(Key).size)

q_: jax.Array = _reshape_axes_for_bshd_bins(query, q_class, output_order=list("BHSD")).array
k_ = _reshape_axes_for_bshd_bins(key, k_class, output_order=list("BHSD")).array
v_ = _reshape_axes_for_bshd_bins(value, v_class, output_order=list("BHSD")).array
Expand Down
5 changes: 3 additions & 2 deletions src/levanter/tracker/tracker_fns.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,9 @@ def log_metrics(metrics: dict[str, Any], *, step: Optional[int], commit: Optiona
def _no_throw_log_metrics(metrics: dict[str, Any], *, step: Optional[int], commit: Optional[bool] = None):
try:
if _global_tracker is None:
raise RuntimeError("No global tracker set")
_global_tracker.log(metrics, step=step, commit=False)
warnings.warn("No global tracker set")
else:
_global_tracker.log(metrics, step=step, commit=False)
except Exception:
logger.exception("Error logging metrics")

Expand Down
1 change: 1 addition & 0 deletions tests/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ flake8
pytest
soundfile
librosa
pytest-forked
Loading
Loading