Skip to content

Commit

Permalink
[NPU] Support minicpm-v with python cpp backend (#12637)
Browse files Browse the repository at this point in the history
  • Loading branch information
plusbang authored Jan 2, 2025
1 parent f289f68 commit 534566e
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 31 deletions.
2 changes: 1 addition & 1 deletion python/llm/dev/benchmark/all-in-one/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,7 +650,7 @@ def transformers_int4_npu_win(repo_id,
load_time = end - st
print(">> loading of model costs {}s".format(load_time))

if not hasattr(model, "model_ptr"):
if not hasattr(model, "model_ptr") or repo_id in MINICPM_V_IDS:
model = BenchmarkWrapper(model)

result = {}
Expand Down
6 changes: 3 additions & 3 deletions python/llm/src/ipex_llm/transformers/npu_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,8 +301,7 @@ def optimize_npu_model(cls, *args, **kwargs):
model.share_memory()

if not pipeline:
if (not hasattr(model, 'llm') and
model.config.model_type in ["qwen2", "llama", "minicpm"]):
if model.config.model_type in ["qwen2", "llama", "minicpm"]:
from ipex_llm.transformers.npu_models.convert import optimize_llm_single_process
optimize_llm_single_process(
llm,
Expand All @@ -312,7 +311,8 @@ def optimize_npu_model(cls, *args, **kwargs):
group_size=quantization_group_size,
qtype=qtype,
save_directory=save_directory,
fuse_layers=fuse_layers
fuse_layers=fuse_layers,
has_llm=hasattr(model, "llm")
)
else:
optimize_llm(
Expand Down
58 changes: 40 additions & 18 deletions python/llm/src/ipex_llm/transformers/npu_models/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,8 @@ def optimize_llm_single_process(
group_size: int,
qtype: str,
save_directory: str,
fuse_layers: int=None
fuse_layers: int=None,
has_llm: bool=False
):
from ipex_llm.transformers.npu_pipeline_model.convert_pipeline import convert_llm
from .npu_llm_cpp import load_model_from_file
Expand All @@ -468,8 +469,13 @@ def optimize_llm_single_process(
model.kv_len = kv_len
model.model_ptr = model_ptr
model.save_directory = save_directory
model.vocab_size = model.config.vocab_size
if model.config.vocab_size == 151666:
# for MiniCPM-V 2.6, 152064 is vocab_size of Qwen2-7B
model.vocab_size = 152064
else:
model.vocab_size = model.config.vocab_size
model.logits_buffer = torch.empty(1, 1, model.vocab_size, dtype=torch.float32)
model.max_prompt_len = max_prompt_len
except:
invalidInputError(False,
"False to InitLLMPipeline.")
Expand All @@ -478,9 +484,10 @@ def optimize_llm_single_process(
general_convert(model, PreTrainedModel, prepare_input_ids, "prepare_inputs_for_generation")
general_convert(model, PreTrainedModel, causal_lm_forward)
# patch generate function
import types
model.original_generate = model.generate
model.generate = types.MethodType(generate, model)
if not has_llm:
import types
model.original_generate = model.generate
model.generate = types.MethodType(generate, model)
return model


Expand All @@ -491,9 +498,10 @@ def prepare_input_ids(
else: # prefill, reset the model here
from .npu_llm_cpp import reset
reset(self.model_ptr)
model_inputs = {
"input_ids": input_ids
}
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
return model_inputs


Expand All @@ -511,17 +519,31 @@ def causal_lm_forward(
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
from .npu_llm_cpp import run_prefill_with_logits, run_decode_with_logits
if isinstance(input_ids[0], torch.Tensor):
input_list = input_ids[0].flatten().tolist()
else:
input_list = input_ids[0]
input_length = len(input_list)
if input_length > 1:
logits = run_prefill_with_logits(self.model_ptr, input_list,
self.logits_buffer, self.vocab_size)
if input_ids is not None:
if isinstance(input_ids[0], torch.Tensor):
input_list = input_ids[0].flatten().tolist()
else:
input_list = input_ids[0]
input_length = len(input_list)
if input_length > 1:
logits = run_prefill_with_logits(self.model_ptr, input_list,
self.logits_buffer, self.vocab_size)
else:
logits = run_decode_with_logits(self.model_ptr, input_list[0],
self.logits_buffer, self.vocab_size)
elif inputs_embeds is not None:
seq_len = inputs_embeds.shape[1]
pad_len = self.max_prompt_len - seq_len
inputs_embeds = torch.nn.functional.pad(inputs_embeds.to(torch.float16),
(0, 0, 0, pad_len), value=0.0)
logits = run_prefill_with_logits(self.model_ptr, None, self.logits_buffer,
self.vocab_size, inputs_embeds, seq_len)
else:
logits = run_decode_with_logits(self.model_ptr, input_list[0],
self.logits_buffer, self.vocab_size)
invalidInputError(False, "Please specify either input_ids or inputs_embeds.")

if self.config.vocab_size == 151666:
# for MiniCPM-V 2.6
logits = logits[:, :, :151666]

return CausalLMOutputWithPast(
loss=None,
Expand Down
27 changes: 18 additions & 9 deletions python/llm/src/ipex_llm/transformers/npu_models/npu_llm_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def get_shared_lib_info(lib_base_name: str):
_lib.load_model_from_file.argtypes = [ctypes.c_char_p]
_lib.load_model_from_file.restype = ctypes.c_void_p

_lib.run_prefill.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_int), ctypes.c_int,
ctypes.c_float]
_lib.run_prefill.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int,
ctypes.c_float, ctypes.c_bool]
_lib.run_prefill.restype = ctypes.POINTER(ctypes.c_float)

_lib.run_decode.argtypes = [ctypes.c_void_p, ctypes.c_int, ctypes.c_float]
Expand All @@ -61,8 +61,10 @@ def get_shared_lib_info(lib_base_name: str):
_lib.reset.argtypes = [ctypes.c_void_p]
_lib.reset.restype = None

_lib.run_prefill_with_logits.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_int),
ctypes.c_int, ctypes.POINTER(ctypes.c_float), ctypes.c_int]
_lib.run_prefill_with_logits.argtypes = [ctypes.c_void_p, ctypes.c_void_p,
ctypes.c_int, ctypes.POINTER(ctypes.c_float),
ctypes.c_int, ctypes.c_bool]

_lib.run_prefill_with_logits.restype = None

_lib.run_decode_with_logits.argtypes = [ctypes.c_void_p, ctypes.c_int,
Expand All @@ -77,7 +79,7 @@ def load_model_from_file(model_dir: str):
def run_prefill(model_ptr, input_ids, vocab_size, repetition_penalty=1.0):
input_ptr = (ctypes.c_int32 * len(input_ids))(*input_ids)
input_len = len(input_ids)
plogits = _lib.run_prefill(model_ptr, input_ptr, input_len, repetition_penalty)
plogits = _lib.run_prefill(model_ptr, input_ptr, input_len, repetition_penalty, False)
new_token = _lib.llm_sample_token(plogits, True, vocab_size)
return new_token

Expand All @@ -88,12 +90,19 @@ def run_decode(model_ptr, input_id, vocab_size, repetition_penalty=1.0):
return new_token


def run_prefill_with_logits(model_ptr, input_ids, logits, vocab_size):
input_ptr = (ctypes.c_int32 * len(input_ids))(*input_ids)
input_len = len(input_ids)
def run_prefill_with_logits(model_ptr, input_ids, logits, vocab_size,
inputs_embeds=None, seq_len=None):
if input_ids is not None:
input_ptr = (ctypes.c_int32 * len(input_ids))(*input_ids)
input_len = len(input_ids)
else:
input_ptr = inputs_embeds.contiguous().data.data_ptr()
input_ptr = ctypes.cast(input_ptr, ctypes.c_void_p)
input_len = seq_len
logits_ptr = logits.data.data_ptr()
logits_ptr = ctypes.cast(logits_ptr, ctypes.POINTER(ctypes.c_float))
_lib.run_prefill_with_logits(model_ptr, input_ptr, input_len, logits_ptr, vocab_size)
_lib.run_prefill_with_logits(model_ptr, input_ptr, input_len, logits_ptr,
vocab_size, (input_ids is None))
return logits


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ def convert_lm_head_and_embedding(model, temp_dir, weight_dir,
lm_head_n_splits = 1
asym = getattr(model.config, "asym", False)

if vocab_size == 151666:
# for MiniCPM-V 2.6 lm_head on NPU
vocab_size = 152064

if not isinstance(lm_head, SlicedLMHead):
asym = lm_head.qtype == "asym_int4_rtn"
if asym:
Expand Down

0 comments on commit 534566e

Please sign in to comment.