Skip to content

Commit

Permalink
merge_qkv if quant_method is 'gptq'
Browse files Browse the repository at this point in the history
  • Loading branch information
liu-shaojun committed Jun 26, 2024
1 parent 9f6e5b4 commit 9798e05
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ conda activate llm
pip install --pre --upgrade ipex-llm[all] --extra-index-url https://download.pytorch.org/whl/cpu
pip install transformers==4.34.0
BUILD_CUDA_EXT=0 pip install git+https://github.com/PanQiWei/AutoGPTQ.git@1de9ab6
pip install optimum==0.14.0
pip install optimum==1.14.0
```

On Windows:
Expand All @@ -30,7 +30,7 @@ pip install --pre --upgrade ipex-llm[all]
pip install transformers==4.34.0
set BUILD_CUDA_EXT=0
pip install git+https://github.com/PanQiWei/AutoGPTQ.git@1de9ab6
pip install optimum==0.14.0
pip install optimum==1.14.0
```

### 2. Run
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import argparse

from ipex_llm.transformers import AutoModelForCausalLM
from transformers import LlamaTokenizer, GPTQConfig
from transformers import LlamaTokenizer, AutoTokenizer

# you could tune the prompt based on your own model,
# here the prompt tuning refers to https://huggingface.co/georgesung/llama2_7b_chat_uncensored#prompt-style
Expand Down Expand Up @@ -50,7 +50,10 @@
trust_remote_code=True,)

# Load tokenizer
tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True)
if "qwen" in model_path.lower():
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
else:
tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True)

# Generate predicted tokens
with torch.inference_mode():
Expand Down
9 changes: 8 additions & 1 deletion python/llm/src/ipex_llm/transformers/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,7 +734,14 @@ def _optimize_pre(model):
# for qwen2
if model.config.model_type == "qwen2":
from ipex_llm.transformers.models.qwen2 import merge_qkv
model.apply(merge_qkv)
# Skip merge_qkv if quant_method is 'gptq'
should_apply_merge_qkv = (
not hasattr(model.config, "quantization_config") or
not hasattr(model.config.quantization_config, "quant_method") or
model.config.quantization_config.quant_method != "gptq"
)
if should_apply_merge_qkv:
model.apply(merge_qkv)
from ipex_llm.transformers.models.qwen2 import padding_mlp
model.apply(padding_mlp)
if model.config.model_type == "qwen2_moe":
Expand Down
21 changes: 15 additions & 6 deletions python/llm/src/ipex_llm/transformers/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,12 +405,21 @@ def qwen2_attention_forward(
bsz, q_len, _ = hidden_states.size()
device = hidden_states.device

qkv = self.qkv_proj(hidden_states)
qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
qkv = qkv.transpose(1, 2)
query_states, key_states, value_states = qkv.split([self.num_heads,
self.num_key_value_heads,
self.num_key_value_heads], dim=1)
if hasattr(self, 'qkv_proj') and self.qkv_proj is not None:
qkv = self.qkv_proj(hidden_states)
qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
qkv = qkv.transpose(1, 2)
query_states, key_states, value_states = qkv.split([self.num_heads,
self.num_key_value_heads,
self.num_key_value_heads], dim=1)
else:
# when quant_method is 'gptq'
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
Expand Down

0 comments on commit 9798e05

Please sign in to comment.