Skip to content

Commit

Permalink
Merge branch 'trintamaki/yi-support' into 'main'
Browse files Browse the repository at this point in the history
Experimental Yi conversion support

See merge request ADLR/megatron-lm!1534
  • Loading branch information
jon-barker committed Jun 17, 2024
2 parents 13c762e + 2b45e60 commit e33c8f7
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 5 deletions.
10 changes: 8 additions & 2 deletions docs/llama_mistral.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Llama and Mistral support in Megatron-LM
# Llama, Mistral and other Llama-like model support in Megatron-LM

NOTE: Llama-3 and Mistral support in Megatron is currently experimental and we are still evaluting benchmark results to confirm model conversion, training and inference correctness.

Expand Down Expand Up @@ -386,6 +386,12 @@ If loading for either inference or finetuning, use the following arguments:
--attention-softmax-in-fp32
```

# Benchmark results
## Benchmark results

Mistral-7B support in Megatron is currently experimental and we are still carrying out benchmark evaluations.

# Other Llama-like model support

*Note: Experimental*

Many models such as Yi-34B use the Llama architecture and may be converted from HuggingFace to Megatron using the commands in [Llama3](#llama-3).
7 changes: 4 additions & 3 deletions tools/checkpoint/loader_llama_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def add_arguments(parser):

# TODO(jbarker): Need assertion to make sure *exactly* one of these is used
parser.add_argument('--model-size', type=str, required=True,
choices=['llama2-7B', 'llama2-13B', 'llama2-70B', 'llama2-7Bf', 'llama2-13Bf', 'llama2-70Bf', 'llama3-8B', 'llama3-70B', 'llama3-8Bf', 'llama3-70Bf', 'mistral-7B', 'mistral-7Bf'],
choices=['llama2-7B', 'llama2-13B', 'llama2-70B', 'llama2-7Bf', 'llama2-13Bf', 'llama2-70Bf', 'llama3-8B', 'llama3-70B', 'llama3-8Bf', 'llama3-70Bf', 'mistral-7B', 'mistral-7Bf', 'yi-34B'],
help='Model size can be `llama2-7B`, `llama2-13B`, `llama2-70B`, `llama3-8B`, `llama3-70B`, `mistral-7B` (for pretrained models), '
'and `llama2-7Bf`, `llama2-13Bf`, `llama2-70Bf`, `llama3-8Bf`, `llama3-70bf` and `mistral-7Bf` (for chat-finetuned models).')
parser.add_argument('--checkpoint-type', type=str, required=True,
Expand Down Expand Up @@ -58,6 +58,7 @@ def verify_transformers_version():
"llama3-70Bf": 8,
"mistral-7B": 1,
"mistral-7Bf": 1,
"yi-34B": 8,
}


Expand Down Expand Up @@ -394,7 +395,7 @@ def load_checkpoint_to_model(args):
'''Set model params.'''

from pretrain_gpt import model_provider
if "llama" in args.model_size:
if "llama" in args.model_size or "yi" in args.model_size:
from transformers import LlamaForCausalLM as ModelForCausalLM
elif "mistral" in args.model_size:
from transformers import MistralForCausalLM as ModelForCausalLM
Expand Down Expand Up @@ -465,7 +466,7 @@ def _load_checkpoint(queue, args):
margs.tokenizer_model = args.tokenizer_model
load_args_from_checkpoint(margs)

if "llama2" in args.model_size:
if "llama2" in args.model_size or "yi" in args.model_size:
margs.tokenizer_type = "Llama2Tokenizer"
elif "llama3" in args.model_size:
margs.tokenizer_type = "Llama3Tokenizer"
Expand Down

0 comments on commit e33c8f7

Please sign in to comment.