Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added EMA of running weights to the current codebase #853

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions config/llama2_100M_constant_lr2e-3.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
data: !include data/dclm_gpt_neo.yaml
model:
type: llama
seq_len: 4096
hidden_dim: 768
intermediate_dim: 3072
num_layers: 12
num_heads: 12
num_kv_heads: 12
trainer:
tracker:
project: "levanter"
tags: ["pile", "llama"]
mp: p=f32,c=bfloat16
model_axis_size: 1
checkpointer:
keep:
- every: 100
save_interval: 30m


train_batch_size: 1024
per_device_parallelism: 4 # set for v3 TPU
per_device_eval_parallelism: 4 # set a larger batch size for eval
num_train_steps: 1001
optimizer:
learning_rate: 1E-3 # set low for fine-tuning
weight_decay: 0.1
min_lr_ratio: 0.1
warmup: 1000
cooldown: 0.0
lr_schedule: constant
130 changes: 130 additions & 0 deletions infra/babysit-tpu-vm-ema.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
#!/bin/bash

# TRC recently made a change where they unceremoniously kill TPU VMs whenever they need capacity for paying customers
# Understandable, but we have to work around it.
# This script runs on a non-TPU VM (some server somewhere) and periodically relaunches the TPU VM if it's not running
# and restarts the process
# My preference would be to use pdsh for this, but we don't reliably have it on our internal cluster...

# Syntax: babysit-tpu-vm.sh <args to spin-up-vm.sh> -- command to run on the vm


SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"

# first extract the args for just spin-up-vm.sh and pass them to the helper

CREATION_ARGS=()

while [[ $# -gt 0 ]]; do
key="$1"
case $key in
--)
shift
break
;;
*)
CREATION_ARGS+=("$1")
shift
;;
esac
done

source "$SCRIPT_DIR"/helpers/parse-tpu-creation-args.sh "${CREATION_ARGS[@]}"

if [ -z "$VM_NAME" ]; then
echo "Error: VM name not set"
exit 1
fi

if [ -z "$SSH_AUTH_SOCK" ]; then
echo "Error: ssh-agent not running. This script needs to be run from a machine with ssh-agent running. Please run ssh-add ~/.ssh/google_compute_engine and try again"
exit 1
fi

if [ -z "$RUN_ID" ]; then
RUN_ID=$(bash "${SCRIPT_DIR}"/helpers/gen-id.sh)
echo "RUN_ID not set, setting to $RUN_ID"
fi

# set the cmd args. We want to be sure everything is fully quoted when we pass it to the gcloud ssh command
# in case there are spaces in the command (or embedded quotes)
CMD_ARGS=()
for arg in "$@"; do
# need to escape any embedded quotes using printf
CMD_ARGS+=("$(printf '%q' "$arg")")
done

# Now turn CMD_ARGS into a single string we can pass
CMD_ARGS_STR=$(printf ' %s' "${CMD_ARGS[@]}")
CMD_ARGS_STR=${CMD_ARGS_STR:1}
CMD_ARGS_STR="RUN_ID=${RUN_ID} ${CMD_ARGS_STR}"

TRIES=0

# check if the VM is running
# if not, spin it up
# if it is, just run the command
while true; do
# check if it's there
gcloud compute tpus tpu-vm describe --zone $ZONE $VM_NAME &> /dev/null
if [ $? -eq 0 ]; then
# check if it's running
STATE=$(gcloud compute tpus tpu-vm describe --zone $ZONE $VM_NAME | grep state | awk '{print $2}')
if [ "$STATE" != "READY" ]; then
echo "VM $VM_NAME is not in READY state, state is $STATE"
echo "Deleting VM $VM_NAME"
yes | gcloud compute tpus tpu-vm delete --zone $ZONE $VM_NAME
else
# run the command
echo "Running command on VM $VM_NAME"
# customize command
gcloud compute tpus tpu-vm ssh $VM_NAME --zone $ZONE --worker=all --command 'rm -r levanter/config'
gcloud compute tpus tpu-vm ssh $VM_NAME --zone $ZONE --worker=all --command 'rm -r levanter/src/levanter/optim'
gcloud compute tpus tpu-vm scp --recurse config $VM_NAME:levanter/config --zone $ZONE --worker=all
gcloud compute tpus tpu-vm scp --recurse src/levanter/optim $VM_NAME:levanter/src/levanter/optim --zone $ZONE --worker=all
gcloud compute tpus tpu-vm scp src/levanter/models/llama.py $VM_NAME:levanter/src/levanter/models/llama.py --zone $ZONE --worker=all
gcloud compute tpus tpu-vm scp src/levanter/callbacks.py $VM_NAME:levanter/src/levanter/callbacks.py --zone $ZONE --worker=all
gcloud compute tpus tpu-vm scp src/levanter/eval.py $VM_NAME:levanter/src/levanter/eval.py --zone $ZONE --worker=all
gcloud compute tpus tpu-vm scp src/levanter/trainer.py $VM_NAME:levanter/src/levanter/trainer.py --zone $ZONE --worker=all
gcloud compute tpus tpu-vm scp src/levanter/trainer_state.py $VM_NAME:levanter/src/levanter/trainer_state.py --zone $ZONE --worker=all
# run the true thing
echo "gcloud compute tpus tpu-vm ssh --zone=$ZONE $VM_NAME --command='$CMD_ARGS_STR' --worker=all"
gcloud compute tpus tpu-vm ssh --zone=$ZONE $VM_NAME --command="$CMD_ARGS_STR" --worker=all
EXIT_CODE=$?
if [ $EXIT_CODE -eq 0 ]; then
echo "Command succeeded. Exiting"
break
else
echo "Command failed"
TRIES=$((TRIES+1))
if [ "$RETRIES" -ge 0 ]; then
if [ $TRIES -ge "$RETRIES" ]; then
echo "Command failed $TRIES times, exiting with $EXIT_CODE"
break
fi
fi
fi
fi
else
echo "VM $VM_NAME not found, creating it"
bash "$SCRIPT_DIR"/spin-up-vm.sh "${CREATION_ARGS[@]}"
fi
echo "Sleeping for 10s"
sleep 10
done

