diff --git a/python/llm/example/GPU/Lightweight-Serving/README.md b/python/llm/example/GPU/Lightweight-Serving/README.md index 104032c6c9a..4cb29db1efc 100644 --- a/python/llm/example/GPU/Lightweight-Serving/README.md +++ b/python/llm/example/GPU/Lightweight-Serving/README.md @@ -18,6 +18,10 @@ pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-exte pip install fastapi uvicorn openai pip install gradio # for gradio web UI conda install -c conda-forge -y gperftools=2.10 # to enable tcmalloc + +# for internlm-xcomposer2-vl-7b +pip install transformers==4.31.0 +pip install accelerate timm==0.4.12 sentencepiece==0.1.99 gradio==3.44.4 markdown2==2.4.10 xlsxwriter==3.1.2 einops ``` #### 1.2 Installation on Windows @@ -172,10 +176,39 @@ curl http://localhost:8000/v1/chat/completions \ }' ``` -#### /v1/completions +##### Image input +image input only supports [internlm-xcomposer2-vl-7b](https://huggingface.co/internlm/internlm-xcomposer2-vl-7b) now, and it must install transformers==4.31.0 to run. ```bash +wget -O ./test.jpg http://farm6.staticflickr.com/5268/5602445367_3504763978_z.jpg +curl http://localhost:8000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "internlm-xcomposer2-vl-7b", + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "What'\''s in this image?" + }, + { + "type": "image_url", + "image_url": { + "url": "./test.jpg" + } + } + ] + } + ], + "max_tokens": 128 + }' +``` +#### /v1/completions + +```bash curl http://localhost:8000/v1/completions \ -H "Content-Type: application/json" \ -d '{ diff --git a/python/llm/src/ipex_llm/serving/fastapi/api_server.py b/python/llm/src/ipex_llm/serving/fastapi/api_server.py index 0cc12bf35e9..88c856180e5 100644 --- a/python/llm/src/ipex_llm/serving/fastapi/api_server.py +++ b/python/llm/src/ipex_llm/serving/fastapi/api_server.py @@ -47,12 +47,17 @@ class InputsRequest(BaseModel): inputs: str parameters: Optional[Parameters] = None + image_list: Optional[list] = None stream: Optional[bool] = False req_type: str = 'completion' class ChatCompletionRequest(BaseModel): - messages: List[ChatMessage] + messages: Union[ + str, + List[Dict[str, str]], + List[Dict[str, Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]]], + ] model: str max_tokens: Optional[int] = None min_tokens: Optional[int] = None @@ -266,7 +271,7 @@ async def generate_stream(inputs_request: InputsRequest): def get_prompt(messages) -> str: if "codegeex" in local_model.model_name.lower(): - query = messages[-1].content + query = messages[-1]["content"] if len(messages) <= 1: history = [] else: @@ -277,18 +282,33 @@ def get_prompt(messages) -> str: return inputs else: prompt = "" + image_list = [] for msg in messages: - role = msg.role - content = msg.content - if role == "system": - prompt += f"<>\n{content}\n<>\n\n" - elif role == "user": - prompt += f"[INST] {content} [/INST] " - elif role == "assistant": - prompt += f"{content} " + role = msg["role"] + content = msg["content"] + if type(content) == list: + image_list1 = [ + item["image_url"]["url"] + for item in content + if item["type"] == "image_url" + ] + image_list.extend(image_list1) + text_list = [ + item["text"] + for item in content + if item["type"] == "text" + ] + prompt = "".join(text_list) else: - invalidInputError(False, f"Unknown role: {role}") - return prompt.strip() + if role == "system": + prompt += f"<>\n{content}\n<>\n\n" + elif role == "user": + prompt += f"[INST] {content} [/INST] " + elif role == "assistant": + prompt += f"{content} " + else: + invalidInputError(False, f"Unknown role: {role}") + return prompt.strip(), image_list def set_parameters(req): @@ -313,11 +333,12 @@ def set_parameters(req): @app.post("/v1/chat/completions") async def create_chat_completion(request: ChatCompletionRequest): - print(request) model_name = local_model.model_name + prompt, image_list = get_prompt(request.messages) inputs_request = InputsRequest( - inputs=get_prompt(request.messages), + inputs=prompt, parameters=set_parameters(request), + image_list=image_list if len(image_list) >= 1 else None, stream=request.stream, req_type="chat" ) diff --git a/python/llm/src/ipex_llm/serving/fastapi/model_worker.py b/python/llm/src/ipex_llm/serving/fastapi/model_worker.py index 099cfc0c700..0fe0f88cc44 100644 --- a/python/llm/src/ipex_llm/serving/fastapi/model_worker.py +++ b/python/llm/src/ipex_llm/serving/fastapi/model_worker.py @@ -60,15 +60,40 @@ async def add_request(self, tokenizer): tmp_result = await self.waiting_requests.get() request_id, prompt_request = tmp_result plain_texts = prompt_request.inputs - inputs = tokenizer(plain_texts, return_tensors="pt", padding=True) - input_ids = inputs.input_ids.to('xpu') + input_ids = None + inputs_embeds = None + if "internlm-xcomposer2-vl-7b" in self.model_name.lower(): + lines = [ + "You are an AI assistant whose name is InternLM-XComposer (浦语·灵笔).", + "- InternLM-XComposer (浦语·灵笔) is a multi-modality conversational language " + "model that is developed by Shanghai AI Laboratory (上海人工智能实验室). " + "It is designed to be helpful, honest, and harmless.", + "- InternLM-XComposer (浦语·灵笔) can understand and communicate fluently in " + "the language chosen by the user such as English and 中文.", + "- InternLM-XComposer (浦语·灵笔) is capable of comprehending and " + "articulating responses effectively based on the provided image." + ] + meta_instruction = "\n".join(lines) + if prompt_request.image_list is None: + inputs = self.model.build_inputs(tokenizer, plain_texts, [], meta_instruction) + im_mask = torch.zeros(inputs['input_ids'].shape[:2]).bool() + input_ids = inputs["input_ids"].to('xpu') + else: + image = self.model.encode_img(prompt_request.image_list[0]) + plain_texts = "" + plain_texts + inputs, im_mask = self.model.interleav_wrap_chat(tokenizer, plain_texts, + image, [], meta_instruction) + inputs_embeds = inputs["inputs_embeds"].to('xpu').to(self.dtype) + else: + inputs = tokenizer(plain_texts, return_tensors="pt", padding=True) + input_ids = inputs.input_ids.to('xpu') parameters = prompt_request.parameters - return input_ids, parameters, request_id + return input_ids, parameters, request_id, inputs_embeds @torch.no_grad() async def process_step(self, tokenizer, result_dict): if not self.waiting_requests.empty(): - input_ids, parameters, request_id = await self.add_request(tokenizer) + input_ids, parameters, request_id, inputs_embeds = await self.add_request(tokenizer) self.streamer[request_id] = TextIteratorStreamer(tokenizer, skip_prompt=True) def model_generate(): @@ -78,8 +103,18 @@ def model_generate(): tokenizer.convert_tokens_to_ids("<|user|>"), tokenizer.convert_tokens_to_ids("<|observation|>")] generate_kwargs["eos_token_id"] = eos_token_id - self.model.generate(input_ids, - streamer=self.streamer[request_id], **generate_kwargs) + elif "internlm-xcomposer2-vl-7b" in self.model_name.lower(): + eos_token_id = [ + tokenizer.eos_token_id, + tokenizer.convert_tokens_to_ids(['[UNUSED_TOKEN_145]'])[0] + ] + generate_kwargs["eos_token_id"] = eos_token_id + if input_ids is not None: + self.model.generate(input_ids, + streamer=self.streamer[request_id], **generate_kwargs) + elif inputs_embeds is not None: + self.model.generate(inputs_embeds=inputs_embeds, + streamer=self.streamer[request_id], **generate_kwargs) torch.xpu.empty_cache() torch.xpu.synchronize() diff --git a/python/llm/src/ipex_llm/utils/benchmark_util.py b/python/llm/src/ipex_llm/utils/benchmark_util.py index cbbd6c6a9e3..d64631f1f4c 100644 --- a/python/llm/src/ipex_llm/utils/benchmark_util.py +++ b/python/llm/src/ipex_llm/utils/benchmark_util.py @@ -574,7 +574,7 @@ def _prepare_model_inputs( if input_name == "input_ids" and "inputs_embeds" in model_kwargs: if not self.config.is_encoder_decoder: has_inputs_embeds_forwarding = "inputs_embeds" in set( - inspect.signature(self.prepare_inputs_for_generation).parameters.keys() + inspect.signature(self.model.prepare_inputs_for_generation).parameters.keys() ) if not has_inputs_embeds_forwarding: raise ValueError(