Skip to content

Commit

Permalink
Support lightweight-serving with internlm-xcomposer2-vl-7b multimodal…
Browse files Browse the repository at this point in the history
… input (#11703)

* init image_list

* enable internlm-xcomposer2 image input

* update style

* add readme

* update model

* update readme
  • Loading branch information
hzjane authored Aug 5, 2024
1 parent aa98ef9 commit 493cbd9
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 22 deletions.
35 changes: 34 additions & 1 deletion python/llm/example/GPU/Lightweight-Serving/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-exte
pip install fastapi uvicorn openai
pip install gradio # for gradio web UI
conda install -c conda-forge -y gperftools=2.10 # to enable tcmalloc

# for internlm-xcomposer2-vl-7b
pip install transformers==4.31.0
pip install accelerate timm==0.4.12 sentencepiece==0.1.99 gradio==3.44.4 markdown2==2.4.10 xlsxwriter==3.1.2 einops
```

#### 1.2 Installation on Windows
Expand Down Expand Up @@ -172,10 +176,39 @@ curl http://localhost:8000/v1/chat/completions \
}'
```

#### /v1/completions
##### Image input

image input only supports [internlm-xcomposer2-vl-7b](https://huggingface.co/internlm/internlm-xcomposer2-vl-7b) now, and it must install transformers==4.31.0 to run.
```bash
wget -O ./test.jpg http://farm6.staticflickr.com/5268/5602445367_3504763978_z.jpg
curl http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "internlm-xcomposer2-vl-7b",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": "What'\''s in this image?"
},
{
"type": "image_url",
"image_url": {
"url": "./test.jpg"
}
}
]
}
],
"max_tokens": 128
}'
```

#### /v1/completions

```bash
curl http://localhost:8000/v1/completions \
-H "Content-Type: application/json" \
-d '{
Expand Down
49 changes: 35 additions & 14 deletions python/llm/src/ipex_llm/serving/fastapi/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,17 @@
class InputsRequest(BaseModel):
inputs: str
parameters: Optional[Parameters] = None
image_list: Optional[list] = None
stream: Optional[bool] = False
req_type: str = 'completion'


class ChatCompletionRequest(BaseModel):
messages: List[ChatMessage]
messages: Union[
str,
List[Dict[str, str]],
List[Dict[str, Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]]],
]
model: str
max_tokens: Optional[int] = None
min_tokens: Optional[int] = None
Expand Down Expand Up @@ -266,7 +271,7 @@ async def generate_stream(inputs_request: InputsRequest):

def get_prompt(messages) -> str:
if "codegeex" in local_model.model_name.lower():
query = messages[-1].content
query = messages[-1]["content"]
if len(messages) <= 1:
history = []
else:
Expand All @@ -277,18 +282,33 @@ def get_prompt(messages) -> str:
return inputs
else:
prompt = ""
image_list = []
for msg in messages:
role = msg.role
content = msg.content
if role == "system":
prompt += f"<<SYS>>\n{content}\n<</SYS>>\n\n"
elif role == "user":
prompt += f"[INST] {content} [/INST] "
elif role == "assistant":
prompt += f"{content} "
role = msg["role"]
content = msg["content"]
if type(content) == list:
image_list1 = [
item["image_url"]["url"]
for item in content
if item["type"] == "image_url"
]
image_list.extend(image_list1)
text_list = [
item["text"]
for item in content
if item["type"] == "text"
]
prompt = "".join(text_list)
else:
invalidInputError(False, f"Unknown role: {role}")
return prompt.strip()
if role == "system":
prompt += f"<<SYS>>\n{content}\n<</SYS>>\n\n"
elif role == "user":
prompt += f"[INST] {content} [/INST] "
elif role == "assistant":
prompt += f"{content} "
else:
invalidInputError(False, f"Unknown role: {role}")
return prompt.strip(), image_list


def set_parameters(req):
Expand All @@ -313,11 +333,12 @@ def set_parameters(req):

@app.post("/v1/chat/completions")
async def create_chat_completion(request: ChatCompletionRequest):
print(request)
model_name = local_model.model_name
prompt, image_list = get_prompt(request.messages)
inputs_request = InputsRequest(
inputs=get_prompt(request.messages),
inputs=prompt,
parameters=set_parameters(request),
image_list=image_list if len(image_list) >= 1 else None,
stream=request.stream,
req_type="chat"
)
Expand Down
47 changes: 41 additions & 6 deletions python/llm/src/ipex_llm/serving/fastapi/model_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,40 @@ async def add_request(self, tokenizer):
tmp_result = await self.waiting_requests.get()
request_id, prompt_request = tmp_result
plain_texts = prompt_request.inputs
inputs = tokenizer(plain_texts, return_tensors="pt", padding=True)
input_ids = inputs.input_ids.to('xpu')
input_ids = None
inputs_embeds = None
if "internlm-xcomposer2-vl-7b" in self.model_name.lower():
lines = [
"You are an AI assistant whose name is InternLM-XComposer (浦语·灵笔).",
"- InternLM-XComposer (浦语·灵笔) is a multi-modality conversational language "
"model that is developed by Shanghai AI Laboratory (上海人工智能实验室). "
"It is designed to be helpful, honest, and harmless.",
"- InternLM-XComposer (浦语·灵笔) can understand and communicate fluently in "
"the language chosen by the user such as English and 中文.",
"- InternLM-XComposer (浦语·灵笔) is capable of comprehending and "
"articulating responses effectively based on the provided image."
]
meta_instruction = "\n".join(lines)
if prompt_request.image_list is None:
inputs = self.model.build_inputs(tokenizer, plain_texts, [], meta_instruction)
im_mask = torch.zeros(inputs['input_ids'].shape[:2]).bool()
input_ids = inputs["input_ids"].to('xpu')
else:
image = self.model.encode_img(prompt_request.image_list[0])
plain_texts = "<ImageHere>" + plain_texts
inputs, im_mask = self.model.interleav_wrap_chat(tokenizer, plain_texts,
image, [], meta_instruction)
inputs_embeds = inputs["inputs_embeds"].to('xpu').to(self.dtype)
else:
inputs = tokenizer(plain_texts, return_tensors="pt", padding=True)
input_ids = inputs.input_ids.to('xpu')
parameters = prompt_request.parameters
return input_ids, parameters, request_id
return input_ids, parameters, request_id, inputs_embeds

@torch.no_grad()
async def process_step(self, tokenizer, result_dict):
if not self.waiting_requests.empty():
input_ids, parameters, request_id = await self.add_request(tokenizer)
input_ids, parameters, request_id, inputs_embeds = await self.add_request(tokenizer)
self.streamer[request_id] = TextIteratorStreamer(tokenizer, skip_prompt=True)

def model_generate():
Expand All @@ -78,8 +103,18 @@ def model_generate():
tokenizer.convert_tokens_to_ids("<|user|>"),
tokenizer.convert_tokens_to_ids("<|observation|>")]
generate_kwargs["eos_token_id"] = eos_token_id
self.model.generate(input_ids,
streamer=self.streamer[request_id], **generate_kwargs)
elif "internlm-xcomposer2-vl-7b" in self.model_name.lower():
eos_token_id = [
tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids(['[UNUSED_TOKEN_145]'])[0]
]
generate_kwargs["eos_token_id"] = eos_token_id
if input_ids is not None:
self.model.generate(input_ids,
streamer=self.streamer[request_id], **generate_kwargs)
elif inputs_embeds is not None:
self.model.generate(inputs_embeds=inputs_embeds,
streamer=self.streamer[request_id], **generate_kwargs)
torch.xpu.empty_cache()
torch.xpu.synchronize()

Expand Down
2 changes: 1 addition & 1 deletion python/llm/src/ipex_llm/utils/benchmark_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,7 @@ def _prepare_model_inputs(
if input_name == "input_ids" and "inputs_embeds" in model_kwargs:
if not self.config.is_encoder_decoder:
has_inputs_embeds_forwarding = "inputs_embeds" in set(
inspect.signature(self.prepare_inputs_for_generation).parameters.keys()
inspect.signature(self.model.prepare_inputs_for_generation).parameters.keys()
)
if not has_inputs_embeds_forwarding:
raise ValueError(
Expand Down

0 comments on commit 493cbd9

Please sign in to comment.