Skip to content

Commit

Permalink
Fix TPU implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
davidjurado committed Jun 8, 2023
1 parent e9dd27b commit 21d15d8
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 48 deletions.
4 changes: 4 additions & 0 deletions language_model/tensorflow/bert/mlcube/mlcube_tpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ tasks:
train_tpu:
entrypoint: ./run_tpu.sh -a
parameters:
inputs:
parameters_yaml:
type: file
default: parameters.yaml
outputs:
log_dir: logs/
check_logs:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
input_gs: gs://bert_tf_data/
input_gs: gs://bert_tf_data
output_gs: your_gs_bucket_name
tpu_name: your_tpu_instance_name
tpu_zone: your_tpu_zone
Expand Down
97 changes: 50 additions & 47 deletions language_model/tensorflow/bert/run_tpu.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,71 +9,74 @@ start_fmt=$(date +%Y-%m-%d\ %r)
echo "STARTING TIMING RUN AT $start_fmt"

# Set variables
: "${TFDATA_PATH:=./workspace/output_data}"
: "${INIT_CHECKPOINT:=./workspace/data/tf2_ckpt}"
: "${EVAL_FILE:=./workspace/tf_eval_data/eval_10k}"
: "${CONFIG_PATH:=./workspace/data/bert_config.json}"
: "${LOG_DIR:=./workspace/logs}"
: "${OUTPUT_DIR:=./workspace/final_output}"
: "${PARAMETERS_YAML:=./workspace/parameters.yaml}"

# Handle MLCube parameters
while [ $# -gt 0 ]; do
case "$1" in
--tfdata_path=*)
TFDATA_PATH="${1#*=}"
;;
--config_path=*)
CONFIG_PATH="${1#*=}"
;;
--init_checkpoint=*)
INIT_CHECKPOINT="${1#*=}"
;;
--log_dir=*)
LOG_DIR="${1#*=}"
;;
--output_dir=*)
OUTPUT_DIR="${1#*=}"
;;
--eval_file=*)
EVAL_FILE="${1#*=}"
--parameters_yaml=*)
PARAMETERS_YAML="${1#*=}"
;;
*) ;;
esac
shift
done

function parse_yaml {
local prefix=$2
local s='[[:space:]]*' w='[a-zA-Z0-9_]*' fs=$(echo @|tr @ '\034')
sed -ne "s|^\($s\):|\1|" \
-e "s|^\($s\)\($w\)$s:$s[\"']\(.*\)[\"']$s\$|\1$fs\2$fs\3|p" \
-e "s|^\($s\)\($w\)$s:$s\(.*\)$s\$|\1$fs\2$fs\3|p" $1 |
awk -F$fs '{
indent = length($1)/2;
vname[indent] = $2;
for (i in vname) {if (i > indent) {delete vname[i]}}
if (length($3) > 0) {
vn=""; for (i=0; i<indent; i++) {vn=(vn)(vname[i])("_")}
printf("%s%s%s=\"%s\"\n", "'$prefix'",vn, $2, $3);
}
}'
}

eval $(parse_yaml $PARAMETERS_YAML)

# run benchmark
echo "running benchmark"

python3 ./run_pretraining.py \
--bert_config_file=gs://bert_tf_data/bert_config.json \
--nodo_eval \
--do_train \
--eval_batch_size=64 \
--init_checkpoint=gs://bert_tf_data/tf2_ckpt/model.ckpt-28252 \
--input_file=gs://bert_tf_data/tf_data/part-* \
--iterations_per_loop=1 \
--lamb_beta_1=0.88 \
--lamb_beta_2=0.88 \
--lamb_weight_decay_rate=0.0166629 \
--learning_rate=0.00288293 \
--log_epsilon=-6 \
--max_eval_steps=125 \
--max_predictions_per_seq=76 \
--max_seq_length=512 \
--num_tpu_cores=8 \
--num_train_steps=15000 \
--num_warmup_steps=28 \
--optimizer=lamb \
--output_dir=gs://bert_tf_data/output/ \
--save_checkpoints_steps=3 \
--start_warmup_step=-76 \
--steps_per_update=1 \
--train_batch_size=256 \
--use_tpu \
--tpu_name=node3 \
--tpu_zone=us-central1-b \
--gcp_project=training-reference-bench-test |& tee "$LOG_DIR/train_console.log"
--bert_config_file="$input_gs/bert_config.json" \
--nodo_eval \
--do_train \
--eval_batch_size=64 \
--init_checkpoint="$input_gs/tf2_ckpt/model.ckpt-28252" \
--input_file="$input_gs/tf_data/part-*" \
--iterations_per_loop=1 \
--lamb_beta_1=0.88 \
--lamb_beta_2=0.88 \
--lamb_weight_decay_rate=0.0166629 \
--learning_rate=0.00288293 \
--log_epsilon=-6 \
--max_eval_steps=125 \
--max_predictions_per_seq=76 \
--max_seq_length=512 \
--num_tpu_cores=8 \
--num_train_steps=15000 \
--num_warmup_steps=28 \
--optimizer=lamb \
--output_dir="$output_gs/output/" \
--save_checkpoints_steps=3 \
--start_warmup_step=-76 \
--steps_per_update=1 \
--train_batch_size=256 \
--use_tpu \
--tpu_name="$tpu_name" \
--tpu_zone="$tpu_zone" \
--gcp_project="$gcp_project" |& tee "$LOG_DIR/train_console.log"

set +x

Expand Down

0 comments on commit 21d15d8

Please sign in to comment.