From f151d4fb371877cc35e49354f5eda9c7b6dc066d Mon Sep 17 00:00:00 2001 From: plusbang Date: Wed, 10 Jul 2024 02:01:29 +0800 Subject: [PATCH 1/2] first add --- .../GPU/Pipeline-Parallel-Inference/README.md | 15 ++++ .../glm_4v_generate.py | 87 +++++++++++++++++++ .../run_glm_4v_arc_2_card.sh | 31 +++++++ .../ipex_llm/transformers/models/chatglm4v.py | 9 +- .../transformers/pipeline_parallel.py | 53 ++++++++--- 5 files changed, 177 insertions(+), 18 deletions(-) create mode 100644 python/llm/example/GPU/Pipeline-Parallel-Inference/glm_4v_generate.py create mode 100644 python/llm/example/GPU/Pipeline-Parallel-Inference/run_glm_4v_arc_2_card.sh diff --git a/python/llm/example/GPU/Pipeline-Parallel-Inference/README.md b/python/llm/example/GPU/Pipeline-Parallel-Inference/README.md index cb9df2d00c6..c350be36263 100644 --- a/python/llm/example/GPU/Pipeline-Parallel-Inference/README.md +++ b/python/llm/example/GPU/Pipeline-Parallel-Inference/README.md @@ -17,6 +17,7 @@ To run this example with IPEX-LLM on Intel GPUs, we have some recommended requir - [Qwen/Qwen-VL-Chat](./run_qwen_vl_arc_2_card.sh) - [Qwen/CodeQwen1.5-7B-Chat](./run_qwen1.5_arc_2_card.sh) - [THUDM/glm-4-9b-chat](./run_chatglm_arc_2_card.sh) +- [THUDM/glm-4v-9b](./run_glm_4v_arc_2_card.sh) - [THUDM/chatglm3-6b](./run_chatglm_arc_2_card.sh) - [baichuan-inc/Baichuan2-7B-Chat](./run_baichuan2_arc_2_card.sh) - [baichuan-inc/Baichuan2-13B-Chat](./run_baichuan2_arc_2_card.sh) @@ -145,6 +146,20 @@ bash run_chatglm_arc_2_card.sh +
+ Show glm-4v example + +#### Run glm-4v-9b on two Intel Arc A770 + +You could specify `--repo-id-or-model-path` in the test script to be the huggingface repo id for glm-4v-9b to be downloaded, or the path to the huggingface checkpoint folder. Besides, you could change `NUM_GPUS` to the number of GPUs you have on your machine. + +```bash +pip install transformers==4.37.0 tiktoken +bash run_glm_4v_arc_2_card.sh +``` + +
+
diff --git a/python/llm/example/GPU/Pipeline-Parallel-Inference/glm_4v_generate.py b/python/llm/example/GPU/Pipeline-Parallel-Inference/glm_4v_generate.py new file mode 100644 index 00000000000..f788a362f4e --- /dev/null +++ b/python/llm/example/GPU/Pipeline-Parallel-Inference/glm_4v_generate.py @@ -0,0 +1,87 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import time +import torch +import argparse +import requests + +from PIL import Image +from ipex_llm.transformers import AutoModelForCausalLM, init_pipeline_parallel +from transformers import AutoTokenizer + +init_pipeline_parallel() + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for THUDM/glm-4v-9b model') + parser.add_argument('--repo-id-or-model-path', type=str, default="THUDM/glm-4v-9b", + help='The huggingface repo id for the THUDM/glm-4v-9b model to be downloaded' + ', or the path to the huggingface checkpoint folder') + parser.add_argument('--image-url-or-path', type=str, + default='http://farm6.staticflickr.com/5268/5602445367_3504763978_z.jpg', + help='The URL or path to the image to infer') + parser.add_argument('--prompt', type=str, default="这是什么?", + help='Prompt to infer') + parser.add_argument('--n-predict', type=int, default=32, + help='Max tokens to predict') + parser.add_argument('--low-bit', type=str, default='sym_int4', help='The quantization type the model will convert to.') + parser.add_argument('--gpu-num', type=int, default=2, help='GPU number to use') + + args = parser.parse_args() + model_path = args.repo_id_or_model_path + image_path = args.image_url_or_path + + model = AutoModelForCausalLM.from_pretrained(model_path, + load_in_low_bit=args.low_bit, + optimize_model=True, + trust_remote_code=True, + use_cache=True, + pipeline_parallel_stages=args.gpu_num) + model = model.half() + + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + local_rank = torch.distributed.get_rank() + + query = args.prompt + if os.path.exists(image_path): + image = Image.open(image_path) + else: + image = Image.open(requests.get(image_path, stream=True).raw) + + # here the prompt tuning refers to https://huggingface.co/THUDM/glm-4v-9b/blob/main/README.md + inputs = tokenizer.apply_chat_template([{"role": "user", "image": image, "content": query}], + add_generation_prompt=True, + tokenize=True, + return_tensors="pt", + return_dict=True) # chat mode + inputs = inputs.to(f'xpu:{local_rank}') + all_input = [{'image': image_path}, {'text': query}] + + # Generate predicted tokens + with torch.inference_mode(): + gen_kwargs = {"max_new_tokens": args.n_predict, "do_sample": False,} + st = time.time() + outputs = model.generate(**inputs, **gen_kwargs) + outputs = outputs[:, inputs['input_ids'].shape[1]:] + end = time.time() + if local_rank == args.gpu_num - 1: + print(f'Inference time: {end-st} s') + output_str = tokenizer.decode(outputs[0]) + print('-'*20, 'Input', '-'*20) + print(f'Message: {all_input}') + print('-'*20, 'Output', '-'*20) + print(output_str) diff --git a/python/llm/example/GPU/Pipeline-Parallel-Inference/run_glm_4v_arc_2_card.sh b/python/llm/example/GPU/Pipeline-Parallel-Inference/run_glm_4v_arc_2_card.sh new file mode 100644 index 00000000000..98e1fa484a3 --- /dev/null +++ b/python/llm/example/GPU/Pipeline-Parallel-Inference/run_glm_4v_arc_2_card.sh @@ -0,0 +1,31 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +source /opt/intel/oneapi/setvars.sh +export MASTER_ADDR=127.0.0.1 +export MASTER_PORT=9090 +export FI_PROVIDER=tcp +export USE_XETLA=OFF +export OMP_NUM_THREADS=6 +if [[ $KERNEL_VERSION != *"6.5"* ]]; then + export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1 +fi +export TORCH_LLM_ALLREDUCE=0 + +NUM_GPUS=2 # number of used GPU +# To run glm-4v-9b +CCL_ZE_IPC_EXCHANGE=sockets torchrun --standalone --nnodes=1 --nproc-per-node $NUM_GPUS \ + glm_4v_generate.py --repo-id-or-model-path 'THUDM/glm-4v-9b' --gpu-num $NUM_GPUS --low-bit 'sym_int4' diff --git a/python/llm/src/ipex_llm/transformers/models/chatglm4v.py b/python/llm/src/ipex_llm/transformers/models/chatglm4v.py index 2b848e1b26c..c8f87e61d5e 100644 --- a/python/llm/src/ipex_llm/transformers/models/chatglm4v.py +++ b/python/llm/src/ipex_llm/transformers/models/chatglm4v.py @@ -55,9 +55,7 @@ def chatglm4v_model_forward( # generate mode with past_key_values. the image features are already mapped if past_key_values is None: # not allow for inputs_embeds, because we want to process image feature - invalidInputError(input_ids is not None and inputs_embeds is None, - f"{input_ids} should not be None, {inputs_embeds} should be None.") - if not is_empty(images): # multi-modality + if not is_empty(images) and input_ids is not None: # multi-modality image_size: int = self.config.vision_config['image_size'] patch_size: int = self.config.vision_config['patch_size'] num_patches = (image_size // patch_size // 2) ** 2 @@ -99,10 +97,11 @@ def chatglm4v_model_forward( use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict - batch_size, seq_length = input_ids.shape - if inputs_embeds is None: + batch_size, seq_length = input_ids.shape inputs_embeds = self.embedding(input_ids) + else: + batch_size, seq_length, _ = inputs_embeds.shape if full_attention_mask is None: if (attention_mask is not None and not attention_mask.all()) or\ diff --git a/python/llm/src/ipex_llm/transformers/pipeline_parallel.py b/python/llm/src/ipex_llm/transformers/pipeline_parallel.py index de5e5f56df8..f0bd3829b78 100644 --- a/python/llm/src/ipex_llm/transformers/pipeline_parallel.py +++ b/python/llm/src/ipex_llm/transformers/pipeline_parallel.py @@ -229,13 +229,14 @@ def generate( generation_config.pad_token_id = eos_token_id if generation_config is not None and generation_config.max_new_tokens is not None: - max_new_tokens = generation_config.max_new_tokens + max_new_tokens = generation_config.pop("max_new_tokens") else: - max_new_tokens = kwargs.get("max_new_tokens", None) + max_new_tokens = kwargs.pop("max_new_tokens", None) return self.pipeline_parallel_generate(inputs=inputs, max_new_tokens=max_new_tokens, - generation_config=generation_config,) + generation_config=generation_config, + **kwargs) return original_generate(self, inputs=inputs, @@ -257,6 +258,23 @@ def pipeline_parallel_generate(self, max_new_tokens: int = 32, generation_config: Optional[GenerationConfig] = None, **kwargs): + model_kwargs = generation_config.update(**kwargs) + inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs( + inputs, generation_config.bos_token_id, model_kwargs + ) + bs = inputs_tensor.shape[0] + if self.config.is_encoder_decoder: + input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation( + batch_size=bs, + model_input_name=model_input_name, + model_kwargs=model_kwargs, + decoder_start_token_id=generation_config.decoder_start_token_id, + bos_token_id=generation_config.bos_token_id, + device=inputs_tensor.device, + ) + else: + input_ids = inputs_tensor if model_input_name == "input_ids" \ + else model_kwargs.pop("input_ids") local_rank = dist.get_rank() pre_rank = (local_rank - 1) % self.pipeline_parallel_stages next_rank = (local_rank + 1) % self.pipeline_parallel_stages @@ -272,36 +290,44 @@ def pipeline_parallel_generate(self, eos_token_id = generation_config.eos_token_id if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] - eos_token_id_tensor = torch.tensor(eos_token_id).to(inputs.device) \ + eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) \ if eos_token_id is not None else None _input_ids = None _past_key_values = None - bs = inputs.shape[0] - output_ids = inputs.clone() + + bs = input_ids.shape[0] + output_ids = input_ids.clone() _check_quantize_kv_cache(self, layer_start, bs) step = 0 # keep track of which sequences are already finished - unfinished_sequences = torch.ones(inputs.shape[0], dtype=torch.long, device=inputs.device) + unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) this_peer_finished = False while True: if step >= max_new_tokens: break if _input_ids is None: - _input_ids = inputs + _input_ids = input_ids tic = time.time() if local_rank == 0: outputs = self(input_ids=_input_ids, inputs_embeds=None, - past_key_values=_past_key_values, use_cache=True) + past_key_values=_past_key_values, use_cache=True, **model_kwargs) else: - inputs_embeds = torch.empty(_input_ids.shape + (self.config.hidden_size,), + _inputs_shape = _input_ids.shape + (self.config.hidden_size,) + if step == 0 and self.config.model_type == "chatglm" \ + and hasattr(self.config, "vision_config"): + # for glm-4v, image features are mapped during 1st token + # 1597 are computed according to computation process of conv + _images_feature = 1597 + _input_ids.shape[0] * 2 + _input_ids.shape[1] + _inputs_shape = (_input_ids.shape[0], _images_feature, self.config.hidden_size,) + inputs_embeds = torch.empty(_inputs_shape, device=f'xpu:{local_rank}', dtype=self.dtype) dist.recv(inputs_embeds, src=pre_rank) outputs = self(input_ids=None, inputs_embeds=inputs_embeds, - past_key_values=_past_key_values, use_cache=True) + past_key_values=_past_key_values, use_cache=True, **model_kwargs) if local_rank == self.pipeline_parallel_stages - 1: logits = outputs.logits @@ -323,7 +349,8 @@ def pipeline_parallel_generate(self, "make sure that `pad_token_id` is defined.") next_ids = next_ids * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) - if self.config.model_type == "chatglm" and self.config.num_layers == 40: + if self.config.model_type == "chatglm" and self.config.num_layers == 40 \ + and not hasattr(self.config, "vision_config"): # for glm-4-9b-chat if step == 0: value_placeholder = torch.empty_like((outputs.past_key_values)[-1][0]) @@ -337,7 +364,7 @@ def pipeline_parallel_generate(self, _past_key_values = outputs.past_key_values elif self.config.model_type in ["baichuan", "chatglm"] or \ (self.config.model_type == "qwen" and hasattr(self.config, "visual")): - # for baichuan2, chatglm3, Qwen-VL-Chat + # for baichuan2, chatglm3, Qwen-VL-Chat, glm-4v-9b if local_rank != 0: value_placeholder = torch.empty_like((outputs.past_key_values)[-1][0]) past_key_values_placeholder = tuple( From 2e8a4d85041d9e4cebb31415bb6a33031f7b0564 Mon Sep 17 00:00:00 2001 From: plusbang Date: Thu, 11 Jul 2024 21:51:21 +0800 Subject: [PATCH 2/2] fix --- python/llm/src/ipex_llm/transformers/models/chatglm4v.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/llm/src/ipex_llm/transformers/models/chatglm4v.py b/python/llm/src/ipex_llm/transformers/models/chatglm4v.py index c8f87e61d5e..a315124b0ef 100644 --- a/python/llm/src/ipex_llm/transformers/models/chatglm4v.py +++ b/python/llm/src/ipex_llm/transformers/models/chatglm4v.py @@ -102,6 +102,8 @@ def chatglm4v_model_forward( inputs_embeds = self.embedding(input_ids) else: batch_size, seq_length, _ = inputs_embeds.shape + input_ids = torch.empty((batch_size, seq_length), + dtype=inputs_embeds.dtype, device=inputs_embeds.device) if full_attention_mask is None: if (attention_mask is not None and not attention_mask.all()) or\