From e6ca2538fbea17823a650922788c0dcaff860ed3 Mon Sep 17 00:00:00 2001 From: Yang Wang Date: Thu, 19 Oct 2023 22:46:32 +0000 Subject: [PATCH 1/9] Support deepspeed --- .../llm/src/bigdl/llm/transformers/convert.py | 32 +++++++++++++++---- .../bigdl/llm/transformers/low_bit_linear.py | 12 +++++-- 2 files changed, 35 insertions(+), 9 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index c1902d357cd..2283586b82b 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -43,7 +43,7 @@ import importlib from bigdl.llm.ggml.quantize import ggml_tensor_qtype from .utils import logger - +from deepspeed.module_inject.layers import LinearLayer, LinearAllreduce def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, current_key_name=None, convert_shape_only=False): @@ -54,17 +54,28 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, if current_key_name is None: current_key_name = [] - if isinstance(module, nn.Linear) and name not in modules_to_not_convert: + if (isinstance(module, nn.Linear) or isinstance(module, LinearLayer) or isinstance(module, LinearAllreduce)) and name not in modules_to_not_convert: # Check if the current key is not in the `modules_to_not_convert` if not any(key in ".".join(current_key_name) for key in modules_to_not_convert): with init_empty_weights(): + if isinstance(module, nn.Linear): + in_features = module.in_features + out_features = module.out_features + else: + in_features = module.weight.shape[1] + out_features = module.weight.shape[0] + if isinstance(module, LinearAllreduce): + mp_group = module.mp_group + else: + mp_group = None new_linear = None if qtype != ggml_tensor_qtype["fp16"]: new_linear = LowBitLinear( - module.in_features, - module.out_features, + in_features, + out_features, qtype, module.bias is not None, + mp_group=mp_group, ) device_type = module.weight.data.device.type @@ -82,10 +93,11 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, if module.in_features in [4096, 11008]: # esimd fp16 path new_linear = FP16Linear( - module.in_features, - module.out_features, + in_features, + out_features, qtype, module.bias is not None, + mp_group=mp_group, ) device_type = module.weight.data.device.type @@ -104,7 +116,13 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, model._modules[name] = new_linear has_been_replaced = True # Force requires grad to False to avoid unexpected errors - model._modules[name].requires_grad_(False) + try: + model._modules[name].requires_grad_(False) + except Exception as e: + logger.warning( + f"Failed to set `requires_grad=False` on {name} due to the following error: {e}" + ) + print(new_linear) module.weight = None diff --git a/python/llm/src/bigdl/llm/transformers/low_bit_linear.py b/python/llm/src/bigdl/llm/transformers/low_bit_linear.py index b1026ec78ae..c5b85312ade 100644 --- a/python/llm/src/bigdl/llm/transformers/low_bit_linear.py +++ b/python/llm/src/bigdl/llm/transformers/low_bit_linear.py @@ -328,7 +328,7 @@ def backward(ctx, grad_output): class LowBitLinear(nn.Linear): def __init__(self, input_features, output_features, qtype, bias=True, - conver_to_half=True): + conver_to_half=True, mp_group=None): super().__init__(input_features, output_features, bias) self.weight = FP4Params(self.weight.data, requires_grad=False, @@ -339,6 +339,7 @@ def __init__(self, input_features, output_features, qtype, bias=True, self.weight_length = self.out_len * self.in_len self.qtype = qtype self.conver_to_half = conver_to_half + self.mp_group = mp_group def forward(self, x: torch.Tensor): if self.bias is not None and self.bias.dtype != x.dtype: @@ -378,6 +379,9 @@ def forward(self, x: torch.Tensor): input_seq_size) new_shape = x_shape[:-1] + (self.out_len,) result = result.view(new_shape) + if self.mp_group is not None: + from deepspeed import comm as dist + dist.inference_all_reduce(result, group=self.mp_group) if self.bias is not None: result += self.bias else: @@ -400,7 +404,7 @@ def forward(self, x: torch.Tensor): class FP16Linear(nn.Linear): def __init__(self, input_features, output_features, qtype, bias=True, - conver_to_half=True): + conver_to_half=True, mp_group=None): super().__init__(input_features, output_features, bias) self.in_len = input_features self.out_len = output_features @@ -408,6 +412,7 @@ def __init__(self, input_features, output_features, qtype, bias=True, self.weight_length = self.out_len * self.in_len self.qtype = qtype self.conver_to_half = conver_to_half + self.mp_group = mp_group def forward(self, x: torch.Tensor): if self.bias is not None and self.bias.dtype != x.dtype: @@ -442,6 +447,9 @@ def forward(self, x: torch.Tensor): new_shape = x_shape[:-1] + (self.out_len,) result = result.view(new_shape) + if self.mp_group is not None: + from deepspeed import comm as dist + dist.inference_all_reduce(result, group=self.mp_group) if self.bias is not None: result += self.bias From 3693c691f717d1b2009a91c5e64c8209423338da Mon Sep 17 00:00:00 2001 From: Yang Wang Date: Mon, 23 Oct 2023 06:02:48 +0000 Subject: [PATCH 2/9] add test script --- run_deepspeed.sh | 7 ++++ test_deepspeed.py | 97 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 104 insertions(+) create mode 100644 run_deepspeed.sh create mode 100644 test_deepspeed.py diff --git a/run_deepspeed.sh b/run_deepspeed.sh new file mode 100644 index 00000000000..0997df4b571 --- /dev/null +++ b/run_deepspeed.sh @@ -0,0 +1,7 @@ +export MASTER_ADDR=127.0.0.1 +export CCL_ZE_IPC_EXCHANGE=sockets +export OMP_NUM_THREADS=28 +torchrun --standalone \ + --nnodes=1 \ + --nproc-per-node 4 \ + test_deepspeed.py diff --git a/test_deepspeed.py b/test_deepspeed.py new file mode 100644 index 00000000000..d2da0ea4859 --- /dev/null +++ b/test_deepspeed.py @@ -0,0 +1,97 @@ +import os +import torch +import transformers +import deepspeed +from gpu_benchmark_util import BenchmarkWrapper +local_rank = int(os.getenv("LOCAL_RANK", "0")) +world_size = int(os.getenv("WORLD_SIZE", "1")) + +from bigdl.llm import optimize_model + +import torch +import intel_extension_for_pytorch as ipex +import time +import argparse + +# from bigdl.llm.transformers import AutoModelForCausalLM +from transformers import AutoModelForCausalLM +from transformers import LlamaTokenizer, AutoTokenizer + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for Llama2 model') + parser.add_argument('--repo-id-or-model-path', type=str, default="meta-llama/Llama-2-7b-chat-hf", + help='The huggingface repo id for the Llama2 (e.g. `meta-llama/Llama-2-7b-chat-hf` and `meta-llama/Llama-2-13b-chat-hf`) to be downloaded' + ', or the path to the huggingface checkpoint folder') + parser.add_argument('--prompt', type=str, default="Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun", + help='Prompt to infer') + parser.add_argument('--n-predict', type=int, default=32, + help='Max tokens to predict') + + args = parser.parse_args() + model_path = args.repo_id_or_model_path + + # Load model in 4 bit, + # which convert the relevant layers in the model into INT4 format + # model_path = "/home/sdp/yang/bigdl/alpaca-lora-xpu/finetune_merged_llama_70b_step_700" + model_path = "meta-llama/Llama-2-7b-hf" + model_path = "bigscience/bloom-7b1" + # with deepspeed.OnDevice(dtype=torch.float16, device="meta"): + model = AutoModelForCausalLM.from_pretrained(model_path, + # load_in_4bit=True, + # optimize_model=True, + device_map={"": "cpu"}, + low_cpu_mem_usage=True, + torch_dtype=torch.float16, + trust_remote_code=True, + use_cache=True) + # model = BenchmarkWrapper(model) + + model = deepspeed.init_inference( + model, + mp_size=world_size, + dtype=torch.float16, + replace_method="auto", + # checkpoint="/home/sdp/yang/bigdl/save_deepspeed_llama_70b_sharded/ds_inference_config.json", + # replace_with_kernel_inject=True, + ) + + model = optimize_model(model.module.to(f'cpu')) + model = model.to(f'xpu:{local_rank}') + print(model) + + model = BenchmarkWrapper(model) + + # Load tokenizer + # tokenizer_path = "meta-llama/Llama-2-7b-hf" + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + + # Generate predicted tokens + with torch.inference_mode(): + # prompt = get_prompt(args.prompt, [], system_prompt=DEFAULT_SYSTEM_PROMPT) + prompt = args.prompt + # input_str = f"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{prompt}\n\n### Response:\n" + input_ids = tokenizer.encode(prompt, return_tensors="pt").to(f'xpu:{local_rank}') + # ipex model needs a warmup, then inference time can be accurate + output = model.generate(input_ids, + max_new_tokens=args.n_predict, + use_cache=True) + + # start inference + st = time.time() + # if your selected model is capable of utilizing previous key/value attentions + # to enhance decoding speed, but has `"use_cache": false` in its model config, + # it is important to set `use_cache=True` explicitly in the `generate` function + # to obtain optimal performance with BigDL-LLM INT4 optimizations + output = model.generate(input_ids, + do_sample=False, + max_new_tokens=args.n_predict) + torch.xpu.synchronize() + end = time.time() + if local_rank == 0: + output = output.cpu() + output_str = tokenizer.decode(output[0], skip_special_tokens=True) + print(f'Inference time: {end-st} s') + print('-'*20, 'Prompt', '-'*20) + print(prompt) + print('-'*20, 'Output', '-'*20) + print(output_str) From e0a835068bc3e98f948516bc02f6b9fd4197e21d Mon Sep 17 00:00:00 2001 From: Yang Wang Date: Mon, 23 Oct 2023 22:37:07 +0000 Subject: [PATCH 3/9] refactor convert --- .../llm/src/bigdl/llm/transformers/convert.py | 48 ++++++++++++++----- 1 file changed, 36 insertions(+), 12 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index 2283586b82b..ed004746292 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -43,7 +43,39 @@ import importlib from bigdl.llm.ggml.quantize import ggml_tensor_qtype from .utils import logger -from deepspeed.module_inject.layers import LinearLayer, LinearAllreduce +from transformers.integrations.deepspeed import is_deepspeed_available + +def is_linear_module(module): + + in_features = None + out_features = None + mp_group = None + + if isinstance(module, nn.Linear): + in_features = module.in_features + out_features = module.out_features + mp_group = None + result = True + else: + if is_deepspeed_available(): + from deepspeed.module_inject.layers import LinearLayer, LinearAllreduce + if isinstance(module, LinearLayer): + in_features = module.in_features + out_features = module.out_features + mp_group = None + result = True + elif isinstance(module, LinearAllreduce): + in_features = module.in_features + out_features = module.out_features + mp_group = module.mp_group + result = True + else: + result = False + else: + result = False + + return result, (in_features, out_features, mp_group) + def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, current_key_name=None, convert_shape_only=False): @@ -54,20 +86,12 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, if current_key_name is None: current_key_name = [] - if (isinstance(module, nn.Linear) or isinstance(module, LinearLayer) or isinstance(module, LinearAllreduce)) and name not in modules_to_not_convert: + is_linear, linear_args = is_linear_module(module) + if is_linear and name not in modules_to_not_convert: # Check if the current key is not in the `modules_to_not_convert` if not any(key in ".".join(current_key_name) for key in modules_to_not_convert): + in_features, out_features, mp_group = linear_args with init_empty_weights(): - if isinstance(module, nn.Linear): - in_features = module.in_features - out_features = module.out_features - else: - in_features = module.weight.shape[1] - out_features = module.weight.shape[0] - if isinstance(module, LinearAllreduce): - mp_group = module.mp_group - else: - mp_group = None new_linear = None if qtype != ggml_tensor_qtype["fp16"]: new_linear = LowBitLinear( From ddd6eca7fcf4331191d8f3906b4cd22a5ae7662a Mon Sep 17 00:00:00 2001 From: Yang Wang Date: Mon, 23 Oct 2023 23:39:17 +0000 Subject: [PATCH 4/9] refine example --- .../GPU/Deepspeed-AutoTP/deepspeed_autotp.py | 35 ++++++------------- .../GPU/Deepspeed-AutoTP/run_deepspeed.sh | 12 +++++++ run_deepspeed.sh | 7 ---- 3 files changed, 23 insertions(+), 31 deletions(-) rename test_deepspeed.py => python/llm/example/GPU/Deepspeed-AutoTP/deepspeed_autotp.py (70%) create mode 100644 python/llm/example/GPU/Deepspeed-AutoTP/run_deepspeed.sh delete mode 100644 run_deepspeed.sh diff --git a/test_deepspeed.py b/python/llm/example/GPU/Deepspeed-AutoTP/deepspeed_autotp.py similarity index 70% rename from test_deepspeed.py rename to python/llm/example/GPU/Deepspeed-AutoTP/deepspeed_autotp.py index d2da0ea4859..1022b99febe 100644 --- a/test_deepspeed.py +++ b/python/llm/example/GPU/Deepspeed-AutoTP/deepspeed_autotp.py @@ -13,8 +13,7 @@ import time import argparse -# from bigdl.llm.transformers import AutoModelForCausalLM -from transformers import AutoModelForCausalLM +from transformers import AutoModelForCausalLM # export AutoModelForCausalLM from transformers so that deepspeed use it from transformers import LlamaTokenizer, AutoTokenizer if __name__ == '__main__': @@ -30,39 +29,27 @@ args = parser.parse_args() model_path = args.repo_id_or_model_path - # Load model in 4 bit, - # which convert the relevant layers in the model into INT4 format - # model_path = "/home/sdp/yang/bigdl/alpaca-lora-xpu/finetune_merged_llama_70b_step_700" - model_path = "meta-llama/Llama-2-7b-hf" - model_path = "bigscience/bloom-7b1" - # with deepspeed.OnDevice(dtype=torch.float16, device="meta"): - model = AutoModelForCausalLM.from_pretrained(model_path, - # load_in_4bit=True, - # optimize_model=True, - device_map={"": "cpu"}, - low_cpu_mem_usage=True, - torch_dtype=torch.float16, - trust_remote_code=True, - use_cache=True) - # model = BenchmarkWrapper(model) + model = AutoModelForCausalLM.from_pretrained(args.repo_id_or_model_path, + low_cpu_mem_usage=True, + torch_dtype=torch.float16, + trust_remote_code=True, + use_cache=True) model = deepspeed.init_inference( model, mp_size=world_size, dtype=torch.float16, replace_method="auto", - # checkpoint="/home/sdp/yang/bigdl/save_deepspeed_llama_70b_sharded/ds_inference_config.json", - # replace_with_kernel_inject=True, ) - model = optimize_model(model.module.to(f'cpu')) - model = model.to(f'xpu:{local_rank}') - print(model) + # move model to cpu and use bigdl-llm `optimize_model` to convert the + # model into optimized low bit format + model = optimize_model(model.module.to(f'cpu'), low_bit='sym_int4') - model = BenchmarkWrapper(model) + # move model back to xpu + model = model.to(f'xpu:{local_rank}') # Load tokenizer - # tokenizer_path = "meta-llama/Llama-2-7b-hf" tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) # Generate predicted tokens diff --git a/python/llm/example/GPU/Deepspeed-AutoTP/run_deepspeed.sh b/python/llm/example/GPU/Deepspeed-AutoTP/run_deepspeed.sh new file mode 100644 index 00000000000..79adc15d2a0 --- /dev/null +++ b/python/llm/example/GPU/Deepspeed-AutoTP/run_deepspeed.sh @@ -0,0 +1,12 @@ +source bigdl-llm-init +export MASTER_ADDR=127.0.0.1 +export CCL_ZE_IPC_EXCHANGE=sockets +if [[ -n $OMP_NUM_THREADS ]]; then + export OMP_NUM_THREADS=$(($OMP_NUM_THREADS / 4)) +else + export OMP_NUM_THREADS=$(($(nproc) / 4)) +fi +torchrun --standalone \ + --nnodes=1 \ + --nproc-per-node 4 \ + deepspeed_autotp.py diff --git a/run_deepspeed.sh b/run_deepspeed.sh deleted file mode 100644 index 0997df4b571..00000000000 --- a/run_deepspeed.sh +++ /dev/null @@ -1,7 +0,0 @@ -export MASTER_ADDR=127.0.0.1 -export CCL_ZE_IPC_EXCHANGE=sockets -export OMP_NUM_THREADS=28 -torchrun --standalone \ - --nnodes=1 \ - --nproc-per-node 4 \ - test_deepspeed.py From 495e1076bd87ec7a647852bccda7a15afcccb1cc Mon Sep 17 00:00:00 2001 From: Yang Wang Date: Mon, 23 Oct 2023 23:43:37 +0000 Subject: [PATCH 5/9] refine --- .../llm/example/GPU/Deepspeed-AutoTP/{run_deepspeed.sh => run.sh} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename python/llm/example/GPU/Deepspeed-AutoTP/{run_deepspeed.sh => run.sh} (100%) diff --git a/python/llm/example/GPU/Deepspeed-AutoTP/run_deepspeed.sh b/python/llm/example/GPU/Deepspeed-AutoTP/run.sh similarity index 100% rename from python/llm/example/GPU/Deepspeed-AutoTP/run_deepspeed.sh rename to python/llm/example/GPU/Deepspeed-AutoTP/run.sh From c40ea3c1aca1d653e9447c3ee16f4393c53779fe Mon Sep 17 00:00:00 2001 From: Yang Wang Date: Tue, 24 Oct 2023 00:09:21 +0000 Subject: [PATCH 6/9] refine example --- .../llm/example/GPU/Deepspeed-AutoTP/deepspeed_autotp.py | 4 +++- python/llm/example/GPU/Deepspeed-AutoTP/run.sh | 2 +- python/llm/src/bigdl/llm/transformers/convert.py | 8 ++++---- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/python/llm/example/GPU/Deepspeed-AutoTP/deepspeed_autotp.py b/python/llm/example/GPU/Deepspeed-AutoTP/deepspeed_autotp.py index 1022b99febe..256e435b1d1 100644 --- a/python/llm/example/GPU/Deepspeed-AutoTP/deepspeed_autotp.py +++ b/python/llm/example/GPU/Deepspeed-AutoTP/deepspeed_autotp.py @@ -2,7 +2,7 @@ import torch import transformers import deepspeed -from gpu_benchmark_util import BenchmarkWrapper + local_rank = int(os.getenv("LOCAL_RANK", "0")) world_size = int(os.getenv("WORLD_SIZE", "1")) @@ -49,6 +49,8 @@ # move model back to xpu model = model.to(f'xpu:{local_rank}') + print(model) + # Load tokenizer tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) diff --git a/python/llm/example/GPU/Deepspeed-AutoTP/run.sh b/python/llm/example/GPU/Deepspeed-AutoTP/run.sh index 79adc15d2a0..77f216b1a90 100644 --- a/python/llm/example/GPU/Deepspeed-AutoTP/run.sh +++ b/python/llm/example/GPU/Deepspeed-AutoTP/run.sh @@ -9,4 +9,4 @@ fi torchrun --standalone \ --nnodes=1 \ --nproc-per-node 4 \ - deepspeed_autotp.py + deepspeed_autotp.py --repo-id-or-model-path "meta-llama/Llama-2-7b-hf" diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index ed004746292..7c0c28bf816 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -60,13 +60,13 @@ def is_linear_module(module): if is_deepspeed_available(): from deepspeed.module_inject.layers import LinearLayer, LinearAllreduce if isinstance(module, LinearLayer): - in_features = module.in_features - out_features = module.out_features + in_features = module.weight.shape[1] + out_features = module.weight.shape[0] mp_group = None result = True elif isinstance(module, LinearAllreduce): - in_features = module.in_features - out_features = module.out_features + in_features = module.weight.shape[1] + out_features = module.weight.shape[0] mp_group = module.mp_group result = True else: From bab15533f4e55a7b025f48594c5433cd2bb794db Mon Sep 17 00:00:00 2001 From: Yang Wang Date: Tue, 24 Oct 2023 00:21:37 +0000 Subject: [PATCH 7/9] fix style --- .../GPU/Deepspeed-AutoTP/deepspeed_autotp.py | 16 ++++++++++++++++ python/llm/src/bigdl/llm/transformers/convert.py | 9 ++------- 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/python/llm/example/GPU/Deepspeed-AutoTP/deepspeed_autotp.py b/python/llm/example/GPU/Deepspeed-AutoTP/deepspeed_autotp.py index 256e435b1d1..3f858e68ed2 100644 --- a/python/llm/example/GPU/Deepspeed-AutoTP/deepspeed_autotp.py +++ b/python/llm/example/GPU/Deepspeed-AutoTP/deepspeed_autotp.py @@ -1,3 +1,19 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + import os import torch import transformers diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index 7c0c28bf816..2c2dc083633 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -45,6 +45,7 @@ from .utils import logger from transformers.integrations.deepspeed import is_deepspeed_available + def is_linear_module(module): in_features = None @@ -140,13 +141,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, model._modules[name] = new_linear has_been_replaced = True # Force requires grad to False to avoid unexpected errors - try: - model._modules[name].requires_grad_(False) - except Exception as e: - logger.warning( - f"Failed to set `requires_grad=False` on {name} due to the following error: {e}" - ) - print(new_linear) + model._modules[name].requires_grad_(False) module.weight = None From 16295ee092bc15ee5c341e471007d37d856983b5 Mon Sep 17 00:00:00 2001 From: Yang Wang Date: Tue, 24 Oct 2023 18:24:23 -0700 Subject: [PATCH 8/9] refine example and adapte latest ipex --- .../GPU/Deepspeed-AutoTP/deepspeed_autotp.py | 5 +++-- .../llm/example/GPU/Deepspeed-AutoTP/run.sh | 2 +- .../llm/src/bigdl/llm/transformers/convert.py | 5 ++++- .../bigdl/llm/transformers/models/llama.py | 21 +++++++++++++++++-- 4 files changed, 27 insertions(+), 6 deletions(-) diff --git a/python/llm/example/GPU/Deepspeed-AutoTP/deepspeed_autotp.py b/python/llm/example/GPU/Deepspeed-AutoTP/deepspeed_autotp.py index 3f858e68ed2..6b1309a7f30 100644 --- a/python/llm/example/GPU/Deepspeed-AutoTP/deepspeed_autotp.py +++ b/python/llm/example/GPU/Deepspeed-AutoTP/deepspeed_autotp.py @@ -59,8 +59,9 @@ ) # move model to cpu and use bigdl-llm `optimize_model` to convert the - # model into optimized low bit format - model = optimize_model(model.module.to(f'cpu'), low_bit='sym_int4') + # model into optimized low bit format + # convert the rest of the model into float16 to reduce allreduce traffic + model = optimize_model(model.module.to(f'cpu'), low_bit='sym_int4').to(torch.float16) # move model back to xpu model = model.to(f'xpu:{local_rank}') diff --git a/python/llm/example/GPU/Deepspeed-AutoTP/run.sh b/python/llm/example/GPU/Deepspeed-AutoTP/run.sh index 77f216b1a90..972e8c9d247 100644 --- a/python/llm/example/GPU/Deepspeed-AutoTP/run.sh +++ b/python/llm/example/GPU/Deepspeed-AutoTP/run.sh @@ -1,4 +1,4 @@ -source bigdl-llm-init +source bigdl-llm-init -t -g export MASTER_ADDR=127.0.0.1 export CCL_ZE_IPC_EXCHANGE=sockets if [[ -n $OMP_NUM_THREADS ]]; then diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index 2c2dc083633..c5168a42f71 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -43,7 +43,10 @@ import importlib from bigdl.llm.ggml.quantize import ggml_tensor_qtype from .utils import logger -from transformers.integrations.deepspeed import is_deepspeed_available + + +def is_deepspeed_available(): + return importlib.util.find_spec("deepspeed") is not None def is_linear_module(module): diff --git a/python/llm/src/bigdl/llm/transformers/models/llama.py b/python/llm/src/bigdl/llm/transformers/models/llama.py index 0dd39ae669b..b3eb5ce6f89 100644 --- a/python/llm/src/bigdl/llm/transformers/models/llama.py +++ b/python/llm/src/bigdl/llm/transformers/models/llama.py @@ -57,11 +57,28 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: KV_CACHE_ALLOC_BLOCK_LENGTH = 256 +import importlib +def get_ipex_version(): + + if importlib.util.find_spec("intel_extension_for_pytorch") is not None: + import intel_extension_for_pytorch as ipex + return ipex.__version__ + else: + return None + + +ipex_version = get_ipex_version() + def llama_rms_norm_forward(self, hidden_states): if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad): - hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states, - [self.weight.size(0)], self.weight) + if ipex_version == "2.0.110+xpu": + hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states, + [self.weight.size(0)], self.weight) + else: + hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states, + [self.weight.size(0)], self.weight, + self.variance_epsilon) else: input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) From f2b3728423c0d07e780678c722113eb6a1954cb7 Mon Sep 17 00:00:00 2001 From: Yang Wang Date: Tue, 24 Oct 2023 22:53:34 -0700 Subject: [PATCH 9/9] fix style --- python/llm/src/bigdl/llm/transformers/models/llama.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/models/llama.py b/python/llm/src/bigdl/llm/transformers/models/llama.py index b3eb5ce6f89..94515ea00f1 100644 --- a/python/llm/src/bigdl/llm/transformers/models/llama.py +++ b/python/llm/src/bigdl/llm/transformers/models/llama.py @@ -32,6 +32,7 @@ # limitations under the License. import torch +import importlib import torch.nn as nn from typing import Optional, Tuple import math @@ -57,7 +58,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: KV_CACHE_ALLOC_BLOCK_LENGTH = 256 -import importlib + def get_ipex_version(): if importlib.util.find_spec("intel_extension_for_pytorch") is not None: @@ -74,11 +75,11 @@ def llama_rms_norm_forward(self, hidden_states): if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad): if ipex_version == "2.0.110+xpu": hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states, - [self.weight.size(0)], self.weight) + [self.weight.size(0)], self.weight) else: hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states, - [self.weight.size(0)], self.weight, - self.variance_epsilon) + [self.weight.size(0)], self.weight, + self.variance_epsilon) else: input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32)