From 60de428b373bf56c4d959447aaec0166c32da029 Mon Sep 17 00:00:00 2001
From: binbin Deng <108676127+plusbang@users.noreply.github.com>
Date: Thu, 4 Jul 2024 18:03:57 +0800
Subject: [PATCH] Support pipeline parallel for qwen-vl (#11503)
---
.../GPU/Pipeline-Parallel-Inference/README.md | 25 +++
.../GPU/Pipeline-Parallel-Inference/chat.py | 76 +++++++
.../run_qwen_vl_arc_2_card.sh | 32 +++
.../llm/src/ipex_llm/transformers/convert.py | 4 +
.../ipex_llm/transformers/models/qwen_vl.py | 209 +++++++++++++++++-
.../transformers/pipeline_parallel.py | 35 ++-
6 files changed, 370 insertions(+), 11 deletions(-)
create mode 100644 python/llm/example/GPU/Pipeline-Parallel-Inference/chat.py
create mode 100644 python/llm/example/GPU/Pipeline-Parallel-Inference/run_qwen_vl_arc_2_card.sh
diff --git a/python/llm/example/GPU/Pipeline-Parallel-Inference/README.md b/python/llm/example/GPU/Pipeline-Parallel-Inference/README.md
index 66cc3cf18a7..cb9df2d00c6 100644
--- a/python/llm/example/GPU/Pipeline-Parallel-Inference/README.md
+++ b/python/llm/example/GPU/Pipeline-Parallel-Inference/README.md
@@ -14,6 +14,7 @@ To run this example with IPEX-LLM on Intel GPUs, we have some recommended requir
- [Qwen/Qwen1.5-14B-Chat](./run_qwen1.5_arc_2_card.sh)
- [Qwen/Qwen1.5-32B-Chat](./run_qwen1.5_arc_2_card.sh)
- [Qwen/Qwen1.5-MoE-A2.7B-Chat](./run_qwen1.5_arc_2_card.sh)
+- [Qwen/Qwen-VL-Chat](./run_qwen_vl_arc_2_card.sh)
- [Qwen/CodeQwen1.5-7B-Chat](./run_qwen1.5_arc_2_card.sh)
- [THUDM/glm-4-9b-chat](./run_chatglm_arc_2_card.sh)
- [THUDM/chatglm3-6b](./run_chatglm_arc_2_card.sh)
@@ -114,6 +115,22 @@ bash run_qwen1.5_arc_2_card.sh
+
+ Show Qwen-VL example
+
+#### Run Qwen-VL-Chat on two Intel Arc A770
+
+You could specify `--repo-id-or-model-path` in the test script to be the huggingface repo id for Qwen-VL to be downloaded, or the path to the huggingface checkpoint folder. Besides, you could change `NUM_GPUS` to the number of GPUs you have on your machine.
+
+```bash
+pip install transformers==4.32.0 tiktoken einops transformers_stream_generator==0.0.4 scipy torchvision pillow tensorboard matplotlib
+bash run_qwen_vl_arc_2_card.sh
+```
+
+
+
+
+
Show chatglm example
@@ -250,3 +267,11 @@ Once upon a time, there existed a little girl who liked to have adventures. She
One day, the little girl
```
+
+#### [Qwen/Qwen-VL-Chat](https://huggingface.co/Qwen/Qwen-VL-Chat)
+```log
+-------------------- Input --------------------
+Message: [{'image': 'http://farm6.staticflickr.com/5268/5602445367_3504763978_z.jpg'}, {'text': '这是什么?'}]
+-------------------- Output --------------------
+这是一张图片,展现了一个穿着粉色条纹连衣裙的小女孩,她正拿着一只穿粉色裙子的白色玩具小熊。
+```
diff --git a/python/llm/example/GPU/Pipeline-Parallel-Inference/chat.py b/python/llm/example/GPU/Pipeline-Parallel-Inference/chat.py
new file mode 100644
index 00000000000..55ae0762e6f
--- /dev/null
+++ b/python/llm/example/GPU/Pipeline-Parallel-Inference/chat.py
@@ -0,0 +1,76 @@
+#
+# Copyright 2016 The BigDL Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import argparse
+import os
+
+import torch
+import time
+from transformers import AutoTokenizer
+from ipex_llm.transformers import AutoModelForCausalLM, init_pipeline_parallel
+
+init_pipeline_parallel()
+torch.manual_seed(1234)
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(description='Predict Tokens using `chat()` API for large vision language model')
+ parser.add_argument('--repo-id-or-model-path', type=str, default="Qwen/Qwen-VL-Chat",
+ help='The huggingface repo id for the Qwen-VL-Chat model to be downloaded'
+ ', or the path to the huggingface checkpoint folder')
+ parser.add_argument('--image-url-or-path', type=str,
+ default="http://farm6.staticflickr.com/5268/5602445367_3504763978_z.jpg",
+ help='The URL or path to the image to infer')
+ parser.add_argument('--prompt', type=str, default="这是什么?",
+ help='Prompt to infer')
+ parser.add_argument('--n-predict', type=int, default=32,
+ help='Max tokens to predict')
+ parser.add_argument('--low-bit', type=str, default='sym_int4', help='The quantization type the model will convert to.')
+ parser.add_argument('--gpu-num', type=int, default=2, help='GPU number to use')
+
+ args = parser.parse_args()
+ model_path = args.repo_id_or_model_path
+ image_path = args.image_url_or_path
+
+ # Load model
+ # For successful IPEX-LLM optimization on Qwen-VL-Chat, skip the 'c_fc' and 'out_proj' modules during optimization
+ # When running LLMs on Intel iGPUs for Windows users, we recommend setting `cpu_embedding=True` in the from_pretrained function.
+ # This will allow the memory-intensive embedding layer to utilize the CPU instead of iGPU.
+ model = AutoModelForCausalLM.from_pretrained(model_path,
+ load_in_low_bit=args.low_bit,
+ optimize_model=True,
+ trust_remote_code=True,
+ use_cache=True,
+ torch_dtype=torch.float32,
+ modules_to_not_convert=['c_fc', 'out_proj'],
+ pipeline_parallel_stages=args.gpu_num)
+
+ # Load tokenizer
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
+ local_rank = torch.distributed.get_rank()
+
+ all_input = [{'image': args.image_url_or_path}, {'text': args.prompt}]
+ input_list = [_input for _input in all_input if list(_input.values())[0] != '']
+ query = tokenizer.from_list_format(input_list)
+
+ with torch.inference_mode():
+ response, _ = model.chat(tokenizer, query=query, history=[])
+ torch.xpu.synchronize()
+
+ if local_rank == args.gpu_num - 1:
+ print('-'*20, 'Input', '-'*20)
+ print(f'Message: {all_input}')
+ print('-'*20, 'Output', '-'*20)
+ print(response)
diff --git a/python/llm/example/GPU/Pipeline-Parallel-Inference/run_qwen_vl_arc_2_card.sh b/python/llm/example/GPU/Pipeline-Parallel-Inference/run_qwen_vl_arc_2_card.sh
new file mode 100644
index 00000000000..a237a443e71
--- /dev/null
+++ b/python/llm/example/GPU/Pipeline-Parallel-Inference/run_qwen_vl_arc_2_card.sh
@@ -0,0 +1,32 @@
+#
+# Copyright 2016 The BigDL Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+source /opt/intel/oneapi/setvars.sh
+export MASTER_ADDR=127.0.0.1
+export MASTER_PORT=9090
+export FI_PROVIDER=tcp
+export USE_XETLA=OFF
+export OMP_NUM_THREADS=6
+if [[ $KERNEL_VERSION != *"6.5"* ]]; then
+ export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1
+fi
+export TORCH_LLM_ALLREDUCE=0
+
+NUM_GPUS=2 # number of used GPU
+
+# To run Qwen-VL-Chat
+CCL_ZE_IPC_EXCHANGE=sockets torchrun --standalone --nnodes=1 --nproc-per-node $NUM_GPUS \
+ chat.py --repo-id-or-model-path 'Qwen/Qwen-VL-Chat' --gpu-num $NUM_GPUS --low-bit 'sym_int4'
diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py
index 5d9d47468cc..c419eebf65b 100644
--- a/python/llm/src/ipex_llm/transformers/convert.py
+++ b/python/llm/src/ipex_llm/transformers/convert.py
@@ -1269,10 +1269,14 @@ def _optimize_post(model, lightweight_bmm=False):
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
from ipex_llm.transformers.models.qwen_vl import qwen_attention_forward_vl
+ from ipex_llm.transformers.models.qwen_vl import qwen_vl_model_forward
convert_forward(model,
module.QWenAttention,
qwen_attention_forward_vl
)
+ convert_forward(model,
+ module.QWenModel,
+ qwen_vl_model_forward)
else:
# for Qwen-7B and Qwen-14B
modeling_module_name = model.__class__.__module__
diff --git a/python/llm/src/ipex_llm/transformers/models/qwen_vl.py b/python/llm/src/ipex_llm/transformers/models/qwen_vl.py
index 3f494de9e3e..a2f6e9481ae 100644
--- a/python/llm/src/ipex_llm/transformers/models/qwen_vl.py
+++ b/python/llm/src/ipex_llm/transformers/models/qwen_vl.py
@@ -33,7 +33,8 @@
from ipex_llm.transformers.models.utils import extend_kv_cache, init_kv_cache, append_kv_cache
from ipex_llm.transformers.models.utils import rotate_half
from ipex_llm.transformers.models.utils import use_sdp
-
+from transformers.modeling_outputs import BaseModelOutputWithPast
+from ipex_llm.utils.common import invalidInputError
import os
KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
@@ -243,3 +244,209 @@ def qwen_vl_vision_transformer_forward(self, x: torch.Tensor):
x = x @ self.proj
return x
+
+
+def qwen_vl_model_forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+):
+ # bigdl-llm change starts
+ input = input_ids if input_ids is not None else inputs_embeds
+ # bigdl-llm change ends
+ if past_key_values is None and torch.any(input == self.config.visual['image_start_id']):
+ bos_pos = torch.where(input == self.config.visual['image_start_id'])
+ eos_pos = torch.where(input == self.config.visual['image_start_id'] + 1)
+ invalidInputError((bos_pos[0] == eos_pos[0]).all(),
+ 'bos_pos[0] should be same as eos_pos[0]')
+ img_pos = torch.stack((bos_pos[0], bos_pos[1], eos_pos[1]), dim=1)
+ images = []
+ for i, a, b in img_pos:
+ image = input[i][a + 1: b - 1].tolist()
+ image = image[: image.index(self.config.visual['image_start_id'] + 2)]
+ images.append(bytes(image).decode('utf-8'))
+
+ images = self.visual.encode(images)
+ invalidInputError(images.shape[0] == len(images),
+ 'images.shape[0] should be same as len(images)')
+ fake_images = None
+ elif self.training:
+ fake_images = torch.zeros(1, 3, 224, 224).to(
+ dtype=self.visual.conv1.weight.dtype, device=self.visual.conv1.weight.device)
+ images = self.visual(fake_images)
+ else:
+ fake_images = None
+ images = None
+
+ output_attentions = (
+ output_attentions
+ if output_attentions is not None
+ else self.config.output_attentions
+ )
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = (
+ return_dict if return_dict is not None else self.config.use_return_dict
+ )
+
+ if input_ids is not None and inputs_embeds is not None:
+ invalidInputError(False,
+ "You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ input_ids = input_ids.view(-1, input_shape[-1])
+ batch_size = input_ids.shape[0]
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ batch_size = inputs_embeds.shape[0]
+ else:
+ invalidInputError(False, "You have to specify either input_ids or inputs_embeds")
+
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+
+ if token_type_ids is not None:
+ token_type_ids = token_type_ids.view(-1, input_shape[-1])
+ if position_ids is not None:
+ position_ids = position_ids.view(-1, input_shape[-1])
+
+ if past_key_values is None:
+ past_length = 0
+ past_key_values = tuple([None] * len(self.h))
+ else:
+ past_length = past_key_values[0][0].size(-2)
+
+ if position_ids is None:
+ position_ids = torch.arange(
+ past_length,
+ input_shape[-1] + past_length,
+ dtype=torch.long,
+ device=device,
+ )
+ position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
+
+ encoder_attention_mask = None
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.wte(input_ids)
+
+ if batch_size <= 0:
+ invalidInputError(False, "batch_size has to be defined and > 0")
+ attention_mask = self._prepare_decoder_attention_mask(
+ attention_mask, input_shape, inputs_embeds, past_length
+ )
+
+ hidden_states = inputs_embeds
+
+ kv_seq_len = hidden_states.size()[1]
+ if past_key_values[0] is not None:
+ # past key values[0][0] shape: bs * seq_len * head_num * dim
+ kv_seq_len += past_key_values[0][0].shape[1]
+ if (
+ self.use_dynamic_ntk
+ and kv_seq_len == hidden_states.size()[1]
+ and not self.training
+ ):
+ context_value = math.log(kv_seq_len / self.seq_length, 2) + 1
+ ntk_alpha = 2 ** math.ceil(context_value) - 1
+ ntk_alpha = max(ntk_alpha, 1)
+ else:
+ ntk_alpha = self.rotary_emb._ntk_alpha_cached
+
+ rotary_pos_emb = self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha)
+ for idx in range(len(rotary_pos_emb)):
+ rotary_pos_emb[idx] = rotary_pos_emb[idx].to(hidden_states.device)
+
+ hidden_states = self.drop(hidden_states).clone()
+ if fake_images is not None:
+ hidden_states = hidden_states + images.mean()*0
+ elif images is not None:
+ for idx, (i, a, b) in enumerate(img_pos):
+ hidden_states[i][a + 1: b] = images[idx]
+ output_shape = input_shape + (hidden_states.size(-1),)
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ use_cache = False
+
+ presents = () if use_cache else None
+ all_self_attentions = () if output_attentions else None
+ all_hidden_states = () if output_hidden_states else None
+ for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ # None for past_key_value
+ return module(*inputs, use_cache, output_attentions)
+
+ return custom_forward
+
+ outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ hidden_states,
+ rotary_pos_emb,
+ self.registered_causal_mask,
+ None,
+ attention_mask,
+ head_mask[i],
+ encoder_hidden_states,
+ encoder_attention_mask,
+ )
+ else:
+ outputs = block(
+ hidden_states,
+ layer_past=layer_past,
+ rotary_pos_emb=rotary_pos_emb,
+ registered_causal_mask=self.registered_causal_mask,
+ attention_mask=attention_mask,
+ head_mask=head_mask[i],
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ )
+
+ hidden_states = outputs[0]
+ if use_cache is True:
+ presents = presents + (outputs[1],)
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
+
+ hidden_states = self.ln_f(hidden_states)
+ hidden_states = hidden_states.view(output_shape)
+ # Add last hidden state
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v for v in [hidden_states, presents, all_hidden_states] if v is not None
+ )
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=presents,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
diff --git a/python/llm/src/ipex_llm/transformers/pipeline_parallel.py b/python/llm/src/ipex_llm/transformers/pipeline_parallel.py
index 5ecfa903c39..23d945314c0 100644
--- a/python/llm/src/ipex_llm/transformers/pipeline_parallel.py
+++ b/python/llm/src/ipex_llm/transformers/pipeline_parallel.py
@@ -123,9 +123,20 @@ def pipeline_parallel(model, pipeline_parallel_stages):
layer_start = slice_size * local_rank
layer_end = layer_start + min(slice_size, num_layers - layer_start)
- if model.config.architectures is not None \
- and model.config.architectures[0] in ["ChatGLMModel", "ChatGLMForConditionalGeneration"]:
- # for chatglm3-6b
+ if model.config.model_type == "qwen" and hasattr(model.config, "visual"):
+ # for Qwen-VL-Chat
+ for i in range(num_layers):
+ if i < layer_start or i >= layer_end:
+ model._modules['transformer'].h[i] = Dummy_DecoderLayer()
+ if local_rank != 0:
+ model._modules['transformer'].wte = DummyLayer()
+ model._modules['transformer'].drop = DummyLayer()
+ if local_rank != pipeline_parallel_stages - 1:
+ model._modules['transformer'].ln_f = DummyLayer()
+ model._modules['ln_f'] = DummyLayer()
+ model._modules['lm_head'] = DummyLayer()
+ elif model.config.model_type == "chatglm":
+ # for chatglm3-6b, glm-4-9b-chat
for i in range(num_layers):
if i < layer_start or i >= layer_end:
model._modules['transformer'].encoder.layers[i] = Dummy_GLMBlock()
@@ -296,13 +307,17 @@ def pipeline_parallel_generate(self,
_past_key_values = past_key_values_placeholder
else:
_past_key_values = outputs.past_key_values
- elif self.config.model_type in ["baichuan", "chatglm"] and local_rank != 0:
- # for baichuan2 and chatglm3
- value_placeholder = torch.empty_like((outputs.past_key_values)[-1][0])
- past_key_values_placeholder = tuple(
- (value_placeholder, value_placeholder) for _ in range(layer_start)
- ) + (outputs.past_key_values)[layer_start:]
- _past_key_values = past_key_values_placeholder
+ elif self.config.model_type in ["baichuan", "chatglm"] or \
+ (self.config.model_type == "qwen" and hasattr(self.config, "visual")):
+ # for baichuan2, chatglm3, Qwen-VL-Chat
+ if local_rank != 0:
+ value_placeholder = torch.empty_like((outputs.past_key_values)[-1][0])
+ past_key_values_placeholder = tuple(
+ (value_placeholder, value_placeholder) for _ in range(layer_start)
+ ) + (outputs.past_key_values)[layer_start:]
+ _past_key_values = past_key_values_placeholder
+ else:
+ _past_key_values = outputs.past_key_values
else:
_past_key_values = outputs.past_key_values