From e180743acc40611b659ffa64bbbf577cab8190bc Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Fri, 28 Jun 2024 13:10:38 +0000 Subject: [PATCH] update fp8 config for reproducing the ug --- examples/config_fp8_llama.yaml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/config_fp8_llama.yaml b/examples/config_fp8_llama.yaml index 54eea99e..d3036fef 100644 --- a/examples/config_fp8_llama.yaml +++ b/examples/config_fp8_llama.yaml @@ -73,9 +73,9 @@ model: intermediate_size: 2048 is_llama_config: true max_position_embeddings: 256 - num_attention_heads: 16 + num_attention_heads: 4 num_hidden_layers: 2 - num_key_value_heads: 16 + num_key_value_heads: 4 pad_token_id: null pretraining_tp: 1 rms_norm_eps: 1.0e-05 @@ -119,13 +119,13 @@ optimizer: # clip_grad: 1.0 learning_rate_scheduler: # learning_rate: 0.0015 # note: 1/2 of pythia use this for a 400m model - learning_rate: 6.0e-4 + learning_rate: 0.0006 lr_decay_starting_step: null lr_decay_steps: null lr_decay_style: cosine lr_warmup_steps: 1000 # 10% warm up of total training steps lr_warmup_style: linear - min_decay_lr: 6.0e-5 + min_decay_lr: 0.00006 optimizer_factory: adam_beta1: 0.9 adam_beta2: 0.95 @@ -158,7 +158,7 @@ tokens: batch_accumulation_per_replica: 1 limit_test_batches: 0 limit_val_batches: 0 - micro_batch_size: 128 # 256 + micro_batch_size: 256 # 256 # micro_batch_size: 1 sequence_length: 256 train_steps: 24376