# exit code is the exit code of the command
if [ $EXIT_CODE -eq 0 ]; then
echo "Command succeeded"
else
echo "Command failed too many times, ending with exit code $EXIT_CODE"
fi

# delete the VM when we're done
gcloud compute tpus tpu-vm describe --zone $ZONE $VM_NAME &> /dev/null
if [ $? -eq 0 ]; then
echo "Deleting VM $VM_NAME"
yes | gcloud compute tpus tpu-vm delete --zone $ZONE $VM_NAME
fi

exit $EXIT_CODE
2 changes: 1 addition & 1 deletion infra/helpers/setup-tpu-vm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ pip install -U wheel

# jax and jaxlib
# libtpu sometimes has issues installing for clinical (probably firewall?)
retru pip install -U "jax[tpu]==0.4.38" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
retry pip install -U "jax[tpu]==0.4.38" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

# clone levanter
git clone $REPO levanter
Expand Down
9 changes: 9 additions & 0 deletions run_ema.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
eval $(ssh-agent -s)
bash infra/babysit-tpu-vm-ema.sh test-ema -z us-central2-b -t v4-128 --preemptible -- \
WANDB_API_KEY=1c85c63399be786e59026e288175122f49a434b0 \
bash levanter/infra/run.sh python \
levanter/src/levanter/main/train_lm.py \
--config_path levanter/config/llama2_100M_constant_lr2e-3.yaml \
--trainer.checkpointer.base_path gs://marin-us-central2/scratch/kaiyue/checkpoints/test_ema \
--optimizer.min_lr_ratio 0.0 \
--trainer.use_ema True
9 changes: 9 additions & 0 deletions run_wo_ema.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
eval $(ssh-agent -s)
bash infra/babysit-tpu-vm-ema.sh test-no-ema -z us-central2-b -t v4-128 --preemptible -- \
WANDB_API_KEY=1c85c63399be786e59026e288175122f49a434b0 \
bash levanter/infra/run.sh python \
levanter/src/levanter/main/train_lm.py \
--config_path levanter/config/llama2_100M_constant_lr2e-3.yaml \
--trainer.checkpointer.base_path gs://marin-us-central2/scratch/kaiyue/checkpoints/test_ema \
--optimizer.min_lr_ratio 0.0 \
--trainer.use_ema False
2 changes: 2 additions & 0 deletions src/levanter/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ class StepInfo(Generic[S]):
opt_state = property(lambda self: self.state.opt_state)

step = property(lambda self: int(self.state.step) - 1)
use_ema = property(lambda self: self.state.use_ema)

"""
The step that was just completed. If you want the next step, use `next_step`.
"""
Expand Down
94 changes: 49 additions & 45 deletions src/levanter/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,54 +193,58 @@ def cb_tagged_lm_evaluate(
)

def eval_callback(step: StepInfo):
results = dict()
with levanter.tracker.capture_time() as time_fn:
result = evaluator.evaluate(step.model)

log_dict = {
# log micro average as just "loss"
_join_prefix(prefix, "loss"): result.micro_avg_loss,
_join_prefix(prefix, "loading_time"): result.total_eval_loading_time,
_join_prefix(prefix, "total_time"): time_fn(),
}

logger.info(f"{prefix} loss: {result.micro_avg_loss:.3f}")
has_tags = len(evaluator.dataset.tag_to_index) > 1 # 1 tag means there's no difference between micro and macro
if has_tags:
log_dict[_join_prefix(prefix, "macro_loss")] = result.macro_avg_loss

