diff --git a/python/llm/src/ipex_llm/transformers/npu_models/convert.py b/python/llm/src/ipex_llm/transformers/npu_models/convert.py index 065b48bd349..acee0d860fd 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/convert.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/convert.py @@ -109,11 +109,13 @@ def replace_with_QuantizedLinear(layer, qtype, device, modules_to_not_convert, enable_scale_search=enable_scale_search, imatrix=imatrix) if qtype == "sym_int4_rtn" and os.environ.get("IPEX_LLM_NPU_QUANTIZATION_HQQ", "0") != "0": - from .quantize import scale_grid_search + from .quantize import update_scale_grid_search # scale grid search - qweights, scale = scale_grid_search(layer.weight.data.to(torch.float32), - scale.to(torch.float32), - qweights) + print("=====original: ", qweights.shape, scale.shape) + qweights, scale = update_scale_grid_search(layer.weight.data.to(torch.float32), + (1.0 / scale.to(torch.float32)), + [-8, 7]) + print("=====update: ", qweights.shape, scale.shape) zero = None # split scale to scale & zero if qtype == "asym_int4_rtn": diff --git a/python/llm/src/ipex_llm/transformers/npu_models/linear.py b/python/llm/src/ipex_llm/transformers/npu_models/linear.py index c8a5dd467ae..a626d9ec042 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/linear.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/linear.py @@ -155,7 +155,7 @@ def __init__( False, ( f"Quantized weight must be in torch.(u)int8" - " dtype instead of {self.weight.dtype}" + f" dtype instead of {self.weight.dtype}" ) ) self.outC, self.inC = self.weight.shape diff --git a/python/llm/src/ipex_llm/transformers/npu_models/quantize.py b/python/llm/src/ipex_llm/transformers/npu_models/quantize.py index 303bdcc0eba..6fee413c3b9 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/quantize.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/quantize.py @@ -38,49 +38,60 @@ from typing import Union -def update_scale_grid_search(x: Tensor, scale: Tensor, min_max: list, N: int = 128 + 1): +def update_scale_grid_search(x: Tensor, iscale: Tensor, min_max: list, N: int = 128 + 1): + iscale = iscale.unsqueeze(1) print(x.shape) - print(scale.shape) + print(iscale.shape) assert N % 2 == 1, "Please check whether N: odd number" rng_dump = 0.05 # 0.05 / 1. z_val = 2e-4 - device = scale.device - dtype = scale.dtype + device = iscale.device + dtype = iscale.dtype ############################### - print("init scale shape is : ", scale.shape) - W_q = (x / scale).clamp(min_max[0], min_max[1]) + print("init scale shape is : ", iscale.shape) + W_q = (x * iscale).clamp(min_max[0], min_max[1]) n_clusters = W_q.shape[0] - rng = torch.abs(scale).mean() * rng_dump if (rng_dump < 1.0) else rng_dump + rng = torch.abs(iscale).mean() * rng_dump if (rng_dump < 1.0) else rng_dump print("rng is : ", rng) - scale_shifted = ( + iscale_shifted = ( torch.linspace(-rng, rng, N)[None, :] .to(dtype=dtype, device=device) .repeat(n_clusters, 1) - ) + ) + iscale - scale_shifted += scale + print(iscale_shifted.shape) # Safe inverse - scale_shifted[ - torch.logical_and(scale_shifted >= 0, torch.abs(scale_shifted) <= z_val) + iscale_shifted[ + torch.logical_and(iscale_shifted >= 0, torch.abs(iscale_shifted) <= z_val) ] = z_val - scale_shifted[ - torch.logical_and(scale_shifted < 0, torch.abs(scale_shifted) <= z_val) + iscale_shifted[ + torch.logical_and(iscale_shifted < 0, torch.abs(iscale_shifted) <= z_val) ] = -z_val err = torch.empty([n_clusters, N], dtype=dtype, device=device) for i in range(N): - W_r = W_q * scale_shifted[:, i][:, None] + W_r = W_q * iscale_shifted[:, i][:, None] err[:, i] = torch.abs(x - W_r).mean(axis=1, keepdim=True).squeeze() - print(f"err [{i}] shape is ", err[i].shape) - + ind_r = torch.argmin(err, axis=1).to(torch.int32) ind_c = torch.arange(len(ind_r), dtype=torch.int32, device=device) - scale_b = scale_shifted[ind_c, ind_r] - + iscale_b = iscale_shifted[ind_c, ind_r] + scale_b = 1.0 / iscale_b + iscale_b = iscale_b.unsqueeze(1) + print(iscale_b.shape) # obtain qwights based on scale_b + qweights = (x * iscale_b).to(torch.int8) # m * n + qweights = qweights.reshape(x.shape[0], -1 , 2) # m * n/2 * 2 + print(qweights.split(1, dim=-1)) + high_bit, low_bit = qweights.split(1, dim=-1) + print(high_bit.shape) + high_bit = high_bit.squeeze().view(torch.int8) + low_bit = low_bit.squeeze().view(torch.int8) + high_bit = high_bit << 4 + qweights = high_bit | low_bit - return scale_b, qweights + return qweights, scale_b.to(torch.float16)