From 75b23349e7a5db59ac216f1de97c5f2dce3d318e Mon Sep 17 00:00:00 2001 From: Lai Wei Date: Sun, 6 Oct 2024 05:15:32 +0000 Subject: [PATCH 1/2] add mllama convert --- .../inference/checkpoint_converter_fsdp_hf.py | 5 +++-- src/llama_recipes/inference/model_utils.py | 14 ++++++++------ 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/src/llama_recipes/inference/checkpoint_converter_fsdp_hf.py b/src/llama_recipes/inference/checkpoint_converter_fsdp_hf.py index a8c5e646f..f57db9d06 100644 --- a/src/llama_recipes/inference/checkpoint_converter_fsdp_hf.py +++ b/src/llama_recipes/inference/checkpoint_converter_fsdp_hf.py @@ -25,7 +25,8 @@ def main( fsdp_checkpoint_path="", # Path to FSDP Sharded model checkpoints consolidated_model_path="", # Path to save the HF converted model checkpoints - HF_model_path_or_name="" # Path/ name of the HF model that include config.json and tokenizer_config.json (e.g. meta-llama/Llama-2-7b-chat-hf) + HF_model_path_or_name="", # Path/ name of the HF model that include config.json and tokenizer_config.json (e.g. meta-llama/Llama-2-7b-chat-hf) + multimodal=False # Use MllamaConfig for llama 3.2 vision models ): try: @@ -50,7 +51,7 @@ def main( #load the HF model definition from config - model_def = load_llama_from_config(HF_model_path_or_name) + model_def = load_llama_from_config(HF_model_path_or_name, multimodal) print("model is loaded from config") #load the FSDP sharded checkpoints into the model model = load_sharded_model_single_gpu(model_def, fsdp_checkpoint_path) diff --git a/src/llama_recipes/inference/model_utils.py b/src/llama_recipes/inference/model_utils.py index 2b150eea3..6e8d59a81 100644 --- a/src/llama_recipes/inference/model_utils.py +++ b/src/llama_recipes/inference/model_utils.py @@ -4,7 +4,7 @@ from llama_recipes.utils.config_utils import update_config from llama_recipes.configs import quantization_config as QUANT_CONFIG from peft import PeftModel -from transformers import AutoModelForCausalLM, LlamaForCausalLM, LlamaConfig +from transformers import AutoModelForCausalLM, LlamaForCausalLM, LlamaConfig, MllamaForConditionalGeneration, MllamaConfig from warnings import warn # Function to load the main model for text generation @@ -41,9 +41,11 @@ def load_peft_model(model, peft_model): return peft_model # Loading the model from config to load FSDP checkpoints into that -def load_llama_from_config(config_path): - model_config = LlamaConfig.from_pretrained(config_path) - model = LlamaForCausalLM(config=model_config) +def load_llama_from_config(config_path, multimodal): + if multimodal: + model_config = MllamaConfig.from_pretrained(config_path) + model = MllamaForConditionalGeneration(config=model_config) + else: + model_config = LlamaConfig.from_pretrained(config_path) + model = LlamaForCausalLM(config=model_config) return model - - \ No newline at end of file From 54d18952ad7df7ebb9773f8554a1770e935beaee Mon Sep 17 00:00:00 2001 From: Lai Wei Date: Sun, 6 Oct 2024 05:31:45 +0000 Subject: [PATCH 2/2] add doc --- recipes/quickstart/inference/local_inference/README.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/recipes/quickstart/inference/local_inference/README.md b/recipes/quickstart/inference/local_inference/README.md index 7691566ca..3f574d09f 100644 --- a/recipes/quickstart/inference/local_inference/README.md +++ b/recipes/quickstart/inference/local_inference/README.md @@ -96,6 +96,14 @@ python inference.py --model_name --prompt_file