for tag, loss in result.tag_macro_losses.items():
# don't log leaf tag macro losses because it doesn't mean anything different than micro loss
if tag in evaluator.dataset.tag_to_index:
continue
original_result = evaluator.evaluate(step.model)
results[prefix] = original_result
if(step.use_ema):
results[_join_prefix(prefix, 'ema')] = evaluator.evaluate(step.state.ema_model)

for p, result in results.items():
log_dict = {
# log micro average as just "loss"
_join_prefix(p, "loss"): result.micro_avg_loss,
_join_prefix(p, "loading_time"): result.total_eval_loading_time,
_join_prefix(p, "total_time"): time_fn(),
}

logger.info(f"{p} loss: {result.micro_avg_loss:.3f}")
has_tags = len(evaluator.dataset.tag_to_index) > 1 # 1 tag means there's no difference between micro and macro
if has_tags:
log_dict[_join_prefix(p, "macro_loss")] = result.macro_avg_loss

for tag, loss in result.tag_macro_losses.items():
# don't log leaf tag macro losses because it doesn't mean anything different than micro loss
if tag in evaluator.dataset.tag_to_index:
continue
if not tag:
continue
log_dict[_join_prefix(p, tag) + "/macro_loss"] = loss
logger.info(f"{tag} macro loss: {loss:.3f}")

for tag, loss in result.tag_micro_losses.items():
if not tag:
continue
log_dict[_join_prefix(prefix, tag) + "/macro_loss"] = loss
logger.info(f"{tag} macro loss: {loss:.3f}")

for tag, loss in result.tag_micro_losses.items():
if not tag:
continue
if tag in evaluator.dataset.tag_to_index:
log_dict[_join_prefix(prefix, tag) + "/loss"] = loss
logger.info(f"{tag} loss: {loss:.3f}")
else:
log_dict[_join_prefix(prefix, tag) + "/micro_loss"] = loss
logger.info(f"{tag} micro loss: {loss:.3f}")

if tokenizer is not None:
log_dict[_join_prefix(prefix, "bpb")] = result.micro_bpb
if has_tags:
log_dict[_join_prefix(prefix, "macro_bpb")] = result.macro_bpb
for tag, bpb in result.tag_micro_bpb.items():
log_dict[_join_prefix(prefix, tag) + "/bpb"] = bpb

if has_tags:
for tag, bpb in result.tag_macro_bpb.items():
log_dict[_join_prefix(prefix, tag) + "/macro_bpb"] = bpb

levanter.tracker.log(log_dict, step=step.step)

return result
if tag in evaluator.dataset.tag_to_index:
log_dict[_join_prefix(p, tag) + "/loss"] = loss
logger.info(f"{tag} loss: {loss:.3f}")
else:
log_dict[_join_prefix(p, tag) + "/micro_loss"] = loss
logger.info(f"{tag} micro loss: {loss:.3f}")

if tokenizer is not None:
log_dict[_join_prefix(p, "bpb")] = result.micro_bpb
if has_tags:
log_dict[_join_prefix(p, "macro_bpb")] = result.macro_bpb
for tag, bpb in result.tag_micro_bpb.items():
log_dict[_join_prefix(p, tag) + "/bpb"] = bpb

if has_tags:
for tag, bpb in result.tag_macro_bpb.items():
log_dict[_join_prefix(p, tag) + "/macro_bpb"] = bpb

levanter.tracker.log(log_dict, step=step.step)
return original_result

return eval_callback

Expand Down
7 changes: 5 additions & 2 deletions src/levanter/eval_harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,8 +688,11 @@ def lm_eval_harness(config: LmEvalHarnessConfig, tokenizer, EvalBatch, axis_reso
def lm_eval_harness(step: StepInfo, force=False):
if step.step == 0 and not force:
return

model = inference_mode(step.model, True)

if step.use_ema:
model = inference_mode(step.ema_model, True)
else:
model = inference_mode(step.model, True)
logger.info("Running eval harness...")
outputs = _actually_run_eval_harness(
config,
Expand Down
5 changes: 4 additions & 1 deletion src/levanter/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,8 @@ def init_state_and_model(model_init, training_key):
is_trainable=is_trainable,
mp=self.mp,
fp8=self.fp8,
use_ema = self.config.use_ema,
ema_beta = self.config.ema_beta
)
return state

Expand Down Expand Up @@ -577,7 +579,8 @@ class TrainerConfig:
seed: int = 0 # random seed
mp: jmp.Policy = jmp.get_policy("f32") # mixed precision policy
fp8: Optional[bool | Fp8Config] = None

use_ema: bool = False
ema_beta: float = 0.995
wandb: Optional[tracker.wandb.WandbConfig] = None
log_dir: Path = Path("logs/")
id: Optional[str] = None # run id. if None, will be set to a random string
Expand Down
Loading
Loading