Skip to content

Commit

Permalink
LLM: Fix vLLM CPU model convert mismatch (#11254)
Browse files Browse the repository at this point in the history
Fix vLLM CPU model convert mismatch.
  • Loading branch information
xiangyuT authored Jun 7, 2024
1 parent 42fab48 commit 4b07712
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 28 deletions.
7 changes: 3 additions & 4 deletions python/llm/src/ipex_llm/vllm/cpu/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def from_engine_args(
engine_args: AsyncEngineArgs,
start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
load_in_low_bit: str = "sym_int4",
load_in_low_bit: Optional[str] = None,
) -> "AsyncLLMEngine":
"""Creates an async LLM engine from the engine arguments."""
# Enable ipex-llm optimizations
Expand Down Expand Up @@ -97,7 +97,7 @@ def __init__(
max_context_len_to_capture: Optional[int] = None,
max_seq_len_to_capture: int = 8192,
disable_custom_all_reduce: bool = False,
load_in_low_bit: str = "sym_int4",
load_in_low_bit: Optional[str] = None,
**kwargs,
) -> None:
if "disable_log_stats" not in kwargs:
Expand Down Expand Up @@ -136,8 +136,7 @@ def from_engine_args(
cls,
engine_args: EngineArgs,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
load_in_low_bit: str = "sym_int4",
# ipex_llm_optimize_mode: str = 'NATIVE',
load_in_low_bit: Optional[str] = None,
) -> "LLMEngine":
"""Creates an LLM engine from the engine arguments."""
# Create the engine configs.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def parse_args():
parser.add_argument(
"--load-in-low-bit",
type=str,
default="sym_int4",
default=None,
help="Low-bit quantization for IPEX-LLM models")
return parser.parse_args()

Expand Down
79 changes: 56 additions & 23 deletions python/llm/src/ipex_llm/vllm/cpu/model_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import torch
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.model_loader.utils import get_model_architecture
from vllm.model_executor.models.llama import LlamaMLP, LlamaAttention
from vllm.model_executor.models.qwen2 import Qwen2MLP, Qwen2Attention
from vllm.model_executor.models.qwen import QWenMLP, QWenAttention
Expand All @@ -24,11 +25,14 @@
from vllm.attention import Attention, AttentionMetadata
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.config import DeviceConfig
from vllm.logger import init_logger

from vllm._C import ops
from ipex_llm.utils.common import invalidInputError
from typing import List, Optional, Tuple, Union

logger = init_logger(__name__)


def _MLP_forward(self, x):
gate_up = self.gate_up_proj(x)
Expand Down Expand Up @@ -59,10 +63,10 @@ def _QWen_Attention_forward(
kv_cache: Tuple[torch.Tensor, torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv = self.c_attn(hidden_states)
qkv = self.c_attn(hidden_states).to(dtype=kv_cache.dtype)
q, k, v = qkv.chunk(chunks=3, dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata, self.kv_scale)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output = self.c_proj(attn_output)
return output

Expand All @@ -74,6 +78,21 @@ def _QWen_MLP_forward(self, x):
return x


def _Qwen2_Attention_forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv = self.qkv_proj(hidden_states).to(dtype=kv_cache.dtype)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output = self.o_proj(attn_output)
return output


def _ChatGLM_MLP_forward(self, hidden_states):
# [s, b, 4hp]
intermediate_parallel = self.dense_h_to_4h(hidden_states)
Expand All @@ -90,11 +109,11 @@ def _Baichuan_Attention_forward(
kv_cache: Tuple[torch.Tensor, torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv = self.W_pack(hidden_states)
qkv = self.W_pack(hidden_states).to(dtype=kv_cache.dtype)
q, k, v = qkv.chunk(chunks=3, dim=-1)
if self.postion_embedding != "ALIBI":
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata, self.kv_scale)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output = self.o_proj(attn_output)
return output

Expand All @@ -106,7 +125,7 @@ def _ChatGLM_Attention_forward(
kv_cache: Tuple[torch.Tensor, torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv = self.query_key_value(hidden_states)
qkv = self.query_key_value(hidden_states).to(dtype=kv_cache.dtype)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(position_ids, q, k)
context_layer = self.attn(
Expand All @@ -123,18 +142,25 @@ def _ChatGLM_Attention_forward(
LlamaMLP: _MLP_forward,
Qwen2MLP: _MLP_forward,
BaiChuanMLP: _MLP_forward,
QWenMLP: _QWen_MLP_forward,
# QWenMLP: _QWen_MLP_forward,
GLMMLP: _ChatGLM_MLP_forward
}

_REPLACED_ATTENTION_LAYERS = {
LlamaAttention: _Attention_forward,
Qwen2Attention: _Attention_forward,
QWenAttention: _QWen_Attention_forward,
Qwen2Attention: _Qwen2_Attention_forward,
# QWenAttention: _QWen_Attention_forward,
BaiChuanAttention: _Baichuan_Attention_forward,
GLMAttention: _ChatGLM_Attention_forward
}

_IPEX_LLM_SUPPORTED_MODELS = [
"LlamaForCausalLM",
"BaichuanForCausalLM",
"ChatGLMForCausalLM",
"Qwen2ForCausalLM",
]


def _model_mlp_convert():
for module, replaced_func in _REPLACED_MLP_LAYERS.items():
Expand All @@ -147,6 +173,8 @@ def _model_attention_convert():


def _ipex_llm_convert(load_in_low_bit):
if load_in_low_bit is None:
return
from vllm.worker.cpu_model_runner import CPUModelRunner
import vllm.model_executor.model_loader as model_loader
setattr(CPUModelRunner, "load_model", get_load_function(load_in_low_bit))
Expand Down Expand Up @@ -206,6 +234,26 @@ def _ipex_llm_rmsnorm_forward(

def get_load_function(low_bit):
def _ipex_llm_load_model(self) -> None:
model_class = get_model_architecture(self.model_config)[1]
cur_model_list = ", ".join(_IPEX_LLM_SUPPORTED_MODELS)
if low_bit != "bf16":
invalidInputError(model_class in _IPEX_LLM_SUPPORTED_MODELS,
f"Currently IPEX-LLM vLLM convert only support {cur_model_list}.")
else:
if model_class not in _IPEX_LLM_SUPPORTED_MODELS:
logger.warning(
f"Currently IPEX-LLM vLLM convert only support {cur_model_list}."
)
self.model = get_model(
model_config=self.model_config,
load_config=self.load_config,
device_config=self.device_config,
vision_language_config=self.vision_language_config,
lora_config=self.lora_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config)
return

_model_mlp_convert()
_model_attention_convert()

Expand All @@ -221,19 +269,4 @@ def _ipex_llm_load_model(self) -> None:
from ipex_llm import optimize_model
optimize_model(self.model, low_bit=low_bit, torch_dtype=self.model_config.dtype)

if self.lora_config:
invalidInputError(hasattr(self.model, "supported_lora_modules")
and self.model.supported_lora_modules,
"Model does not support LoRA")
invalidInputError(hasattr(self.model, "embedding_modules"),
"Model does not have embedding_modules")
invalidInputError(hasattr(self.model, "embedding_padding_modules"),
"Model does not have embedding_padding_modules")
self.lora_manager = LRUCacheWorkerLoRAManager(
self.scheduler_config.max_num_seqs,
self.scheduler_config.max_num_batched_tokens +
self.scheduler_config.max_paddings, self.vocab_size,
self.lora_config, self.device, self.model.embedding_modules,
self.model.embedding_padding_modules)
self.model = self.lora_manager.create_lora_manager(self.model)
return _ipex_llm_load_model

0 comments on commit 4b07712

Please sign in to comment.