diff --git a/python/llm/src/bigdl/llm/transformers/gguf/models/mixtral.py b/python/llm/src/bigdl/llm/transformers/gguf/models/mixtral.py index 3d77ac62a44..e54f7aa6cc2 100644 --- a/python/llm/src/bigdl/llm/transformers/gguf/models/mixtral.py +++ b/python/llm/src/bigdl/llm/transformers/gguf/models/mixtral.py @@ -17,7 +17,7 @@ import os import torch from accelerate import init_empty_weights -from accelerate.utils import set_module_tensor_to_device +from accelerate.utils import set_module_tensor_to_device as fill_model from tempfile import NamedTemporaryFile from transformers import MixtralConfig, MixtralForCausalLM, LlamaTokenizer @@ -57,12 +57,18 @@ def process_mixtral(name, tensor): module_name = get_mixtral_module_name(name) if 'ffn_gate_inp' in name: # gguf weight needs to reshape for ffn_gate_inp - set_module_tensor_to_device(model, module_name, "cpu", \ - tensor.reshape(num_local_experts, hidden_size), dtype=dtype) + fill_model(model, + module_name, + "cpu", + tensor.reshape(num_local_experts, hidden_size), + dtype=dtype) else: - set_module_tensor_to_device(model, module_name, "cpu", \ - tensor, dtype=dtype) - + fill_model(model, + module_name, + "cpu", + tensor, + dtype=dtype) + tensor_loader = loader.tensor_loader tensor_loader.load_while_process(process_mixtral) @@ -83,6 +89,7 @@ def process_mixtral(name, tensor): return model, tokenizer + def get_mixtral_module_name(name): if name == 'token_embd.weight': return 'model.embed_tokens.weight' @@ -92,7 +99,7 @@ def get_mixtral_module_name(name): return 'lm_head.weight' layer_id = name.split('.')[1] if 'attn_q' in name: - return f'model.layers.{layer_id}.self_attn.q_proj.weight' + return f'model.layers.{layer_id}.self_attn.q_proj.weight' if 'attn_k' in name: return f'model.layers.{layer_id}.self_attn.k_proj.weight' if 'attn_v' in name: @@ -115,4 +122,3 @@ def get_mixtral_module_name(name): if 'ffn_up' in name: return f'model.layers.{layer_id}.' + \ f'block_sparse_moe.experts.{local_expert_id}.w3.weight' -