diff --git a/python/llm/src/ipex_llm/transformers/npu_models/convert.py b/python/llm/src/ipex_llm/transformers/npu_models/convert.py index acee0d860fd..4fcd3f901ad 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/convert.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/convert.py @@ -111,11 +111,9 @@ def replace_with_QuantizedLinear(layer, qtype, device, modules_to_not_convert, if qtype == "sym_int4_rtn" and os.environ.get("IPEX_LLM_NPU_QUANTIZATION_HQQ", "0") != "0": from .quantize import update_scale_grid_search # scale grid search - print("=====original: ", qweights.shape, scale.shape) qweights, scale = update_scale_grid_search(layer.weight.data.to(torch.float32), (1.0 / scale.to(torch.float32)), [-8, 7]) - print("=====update: ", qweights.shape, scale.shape) zero = None # split scale to scale & zero if qtype == "asym_int4_rtn": diff --git a/python/llm/src/ipex_llm/transformers/npu_models/quantize.py b/python/llm/src/ipex_llm/transformers/npu_models/quantize.py index 6fee413c3b9..f9a127e3c9f 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/quantize.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/quantize.py @@ -32,16 +32,11 @@ # limitations under the License. import torch -import numpy as np -from torch import float32, float16, Tensor -from functools import partial -from typing import Union +from torch import Tensor def update_scale_grid_search(x: Tensor, iscale: Tensor, min_max: list, N: int = 128 + 1): iscale = iscale.unsqueeze(1) - print(x.shape) - print(iscale.shape) assert N % 2 == 1, "Please check whether N: odd number" rng_dump = 0.05 # 0.05 / 1. @@ -50,11 +45,9 @@ def update_scale_grid_search(x: Tensor, iscale: Tensor, min_max: list, N: int = device = iscale.device dtype = iscale.dtype ############################### - print("init scale shape is : ", iscale.shape) - W_q = (x * iscale).clamp(min_max[0], min_max[1]) + W_q = torch.round(x * iscale).clamp(min_max[0], min_max[1]) n_clusters = W_q.shape[0] rng = torch.abs(iscale).mean() * rng_dump if (rng_dump < 1.0) else rng_dump - print("rng is : ", rng) iscale_shifted = ( torch.linspace(-rng, rng, N)[None, :] @@ -74,7 +67,7 @@ def update_scale_grid_search(x: Tensor, iscale: Tensor, min_max: list, N: int = err = torch.empty([n_clusters, N], dtype=dtype, device=device) for i in range(N): - W_r = W_q * iscale_shifted[:, i][:, None] + W_r = W_q * iscale_shifted[:, i][:, None] err[:, i] = torch.abs(x - W_r).mean(axis=1, keepdim=True).squeeze() ind_r = torch.argmin(err, axis=1).to(torch.int32) @@ -82,16 +75,18 @@ def update_scale_grid_search(x: Tensor, iscale: Tensor, min_max: list, N: int = iscale_b = iscale_shifted[ind_c, ind_r] scale_b = 1.0 / iscale_b iscale_b = iscale_b.unsqueeze(1) - print(iscale_b.shape) + # obtain qwights based on scale_b - qweights = (x * iscale_b).to(torch.int8) # m * n + qweights = (torch.round(x * iscale_b)).clamp(min_max[0], min_max[1]).to(torch.int8) # m * n + # test with original + # scale_b = (1.0 / iscale).squeeze() + # qweights = (torch.round(x * iscale)).clamp(min_max[0], min_max[1]).to(torch.int8) # m * n qweights = qweights.reshape(x.shape[0], -1 , 2) # m * n/2 * 2 - print(qweights.split(1, dim=-1)) - high_bit, low_bit = qweights.split(1, dim=-1) - print(high_bit.shape) + low_bit, high_bit = qweights.split(1, dim=-1) high_bit = high_bit.squeeze().view(torch.int8) low_bit = low_bit.squeeze().view(torch.int8) high_bit = high_bit << 4 + low_bit = low_bit & 0x0f qweights = high_bit | low_bit - return qweights, scale_b.to(torch.float16) + return qweights.view(torch.uint8), scale_b.to(torch.float16)