From 1ce8d7bcd984016734cd77b0a7e91283adfabbb6 Mon Sep 17 00:00:00 2001 From: Yang Wang Date: Wed, 24 Apr 2024 10:17:13 -0700 Subject: [PATCH] Support the `desc_act` feature in GPTQ model (#10851) * support act_order * update versions * fix style * fix bug * clean up --- .../Advanced-Quantizations/GPTQ/README.md | 6 +-- .../Advanced-Quantizations/GPTQ/generate.py | 9 ++-- .../llm/src/ipex_llm/transformers/convert.py | 47 +++++++++++++++---- .../ipex_llm/transformers/low_bit_linear.py | 10 +++- python/llm/src/ipex_llm/transformers/model.py | 4 +- 5 files changed, 56 insertions(+), 20 deletions(-) diff --git a/python/llm/example/GPU/HF-Transformers-AutoModels/Advanced-Quantizations/GPTQ/README.md b/python/llm/example/GPU/HF-Transformers-AutoModels/Advanced-Quantizations/GPTQ/README.md index 742ba6ec8bc..a8040d31a5a 100644 --- a/python/llm/example/GPU/HF-Transformers-AutoModels/Advanced-Quantizations/GPTQ/README.md +++ b/python/llm/example/GPU/HF-Transformers-AutoModels/Advanced-Quantizations/GPTQ/README.md @@ -13,9 +13,9 @@ conda create -n llm python=3.11 conda activate llm # below command will install intel_extension_for_pytorch==2.1.10+xpu as default pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ -pip install transformers==4.34.0 -BUILD_CUDA_EXT=0 pip install git+https://github.com/PanQiWei/AutoGPTQ.git@1de9ab6 -pip install optimum==0.14.0 +pip install transformers==4.37.0 +pip install auto_gptq==0.7.1 +pip install optimum==1.14.0 ``` ### 2. Configures OneAPI environment variables diff --git a/python/llm/example/GPU/HF-Transformers-AutoModels/Advanced-Quantizations/GPTQ/generate.py b/python/llm/example/GPU/HF-Transformers-AutoModels/Advanced-Quantizations/GPTQ/generate.py index 6a39250de75..c830d9106e0 100644 --- a/python/llm/example/GPU/HF-Transformers-AutoModels/Advanced-Quantizations/GPTQ/generate.py +++ b/python/llm/example/GPU/HF-Transformers-AutoModels/Advanced-Quantizations/GPTQ/generate.py @@ -18,7 +18,7 @@ import time import argparse from ipex_llm.transformers import AutoModelForCausalLM -from transformers import LlamaTokenizer, GPTQConfig +from transformers import AutoTokenizer, GPTQConfig # you could tune the prompt based on your own model, # here the prompt tuning refers to https://huggingface.co/georgesung/llama2_7b_chat_uncensored#prompt-style @@ -30,7 +30,7 @@ if __name__ == '__main__': parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for Llama2 model') - parser.add_argument('--repo-id-or-model-path', type=str, default="TheBloke/Llama-2-7B-GPTQ", + parser.add_argument('--repo-id-or-model-path', type=str, default="TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", help='The huggingface repo id' ', or the path to the huggingface checkpoint folder') parser.add_argument('--prompt', type=str, default="What is AI?", @@ -47,9 +47,10 @@ load_in_4bit=True, torch_dtype=torch.float, trust_remote_code=True,).to("xpu") - + + print(model) # Load tokenizer - tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) # Generate predicted tokens with torch.inference_mode(): diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 9ddf7bae4d5..897186980e6 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -99,6 +99,11 @@ def is_lm_head(name, model_config, out_features): return False +def is_gptq_linear(module): + return is_auto_gptq_available() and \ + (isinstance(module, QuantLinearCuda) or isinstance(module, QuantLinearCudaOld)) + + def is_linear_module(module): in_features = None @@ -122,7 +127,7 @@ def is_linear_module(module): mp_group = None else: result = False - elif is_auto_gptq_available() and isinstance(module, QuantLinearCudaOld): + elif is_gptq_linear(module): in_features = module.infeatures out_features = module.outfeatures mp_group = None @@ -153,7 +158,7 @@ def is_linear_module(module): return result, (in_features, out_features, mp_group) -def convert_gptq(module, awq=False, llm_awq=False): +def convert_gptq(module, awq=False, llm_awq=False, act_order=False): from ipex_llm.transformers.low_bit_linear import get_block_size Q4_1 = get_block_size("asym_int4") @@ -164,6 +169,8 @@ def convert_gptq(module, awq=False, llm_awq=False): module.wf.unsqueeze(0)).to(torch.int16 if module.bits == 8 else torch.int8) zeros = torch.bitwise_and(zeros, (2 ** module.bits) - 1) + g_id_map = None + if not awq: zeros = zeros + 1 zeros = zeros.reshape(scales.shape) @@ -183,6 +190,12 @@ def convert_gptq(module, awq=False, llm_awq=False): weight = torch.bitwise_and(weight, (2 ** module.bits) - 1) weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2]) + if act_order: + invalidInputError(module.g_idx.shape[0] == weight.shape[0], + "g_idx and weight shape mismatch") + _, g_id_map = torch.sort(module.g_idx) + weight = weight[g_id_map, :] + # convert weight to ggml format weight = weight.reshape(weight.shape[0]//module.group_size, module.group_size, weight.shape[1]) weight = weight.permute(2, 0, 1).reshape(weight.shape[2], -1, 2, Q4_1//2) @@ -219,7 +232,7 @@ def convert_gptq(module, awq=False, llm_awq=False): weight.view(torch.uint8)], dim=-1) ggml_weight = ggml_weight.reshape([-1]) - return ggml_weight + return ggml_weight, g_id_map def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, @@ -228,7 +241,9 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, imatrix_data=None, embedding_qtype=None, model_config=None, torch_dtype=torch.float32, enable_xetla=False, - mixed_precision=False): + mixed_precision=False, + act_order=False, + ): from ipex_llm.transformers.low_bit_linear import LowBitLinear, FP4Params, \ FP16Linear, BF16Linear from ipex_llm.transformers.embedding import LLMEmbedding, LowBitEmbedding @@ -252,7 +267,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, optimize_lm_head = True with init_empty_weights(): new_linear = None - is_gptq = is_auto_gptq_available() and isinstance(module, QuantLinearCudaOld) + is_gptq = is_gptq_linear(module) is_awq = is_auto_awq_available() and isinstance(module, WQLinear_GEMM) is_llm_awq = is_awq and module.backend == AwqBackendPackingMethod.LLMAWQ if is_gptq or is_awq: @@ -264,14 +279,20 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, bias=has_bias, mp_group=mp_group, enable_xetla=enable_xetla, - optimize_lm_head=optimize_lm_head + optimize_lm_head=optimize_lm_head, + act_order=act_order, ) device = module.qweight.data.device invalidInputError(device.type != "meta", "converting from meta device is not supported") + weight, g_idx_map = convert_gptq(module, + awq=is_awq, + llm_awq=is_llm_awq, + act_order=act_order) + if act_order: + new_linear.g_idx_map = g_idx_map # Copy the weights - paramsLowBit = FP4Params(data=convert_gptq(module, awq=is_awq, - llm_awq=is_llm_awq), + paramsLowBit = FP4Params(data=weight, requires_grad=False, quantized=True, _shape=(out_features, in_features), @@ -422,7 +443,8 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, model_config=model_config, torch_dtype=torch_dtype, enable_xetla=enable_xetla, - mixed_precision=mixed_precision + mixed_precision=mixed_precision, + act_order=act_order, ) has_been_replaced = _flag or has_been_replaced return model, has_been_replaced @@ -464,7 +486,7 @@ def replace_with_low_bit_linear_for_module(model, qtype, module_name=None, in_features, out_features, mp_group = linear_args with init_empty_weights(): new_linear = None - is_gptq = is_auto_gptq_available() and isinstance(module, QuantLinearCudaOld) + is_gptq = is_gptq_linear(module) is_awq = is_auto_awq_available() and isinstance(module, WQLinear_GEMM) is_llm_awq = is_awq and module.backend == AwqBackendPackingMethod.LLMAWQ if is_gptq or is_awq: @@ -721,6 +743,10 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True, if optimize_model: model = _optimize_pre(model) + act_order = False + if getattr(model, "quantization_method", None) == "gptq": + act_order = model.config.quantization_config.desc_act + # mixed quantization needs model_config to choose custom quantization strategy model, has_been_replaced = _replace_with_low_bit_linear( model, qtype, modules_to_not_convert, @@ -731,6 +757,7 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True, torch_dtype=torch_dtype, enable_xetla=enable_xetla, mixed_precision=mixed_precision, + act_order=act_order, ) if not has_been_replaced: warnings.warn( diff --git a/python/llm/src/ipex_llm/transformers/low_bit_linear.py b/python/llm/src/ipex_llm/transformers/low_bit_linear.py index 534e077dcb3..f26ad952757 100644 --- a/python/llm/src/ipex_llm/transformers/low_bit_linear.py +++ b/python/llm/src/ipex_llm/transformers/low_bit_linear.py @@ -579,7 +579,7 @@ def backward(ctx, grad_output): class LowBitLinear(nn.Linear): def __init__(self, input_features, output_features, qtype, bias=True, conver_to_half=True, mp_group=None, enable_xetla=False, - optimize_lm_head=False): + optimize_lm_head=False, act_order=False): super().__init__(input_features, output_features, bias) self.weight = FP4Params(self.weight.data, requires_grad=False, @@ -603,6 +603,11 @@ def __init__(self, input_features, output_features, qtype, bias=True, # since performance isn't impacted. self.is_lm_head = self.in_len * self.out_len >= 32000 * 4096 and self.bias is None self.low_memory_mode = self.is_lm_head + self.act_order = act_order + if act_order: + self.register_buffer( + "g_idx_map", + torch.tensor([i for i in range(self.in_len)], dtype=torch.int64)) def forward(self, x: torch.Tensor): # empty cache before and after lm_head at first token when input > 1024 @@ -640,6 +645,9 @@ def forward(self, x: torch.Tensor): return torch.empty(new_shape, dtype=x.dtype, device=x.device) x_2d = x.view(-1, x_shape[-1]) + + if self.act_order: + x_2d = x_2d[:, self.g_idx_map] # x0 for weight x0 = self.weight.data diff --git a/python/llm/src/ipex_llm/transformers/model.py b/python/llm/src/ipex_llm/transformers/model.py index b1604b6c382..31d1b4e9a39 100644 --- a/python/llm/src/ipex_llm/transformers/model.py +++ b/python/llm/src/ipex_llm/transformers/model.py @@ -243,8 +243,6 @@ def from_pretrained(cls, if q_config["quant_method"] == "gptq": invalidInputError(q_config["bits"] == 4, "Only 4-bit gptq is supported in bigdl-llm.") - invalidInputError(q_config["desc_act"] is False, - "Only desc_act=False is supported in bigdl-llm.") if load_in_low_bit is not None: invalidInputError(load_in_low_bit == "asym_int4", "You can only load gptq model as aysm_int4 low bit type.") @@ -448,6 +446,8 @@ def load_convert(cls, q_k, optimize_model, *args, **kwargs): offload_dir=None ) else: + if quant_config is not None: + kwargs["quantization_config"] = quant_config _load_pre() try: # To handle the input CUDA setting (such as 'device_map={"":0}'), ignore it