Skip to content

Commit

Permalink
Resolved line too long pylint issues and merged main
Browse files Browse the repository at this point in the history
Signed-off-by: taejinp <[email protected]>
  • Loading branch information
tango4j committed Nov 22, 2024
2 parents 07f791a + 86315db commit 2b23136
Show file tree
Hide file tree
Showing 6 changed files with 248 additions and 22 deletions.
116 changes: 116 additions & 0 deletions .github/workflows/cicd-main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2109,6 +2109,121 @@ jobs:
# }
# }

L2_Megatron_LM_To_NeMo_Conversion:
needs: [cicd-test-container-setup]
uses: ./.github/workflows/_test_template.yml
if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_Megatron_LM_To_NeMo_Conversion') || needs.cicd-test-container-setup.outputs.all == 'true'
with:
RUNNER: self-hosted-azure
SCRIPT: |
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=1 Megatron-LM/pretrain_gpt.py \
--mock-data \
--distributed-timeout-minutes 60 \
--use-mcore-models \
--no-mmap-bin-files \
--untie-embeddings-and-output-weights \
--disable-bias-linear \
--train-samples 80 \
--init-method-std 0.014 \
--position-embedding-type rope \
--rotary-base 1000000 \
--rotary-percent 1.0 \
--squared-relu \
--num-layers 4 \
--hidden-size 384 \
--num-attention-heads 8 \
--group-query-attention \
--num-query-groups 8 \
--ffn-hidden-size 1536 \
--kv-channels 128 \
--normalization RMSNorm \
--attention-dropout 0.0 \
--hidden-dropout 0.0 \
--exit-duration-in-mins 5750 \
--tensor-model-parallel-size 1 \
--pipeline-model-parallel-size 1 \
--seq-length 8192 \
--max-position-embeddings 8192 \
--micro-batch-size 1 \
--global-batch-size 8 \
--lr 6e-4 \
--min-lr 6e-6 \
--weight-decay 0.1 \
--clip-grad 1.0 \
--lr-decay-style cosine \
--log-interval 1 \
--eval-iters 1 \
--eval-interval 10 \
--tokenizer-type GPT2BPETokenizer \
--tokenizer-model /home/TestData/nlp/gpt2_tokenizer \
--vocab-file /home/TestData/nlp/gpt2_tokenizer/vocab.json \
--merge-file /home/TestData/nlp/gpt2_tokenizer/merges.txt \
--save /tmp/mlm_conversion_ckpt \
--save-interval 10 \
--ckpt-format torch_dist \
--ckpt-fully-parallel-save \
--ckpt-fully-parallel-load \
--async-save \
--ckpt-assume-constant-structure \
--timing-log-option minmax \
--log-params-norm \
--log-num-zeros-in-grad \
--log-throughput \
--bf16 \
--adam-beta1 0.9 \
--adam-beta2 0.95 \
--use-distributed-optimizer \
--overlap-grad-reduce \
--overlap-param-gather \
--manual-gc \
--num-workers 2
python examples/nlp/language_modeling/megatron_gpt_pretraining.py \
model.data.data_impl=mock \
model.data.data_prefix=[] \
model.skip_train=True \
model.transformer_engine=True \
model.use_flash_attention=False \
model.normalization=rmsnorm \
model.num_layers=4 \
model.hidden_size=384 \
model.ffn_hidden_size=1536 \
model.num_attention_heads=8 \
model.num_query_groups=8 \
model.bias=False \
model.bias_activation_fusion=False \
model.bias_dropout_add_fusion=True \
model.masked_softmax_fusion=True \
model.encoder_seq_length=8192 \
model.max_position_embeddings=8192 \
model.data.seq_length=8192 \
model.activation=squared-relu \
model.transformer_block_type=True \
model.micro_batch_size=1 \
model.global_batch_size=8 \
++model.rotary_base=1000000 \
model.rotary_percentage=1.0 \
model.apply_query_key_layer_scaling=False \
++model.group_query_attention=True \
model.apply_rope_fusion=True \
model.kv_channels=128 \
++model.bert_binary_head=True \
++model.position_embedding_type=rope \
++model.add_position_embedding=True \
trainer.limit_val_batches=1 \
exp_manager.exp_dir=/tmp/nemo_conversion_ckpt
python -m torch.distributed.launch --nproc_per_node=1 examples/nlp/language_modeling/megatron_ckpt_to_nemo.py \
--checkpoint_folder /tmp/mlm_conversion_ckpt \
--checkpoint_name iter_0000010 \
--nemo_file_path /tmp/mlm_to_nemo_test.nemo \
--tensor_model_parallel_size 1 \
--pipeline_model_parallel_size 1 \
--gpus_per_node 1 \
--model_type gpt \
--hparams_file /tmp/nemo_conversion_ckpt/megatron_gpt/version_0/hparams.yaml \
--convert_mlm
L2_Megatron_GPT_with_ResetLR_Pretraining_and_Resume_Training_TP2:
needs: [cicd-test-container-setup]
uses: ./.github/workflows/_test_template.yml
Expand Down Expand Up @@ -4432,6 +4547,7 @@ jobs:
- L2_RAG_Pipeline_Generating
- L2_Megatron_GPT_Pretraining_and_Resume_Training_TP2
- L2_Megatron_GPT_Skip_Train
- L2_Megatron_LM_To_NeMo_Conversion
- L2_Megatron_GPT_with_Rope_Pretraining_and_Resume_Training_TP2
- L2_Megatron_GPT_with_ResetLR_Pretraining_and_Resume_Training_TP2
- L2_Megatron_GPT_with_Drop_Optimizer_States_TP2
Expand Down
9 changes: 8 additions & 1 deletion examples/nlp/language_modeling/megatron_ckpt_to_nemo.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,11 @@ def get_args():
choices=['32-true', '16-mixed', 'bf16-mixed'],
help="Precision value for the trainer that matches with precision of the ckpt",
)
parser.add_argument(
"--convert_mlm",
action="store_true",
help="Use this flag to convert megatron-lm checkpoints.",
)

