Skip to content

Commit

Permalink
fix python style (intel-analytics#9742)
Browse files Browse the repository at this point in the history
* fix python style

* fix

* fix
  • Loading branch information
Uxito-Ada authored Dec 21, 2023
1 parent 782d20e commit de1ef6c
Showing 1 changed file with 14 additions and 8 deletions.
22 changes: 14 additions & 8 deletions python/llm/src/bigdl/llm/transformers/gguf/models/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import os
import torch
from accelerate import init_empty_weights
from accelerate.utils import set_module_tensor_to_device
from accelerate.utils import set_module_tensor_to_device as fill_model
from tempfile import NamedTemporaryFile
from transformers import MixtralConfig, MixtralForCausalLM, LlamaTokenizer

Expand Down Expand Up @@ -57,12 +57,18 @@ def process_mixtral(name, tensor):
module_name = get_mixtral_module_name(name)
if 'ffn_gate_inp' in name:
# gguf weight needs to reshape for ffn_gate_inp
set_module_tensor_to_device(model, module_name, "cpu", \
tensor.reshape(num_local_experts, hidden_size), dtype=dtype)
fill_model(model,
module_name,
"cpu",
tensor.reshape(num_local_experts, hidden_size),
dtype=dtype)
else:
set_module_tensor_to_device(model, module_name, "cpu", \
tensor, dtype=dtype)

fill_model(model,
module_name,
"cpu",
tensor,
dtype=dtype)

tensor_loader = loader.tensor_loader
tensor_loader.load_while_process(process_mixtral)

Expand All @@ -83,6 +89,7 @@ def process_mixtral(name, tensor):

return model, tokenizer


def get_mixtral_module_name(name):
if name == 'token_embd.weight':
return 'model.embed_tokens.weight'
Expand All @@ -92,7 +99,7 @@ def get_mixtral_module_name(name):
return 'lm_head.weight'
layer_id = name.split('.')[1]
if 'attn_q' in name:
return f'model.layers.{layer_id}.self_attn.q_proj.weight'
return f'model.layers.{layer_id}.self_attn.q_proj.weight'
if 'attn_k' in name:
return f'model.layers.{layer_id}.self_attn.k_proj.weight'
if 'attn_v' in name:
Expand All @@ -115,4 +122,3 @@ def get_mixtral_module_name(name):
if 'ffn_up' in name:
return f'model.layers.{layer_id}.' + \
f'block_sparse_moe.experts.{local_expert_id}.w3.weight'

0 comments on commit de1ef6c

Please sign in to comment.