diff --git a/docs/source/checkpoints/convert_mlm.rst b/docs/source/checkpoints/convert_mlm.rst index 61b5b2802e8a..3f0949639975 100644 --- a/docs/source/checkpoints/convert_mlm.rst +++ b/docs/source/checkpoints/convert_mlm.rst @@ -10,11 +10,12 @@ You can convert your GPT-style model checkpoints trained with Megatron-LM into t .. code-block:: bash - /examples/nlp/language_modeling/megatron_lm_ckpt_to_nemo.py \ + python -m torch.distributed.launch --nproc_per_node=4 /examples/nlp/language_modeling/megatron_ckpt_to_nemo.py \ --checkpoint_folder \ --checkpoint_name megatron_gpt--val_loss=99.99-step={steps}-consumed_samples={consumed}.0 \ --nemo_file_path \ --model_type \ + --hparams_file \ --tensor_model_parallel_size \ --pipeline_model_parallel_size \ --gpus_per_node diff --git a/scripts/checkpoint_converters/convert_llama_nemo_to_hf.py b/scripts/checkpoint_converters/convert_llama_nemo_to_hf.py index 87b7151aa961..e0158b935c29 100644 --- a/scripts/checkpoint_converters/convert_llama_nemo_to_hf.py +++ b/scripts/checkpoint_converters/convert_llama_nemo_to_hf.py @@ -116,6 +116,7 @@ def convert(input_nemo_file, output_hf_file, precision=None, cpu_only=False) -> dummy_trainer = Trainer(devices=1, accelerator='cpu', strategy=NLPDDPStrategy()) model_config = MegatronGPTModel.restore_from(input_nemo_file, trainer=dummy_trainer, return_config=True) model_config.tensor_model_parallel_size = 1 + model_config.sequence_parallel = 0 model_config.pipeline_model_parallel_size = 1 model_config.name = "te_gpt" if cpu_only: @@ -151,19 +152,19 @@ def convert(input_nemo_file, output_hf_file, precision=None, cpu_only=False) -> ffn_hidden_size = model.cfg.ffn_hidden_size num_query_groups = model.cfg.get("num_query_groups", head_num) # different num_query_groups for 70B - head_size = model.cfg.get("kv_channels") or (hidden_size // head_num) # equivalent to hf's head_dim + head_size = model.cfg.get("kv_channels") or (hidden_size // head_num) heads_per_group = head_num // num_query_groups qkv_total_dim = head_num + 2 * num_query_groups # Embedding - embed_weight = model.state_dict()[f'model.embedding.word_embeddings.weight'] + embed_weight = model.model[0].state_dict()[f'embedding.word_embeddings.weight'] embed_weights_base_name = f'model.embed_tokens.weight' checkpoint[embed_weights_base_name] = param_to_weights(embed_weight) for l in range(int(num_layers)): print(f"converting layer {l}") - qkv_weights = model.state_dict()[f'model.decoder.layers.{l}.self_attention.linear_qkv.weight'] + qkv_weights = model.model[0].state_dict()[f'decoder.layers.{l}.self_attention.linear_qkv.weight'] qkv_weights = qkv_weights.reshape([qkv_total_dim, head_size, hidden_size]) q_slice = torch.cat( @@ -193,12 +194,12 @@ def convert(input_nemo_file, output_hf_file, precision=None, cpu_only=False) -> checkpoint[v_weights_base_name] = param_to_weights(qkv_weights[v_slice].reshape(-1, hidden_size)) # attention dense - o_weight = model.state_dict()[f'model.decoder.layers.{l}.self_attention.linear_proj.weight'] + o_weight = model.model[0].state_dict()[f'decoder.layers.{l}.self_attention.linear_proj.weight'] o_weight_base_name = f'model.layers.{l}.self_attn.o_proj.weight' checkpoint[o_weight_base_name] = param_to_weights(o_weight) # mlp - mlp_weights = model.state_dict()[f'model.decoder.layers.{l}.mlp.linear_fc1.weight'] + mlp_weights = model.model[0].state_dict()[f'decoder.layers.{l}.mlp.linear_fc1.weight'] mlp_down_proj_weight = mlp_weights[:ffn_hidden_size, :] mlp_gate_proj_weight = mlp_weights[ffn_hidden_size:, :] @@ -208,26 +209,26 @@ def convert(input_nemo_file, output_hf_file, precision=None, cpu_only=False) -> checkpoint[mlp_down_proj_base_name] = param_to_weights(mlp_down_proj_weight) checkpoint[mlp_gate_proj_base_name] = param_to_weights(mlp_gate_proj_weight) - mlp_up_proj_weight = model.state_dict()[f'model.decoder.layers.{l}.mlp.linear_fc2.weight'] + mlp_up_proj_weight = model.model[0].state_dict()[f'decoder.layers.{l}.mlp.linear_fc2.weight'] mlp_up_proj_base_name = f'model.layers.{l}.mlp.down_proj.weight' checkpoint[mlp_up_proj_base_name] = param_to_weights(mlp_up_proj_weight) # layernorm - input_ln_weight = model.state_dict()[f'model.decoder.layers.{l}.self_attention.linear_qkv.layer_norm_weight'] + input_ln_weight = model.model[0].state_dict()[f'decoder.layers.{l}.self_attention.linear_qkv.layer_norm_weight'] input_ln_base_name = f'model.layers.{l}.input_layernorm.weight' checkpoint[input_ln_base_name] = param_to_weights(input_ln_weight) - post_attn_ln_weight = model.state_dict()[f'model.decoder.layers.{l}.mlp.linear_fc1.layer_norm_weight'] + post_attn_ln_weight = model.model[0].state_dict()[f'decoder.layers.{l}.mlp.linear_fc1.layer_norm_weight'] post_attn_ln_base_name = f'model.layers.{l}.post_attention_layernorm.weight' checkpoint[post_attn_ln_base_name] = param_to_weights(post_attn_ln_weight) print(f"done layer {l}") - final_ln_weight = model.state_dict()[f'model.decoder.final_layernorm.weight'] + final_ln_weight = model.model[0].state_dict()[f'decoder.final_layernorm.weight'] final_ln_base_name = f'model.norm.weight' checkpoint[final_ln_base_name] = param_to_weights(final_ln_weight) - output_layer_weight = model.state_dict()[f'model.output_layer.weight'] + output_layer_weight = model.model[0].state_dict()[f'output_layer.weight'] output_layer_base_name = f'lm_head.weight' checkpoint[output_layer_base_name] = param_to_weights(output_layer_weight)