args = parser.parse_args()
return args
Expand Down Expand Up @@ -195,7 +200,9 @@ def convert(local_rank, rank, world_size, args):
)

if args.model_type == 'gpt':
model = MegatronGPTModel.load_from_checkpoint(checkpoint_path, hparams_file=args.hparams_file, trainer=trainer)
model = MegatronGPTModel.load_from_checkpoint(
checkpoint_path, hparams_file=args.hparams_file, trainer=trainer, load_mlm=args.convert_mlm
)
elif args.model_type == 'sft':
model = MegatronGPTSFTModel.load_from_checkpoint(
checkpoint_path, hparams_file=args.hparams_file, trainer=trainer
Expand Down
6 changes: 5 additions & 1 deletion nemo/collections/asr/parts/utils/manifest_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,11 @@ def create_segment_manifest(
segments_manifest_file = write_rttm2manifest(AUDIO_RTTM_MAP, segment_manifest_path, deci)
subsegments_manifest_file = subsegment_manifest_path
segments_manifest_to_subsegments_manifest(
segments_manifest_file, subsegments_manifest_file, window, shift, min_subsegment_duration,
segments_manifest_file,
subsegments_manifest_file,
window,
shift,
min_subsegment_duration,
)
subsegments_dict = get_subsegment_dict(subsegments_manifest_file, window, shift, deci)
write_truncated_subsegments(input_manifest_dict, subsegments_dict, output_manifest_path, step_count, deci)
Expand Down
17 changes: 16 additions & 1 deletion nemo/collections/nlp/models/nlp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,22 @@ def dummy():
model.trainer.strategy.launcher.launch(dummy, trainer=model.trainer)
model.trainer.strategy.setup_environment()
sharded_state_dict = model.sharded_state_dict()
checkpoint['state_dict'] = sharded_state_dict
if kwargs.get("load_mlm", False):
mlm_sharded_state_dict = {}
for k, v in sharded_state_dict.items():
# Remove 'model.' from the sharded_state_dict keys
new_key = k.replace('model.', '', 1)

# Update the key attribute of the ShardedTensor value
new_value = v
if hasattr(v, 'key'):
new_value.key = v.key.replace('model.', '', 1)

# Add the updated key-value pair to the new dictionary
mlm_sharded_state_dict[new_key] = new_value
checkpoint['state_dict'] = mlm_sharded_state_dict
else:
checkpoint['state_dict'] = sharded_state_dict
# load the checkpoint from disk
checkpoint = dist_checkpointing.load(sharded_state_dict=checkpoint, checkpoint_dir=checkpoint_dir)
# restore the weights
Expand Down
38 changes: 37 additions & 1 deletion scripts/checkpoint_converters/convert_nemotron_nemo_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import torch
from lightning.pytorch import Trainer
from transformers import LlamaTokenizer, PreTrainedTokenizerFast
from transformers.convert_slow_tokenizer import LlamaConverter
from transformers.convert_slow_tokenizer import LlamaConverter, TikTokenConverter

from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer
from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel
Expand Down Expand Up @@ -130,6 +130,20 @@ def convert_hf_config(nemo_config, tokenizer, vocab_size, dtype, hf_output_path,
json.dump(hf_config, open(f"{hf_output_path}/config.json", "w"), indent=2)


def convert_tiktoken(vocab_file) -> None:
with open(vocab_file, 'r') as f:
vocab = json.load(f)
os.remove(vocab_file)

lines = []
for line in vocab:
lines.append(f"{line['token_bytes']} {line['rank']}")

for line in lines:
with open(vocab_file, 'a') as f:
f.write(line + '\n')


def convert(input_nemo_file, output_hf_file, precision=None, cpu_only=False) -> None:
"""
Convert NeMo weights to HF weights
Expand Down Expand Up @@ -323,6 +337,28 @@ def extract_nemotron_tokenizer(nemo_file, model_config, output_hf_path, nemo_tok
)
tokenizer.save_pretrained(output_hf_path)
logging.info(f"Setencepiece tokenizer has been saved to {output_tokenizer}")
elif tokenizer_cfg.library == "tiktoken":
tokenizer_fn = tokenizer_cfg.model[5:]
special_tokens = ["<unk>", "<s>", "</s>"]
import tarfile

archive = tarfile.open(nemo_file, "r")
tokenizer_filename = "./" + tokenizer_fn # exclude 'nemo:' prefix
archive.extract(tokenizer_filename, output_hf_path)
archive.close()
vocab_file = os.path.join(output_hf_path, tokenizer_fn)
convert_tiktoken(vocab_file)
converted_tokenizer = TikTokenConverter(
vocab_file=vocab_file, additional_special_tokens=special_tokens
).converted()
os.remove(vocab_file)
tokenizer = PreTrainedTokenizerFast(
tokenizer_object=converted_tokenizer,
model_input_names=["input_ids", "attention_mask"],
bos_token="<s>",
eos_token="</s>",
)
tokenizer.save_pretrained(output_hf_path)
elif isinstance(nemo_tokenizer, AutoTokenizer):
nemo_tokenizer.tokenizer.save_pretrained(output_hf_path)
logging.info(f"HF AutoTokenizer has been saved to {output_hf_path}")
Expand Down
Loading

0 comments on commit 2b23136

Please sign in to comment.