From 4127b99ed6fd742227f9e80ddeea4f478dc1383a Mon Sep 17 00:00:00 2001 From: ZehuaCao <47251317+Romanticoseu@users.noreply.github.com> Date: Thu, 30 May 2024 16:16:10 +0800 Subject: [PATCH] Fix null pointer dereferences error. (#11125) * delete unused function on tgi_server * update * update * fix style --- python/llm/dev/print_glib_requirement.py | 15 ++--- python/llm/portable-zip/chat.py | 3 +- .../serving/fastchat/tgi_api_server.py | 59 ------------------- .../llm/src/ipex_llm/transformers/convert.py | 27 +++++---- 4 files changed, 26 insertions(+), 78 deletions(-) diff --git a/python/llm/dev/print_glib_requirement.py b/python/llm/dev/print_glib_requirement.py index 758116f560e..6e49ba60188 100644 --- a/python/llm/dev/print_glib_requirement.py +++ b/python/llm/dev/print_glib_requirement.py @@ -29,13 +29,14 @@ def _check_version(filename, flag="GLIBC"): if flag == "GLIBCXX": subfile = _check_glibcxx_version(filename) max_version = None - for version_string in subfile.split(): - try: - version = Version(version_string.split("_")[1]) - if max_version is None or version > max_version: - max_version = version - except Exception: - pass + if subfile: + for version_string in subfile.split(): + try: + version = Version(version_string.split("_")[1]) + if max_version is None or version > max_version: + max_version = version + except Exception: + pass return max_version diff --git a/python/llm/portable-zip/chat.py b/python/llm/portable-zip/chat.py index 7794e50d629..c851e947e93 100644 --- a/python/llm/portable-zip/chat.py +++ b/python/llm/portable-zip/chat.py @@ -113,7 +113,8 @@ def greedy_generate(model, tokenizer, input_ids, past_key_values, max_gen_len, s if pred_token_idx == tokenizer.eos_token_id: break - print(" ".join(generated_text[pos:]).strip('\n<'), flush=True) + if generated_text: + print(" ".join(generated_text[pos:]).strip('\n<'), flush=True) return past_key_values @torch.no_grad() diff --git a/python/llm/src/ipex_llm/serving/fastchat/tgi_api_server.py b/python/llm/src/ipex_llm/serving/fastchat/tgi_api_server.py index 499239c90b5..d26344f2b02 100644 --- a/python/llm/src/ipex_llm/serving/fastchat/tgi_api_server.py +++ b/python/llm/src/ipex_llm/serving/fastchat/tgi_api_server.py @@ -607,65 +607,6 @@ async def chat_completion_stream_generator( yield f"data: {json.dumps(chunk.model_dump(exclude_unset=True), ensure_ascii=False)}\n\n" yield "data: [DONE]\n\n" -async def generate_completion_stream_generator( - request: CompletionRequest, n: int, worker_addr: str -): - model_name = request.model - id = f"cmpl-{shortuuid.random()}" - finish_stream_events = [] - for text in request.prompt: - for i in range(n): - previous_text = "" - gen_params = await get_gen_params( - request.model, - worker_addr, - text, - temperature=request.temperature, - top_p=request.top_p, - top_k=request.top_k, - presence_penalty=request.presence_penalty, - frequency_penalty=request.frequency_penalty, - max_tokens=request.max_tokens, - logprobs=request.logprobs, - echo=request.echo, - stop=request.stop, - ) - async for content in generate_completion_stream(gen_params, worker_addr): - if content["error_code"] != 0: - yield f"data: {json.dumps(chunk.model_dump(exclude_unset=True), ensure_ascii=False)}\n\n" - yield "data: [DONE]\n\n" - return - decoded_unicode = content["text"].replace("\ufffd", "") - delta_text = decoded_unicode[len(previous_text) :] - previous_text = ( - decoded_unicode - if len(decoded_unicode) > len(previous_text) - else previous_text - ) - # todo: index is not apparent - choice_data = CompletionResponseStreamChoice( - index=i, - text=delta_text, - logprobs=create_openai_logprobs(content.get("logprobs", None)), - finish_reason=content.get("finish_reason", None), - ) - chunk = CompletionStreamResponse( - id=id, - object="text_completion", - choices=[choice_data], - model=model_name, - ) - if len(delta_text) == 0: - if content.get("finish_reason", None) is not None: - finish_stream_events.append(chunk) - continue - yield f"data: {json.dumps(chunk.model_dump(exclude_unset=True), ensure_ascii=False)}\n\n" - # There is not "content" field in the last delta message, so exclude_none to exclude field "content". - for finish_chunk in finish_stream_events: - yield f"data: {json.dumps(chunk.model_dump(exclude_unset=True), ensure_ascii=False)}\n\n" - yield "data: [DONE]\n\n" - - async def generate_completion_stream(payload: Dict[str, Any], worker_addr: str): controller_address = app_settings.controller_address async with httpx.AsyncClient() as client: diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index e73fe09aaed..04efa8ae4a2 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -487,20 +487,25 @@ def replace_with_low_bit_linear_for_module(model, qtype, module_name=None, FP16Linear, BF16Linear has_been_replaced = False + splits = [] if "." in module_name: splits = module_name.split(".") - parent_module = getattr(model, splits[0]) - - if "lm_head" not in module_name: - for split in splits[1:-2]: - new_module = getattr(parent_module, split) - parent_module = new_module - module = getattr(parent_module, splits[-2]) - module_name = splits[-2] + if not splits: + invalidInputError(False, + "Please provide a valid module_name with hierarchical structure") else: - module = parent_module - parent_module = model - module_name = splits[0] + parent_module = getattr(model, splits[0]) + + if "lm_head" not in module_name: + for split in splits[1:-2]: + new_module = getattr(parent_module, split) + parent_module = new_module + module = getattr(parent_module, splits[-2]) + module_name = splits[-2] + else: + module = parent_module + parent_module = model + module_name = splits[0] if current_key_name is None: current_key_name = []