-
Notifications
You must be signed in to change notification settings - Fork 103
/
tr11-176B-ml.slurm
221 lines (184 loc) · 6.87 KB
/
tr11-176B-ml.slurm
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
#!/bin/bash
#SBATCH --job-name=tr11-176B-ml
#SBATCH --partition=gpu_p5
#SBATCH --constraint=a100
#SBATCH --reservation=hug
#SBATCH --qos=qos_gpu-gc # up to 100h
#SBATCH --nodes=24
#SBATCH --ntasks-per-node=1 # crucial - only 1 task per dist per node!
#SBATCH --cpus-per-task=64 # number of cores per tasks
#SBATCH --hint=nomultithread # we get physical cores not logical
#SBATCH --gres=gpu:8 # number of gpus
#SBATCH --time 100:00:00 # maximum execution time (HH:MM:SS)
#SBATCH --output=%x-%j.out # output file name
#SBATCH --account=six@a100
set -x -e
#source $six_ALL_CCFRWORK/start-py38-pt110
#source $six_ALL_CCFRWORK/start-py38-pt111
source $six_ALL_CCFRWORK/code/tr11-176B-ml/bigscience/train/tr11-176B-ml/start-tr11-176B-ml
echo "START TIME: $(date)"
variant=main
DATA_OUTPUT_PATH=$six_ALL_CCFRSCRATCH/checkpoints/tr11-176B-ml
CHECKPOINT_PATH=$DATA_OUTPUT_PATH/checkpoints/$variant
REPO_PATH=$DATA_OUTPUT_PATH/tr11-176B-ml-logs
TENSORBOARD_PATH=$REPO_PATH/tensorboard/$variant
LOGS_PATH=$REPO_PATH/logs/$variant
mkdir -p $LOGS_PATH
MEGATRON_DEEPSPEED_REPO=$six_ALL_CCFRWORK/code/tr11-176B-ml/Megatron-DeepSpeed
cd $MEGATRON_DEEPSPEED_REPO
KILL_SWITCH_PATH=$MEGATRON_DEEPSPEED_REPO/kill-switch-tr11-176B-exp1
BIGSCIENCE_REPO=$six_ALL_CCFRWORK/code/tr11-176B-ml/bigscience
TRAIN_DATA_PATH=$MEGATRON_DEEPSPEED_REPO/data/train-splits.txt
VALID_DATA_PATH=$MEGATRON_DEEPSPEED_REPO/data/valid-splits.txt
CATALOGUE_JSON_PATH=$BIGSCIENCE_REPO/data/catalogue/training_dataset_ratios_merged_nigercongo_v3.json
LOAD_RATIOS_SCRIPT=$BIGSCIENCE_REPO/data/catalogue/load_ratios_meg_ds_format.py
python $LOAD_RATIOS_SCRIPT --dataset-ratios-path $CATALOGUE_JSON_PATH --split train --output-meg-ds-ratio-file $TRAIN_DATA_PATH
python $LOAD_RATIOS_SCRIPT --dataset-ratios-path $CATALOGUE_JSON_PATH --split valid --output-meg-ds-ratio-file $VALID_DATA_PATH
TOKENIZER_NAME_OR_PATH=bigscience-catalogue-data-dev/byte-level-bpe-tokenizer-no-norm-250k-whitespace-and-eos-regex-alpha-v3-dedup-lines-articles
# defining the right environment variables
export TRANSFORMERS_CACHE=$six_ALL_CCFRWORK/models
export HF_DATASETS_CACHE=$six_ALL_CCFRWORK/datasets
export HF_MODULES_CACHE=$six_ALL_CCFRWORK/modules
export HF_METRICS_CACHE=$six_ALL_CCFRWORK/metrics
export HF_DATASETS_OFFLINE=1
export TRANSFORMERS_OFFLINE=1
# testing for potential faulty nodes
# srun --jobid $SLURM_JOB_ID bash -c 'python -c "import torch, socket; print(socket.gethostname(), torch.cuda.is_available())"'
# so processes know who to talk to
MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
MASTER_PORT=6000
GPUS_PER_NODE=8
NNODES=$SLURM_NNODES
TP_SIZE=4
PP_SIZE=12
MICRO_BATCH_SIZE=2 # was MBS=1 till GBS=784
GLOBAL_BATCH_SIZE=2048 # 4.2M tokens. It is larger than the initial plan of 3.2M tokens to get higher throughput
NHIDDEN=14336
NLAYERS=70
NHEADS=112
SEQ_LEN=2048
SAVE_INTERVAL=100
TRAIN_SAMPLES=220_000_000 # 450B tokens
LR_DECAY_SAMPLES=200_000_000 # Decay for the first 410B tokens then continue at fixed --min-lr
LR_WARMUP_SAMPLES=183_105 # 375M tokens
OPTIMIZER_ARGS=" \
--optimizer adam \
--adam-beta1 0.9 \
--adam-beta2 0.95 \
--adam-eps 1e-8 \
--lr 6e-5 \
--min-lr 6e-6 \
--lr-decay-style cosine \
--lr-decay-samples $LR_DECAY_SAMPLES \
--lr-warmup-samples $LR_WARMUP_SAMPLES \
--clip-grad 1.0 \
--weight-decay 1e-1 \
"
# for 20h 1190, for 100h 5990
# --exit-duration-in-mins 1190 \
EXIT_OPTS=" \
--exit-duration-in-mins 5990 \
"
GPT_ARGS=" \
--pp-partition-method 'type:transformer|embedding' \
--num-layers $NLAYERS \
--hidden-size $NHIDDEN \
--num-attention-heads $NHEADS \
--seq-length $SEQ_LEN \
--max-position-embeddings $SEQ_LEN \
--micro-batch-size $MICRO_BATCH_SIZE \
--rampup-batch-size 192 16 9_765_625 \
--global-batch-size $GLOBAL_BATCH_SIZE \
--train-samples $TRAIN_SAMPLES \
--tokenizer-type PretrainedFromHF \
--tokenizer-name-or-path $TOKENIZER_NAME_OR_PATH \
--init-method-std 0.0048 \
--embed-layernorm \
--sync-tp-duplicated-parameters \
--bf16 \
--seed 42 \
--position-embedding-type alibi \
--checkpoint-activations \
--abort-on-unmet-fused-kernel-constraints \
--kill-switch-path $KILL_SWITCH_PATH \
--pad-vocab-size-to 250880 \
$OPTIMIZER_ARGS \
$EXIT_OPTS \
"
# TODO: decide on efficient eval-interval + eval-iters
OUTPUT_ARGS=" \
--log-interval 1 \
--save-interval $SAVE_INTERVAL \
--eval-interval 1000 \
--eval-iters 1 \
--tensorboard-dir $TENSORBOARD_PATH \
--tensorboard-queue-size 5 \
--log-timers-to-tensorboard \
--log-batch-size-to-tensorboard \
--log-validation-ppl-to-tensorboard \
"
ZERO_STAGE=0 # important: bf16 must use z0! it implements its own zero stage 1 equivalent
config_json="./ds_config.$SLURM_JOB_ID.json"
# Deepspeed figures out GAS dynamically from dynamic GBS via set_train_batch_size()
cat <<EOT > $config_json
{
"train_micro_batch_size_per_gpu": $MICRO_BATCH_SIZE,
"train_batch_size": $GLOBAL_BATCH_SIZE,
"gradient_clipping": 1.0,
"zero_optimization": {
"stage": $ZERO_STAGE
},
"bf16": {
"enabled": true
},
"steps_per_print": 2000,
"wall_clock_breakdown": false
}
EOT
DEEPSPEED_ARGS=" \
--deepspeed \
--deepspeed_config ${config_json} \
--zero-stage ${ZERO_STAGE} \
--deepspeed-activation-checkpointing \
"
export LAUNCHER="python -u -m torch.distributed.run \
--nproc_per_node $GPUS_PER_NODE \
--nnodes $NNODES \
--rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \
--rdzv_backend c10d \
--max_restarts 0 \
--tee 3 \
"
# --universal-checkpoint \
export CMD=" \
`pwd`/pretrain_gpt.py \
--tensor-model-parallel-size $TP_SIZE \
--pipeline-model-parallel-size $PP_SIZE \
$GPT_ARGS \
$OUTPUT_ARGS \
--save $CHECKPOINT_PATH \
--load $CHECKPOINT_PATH \
--train-weighted-split-paths-path $TRAIN_DATA_PATH \
--valid-weighted-split-paths-path $VALID_DATA_PATH \
--num-workers 2 \
--valid-num-workers 0 \
--data-impl mmap \
--distributed-backend nccl \
$DEEPSPEED_ARGS \
"
echo $CMD
# do not remove or the training will hang and nodes will be lost w/o this workaround
export CUDA_LAUNCH_BLOCKING=1
# hide duplicated errors using this hack - will be properly fixed in pt-1.12
export TORCHELASTIC_ERROR_FILE=/tmp/torch-elastic-error.json
# force crashing on nccl issues like hanging broadcast
export NCCL_ASYNC_ERROR_HANDLING=1
# srun error handling:
# --wait=60: wait 60 sec after the first task terminates before terminating all remaining tasks
# --kill-on-bad-exit=1: terminate a step if any task exits with a non-zero exit code
SRUN_ARGS=" \
--wait=60 \
--kill-on-bad-exit=1 \
"
clear; srun $SRUN_ARGS --jobid $SLURM_JOB_ID bash -c "$LAUNCHER --node_rank \$SLURM_PROCID $CMD" 2>&1 | tee -a $LOGS_PATH/main_log.txt
echo "END TIME: $(date)"