diff --git a/docs/llama_mistral.md b/docs/llama_mistral.md index 0e3d4b2fb8..dd96923974 100644 --- a/docs/llama_mistral.md +++ b/docs/llama_mistral.md @@ -1,4 +1,4 @@ -# Llama and Mistral support in Megatron-LM +# Llama, Mistral and other Llama-like model support in Megatron-LM NOTE: Llama-3 and Mistral support in Megatron is currently experimental and we are still evaluting benchmark results to confirm model conversion, training and inference correctness. @@ -386,6 +386,12 @@ If loading for either inference or finetuning, use the following arguments: --attention-softmax-in-fp32 ``` -# Benchmark results +## Benchmark results Mistral-7B support in Megatron is currently experimental and we are still carrying out benchmark evaluations. + +# Other Llama-like model support + +*Note: Experimental* + +Many models such as Yi-34B use the Llama architecture and may be converted from HuggingFace to Megatron using the commands in [Llama3](#llama-3). diff --git a/tools/checkpoint/loader_llama_mistral.py b/tools/checkpoint/loader_llama_mistral.py index 52a8df7925..cba0bd3e1b 100644 --- a/tools/checkpoint/loader_llama_mistral.py +++ b/tools/checkpoint/loader_llama_mistral.py @@ -19,7 +19,7 @@ def add_arguments(parser): # TODO(jbarker): Need assertion to make sure *exactly* one of these is used parser.add_argument('--model-size', type=str, required=True, - choices=['llama2-7B', 'llama2-13B', 'llama2-70B', 'llama2-7Bf', 'llama2-13Bf', 'llama2-70Bf', 'llama3-8B', 'llama3-70B', 'llama3-8Bf', 'llama3-70Bf', 'mistral-7B', 'mistral-7Bf'], + choices=['llama2-7B', 'llama2-13B', 'llama2-70B', 'llama2-7Bf', 'llama2-13Bf', 'llama2-70Bf', 'llama3-8B', 'llama3-70B', 'llama3-8Bf', 'llama3-70Bf', 'mistral-7B', 'mistral-7Bf', 'yi-34B'], help='Model size can be `llama2-7B`, `llama2-13B`, `llama2-70B`, `llama3-8B`, `llama3-70B`, `mistral-7B` (for pretrained models), ' 'and `llama2-7Bf`, `llama2-13Bf`, `llama2-70Bf`, `llama3-8Bf`, `llama3-70bf` and `mistral-7Bf` (for chat-finetuned models).') parser.add_argument('--checkpoint-type', type=str, required=True, @@ -58,6 +58,7 @@ def verify_transformers_version(): "llama3-70Bf": 8, "mistral-7B": 1, "mistral-7Bf": 1, + "yi-34B": 8, } @@ -394,7 +395,7 @@ def load_checkpoint_to_model(args): '''Set model params.''' from pretrain_gpt import model_provider - if "llama" in args.model_size: + if "llama" in args.model_size or "yi" in args.model_size: from transformers import LlamaForCausalLM as ModelForCausalLM elif "mistral" in args.model_size: from transformers import MistralForCausalLM as ModelForCausalLM @@ -465,7 +466,7 @@ def _load_checkpoint(queue, args): margs.tokenizer_model = args.tokenizer_model load_args_from_checkpoint(margs) - if "llama2" in args.model_size: + if "llama2" in args.model_size or "yi" in args.model_size: margs.tokenizer_type = "Llama2Tokenizer" elif "llama3" in args.model_size: margs.tokenizer_type = "Llama3Tokenizer"