diff --git a/python/llm/src/ipex_llm/transformers/npu_models/quantize.py b/python/llm/src/ipex_llm/transformers/npu_models/quantize.py index f9a127e3c9f..c47df92d340 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/quantize.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/quantize.py @@ -35,6 +35,9 @@ from torch import Tensor +def c_round(x: Tensor): + return torch.sign(x) * torch.floor(torch.abs(x) + 0.5) + def update_scale_grid_search(x: Tensor, iscale: Tensor, min_max: list, N: int = 128 + 1): iscale = iscale.unsqueeze(1) @@ -45,7 +48,7 @@ def update_scale_grid_search(x: Tensor, iscale: Tensor, min_max: list, N: int = device = iscale.device dtype = iscale.dtype ############################### - W_q = torch.round(x * iscale).clamp(min_max[0], min_max[1]) + W_q = c_round(x * iscale).clamp(min_max[0], min_max[1]) n_clusters = W_q.shape[0] rng = torch.abs(iscale).mean() * rng_dump if (rng_dump < 1.0) else rng_dump @@ -55,8 +58,6 @@ def update_scale_grid_search(x: Tensor, iscale: Tensor, min_max: list, N: int = .repeat(n_clusters, 1) ) + iscale - print(iscale_shifted.shape) - # Safe inverse iscale_shifted[ torch.logical_and(iscale_shifted >= 0, torch.abs(iscale_shifted) <= z_val) @@ -76,11 +77,12 @@ def update_scale_grid_search(x: Tensor, iscale: Tensor, min_max: list, N: int = scale_b = 1.0 / iscale_b iscale_b = iscale_b.unsqueeze(1) - # obtain qwights based on scale_b - qweights = (torch.round(x * iscale_b)).clamp(min_max[0], min_max[1]).to(torch.int8) # m * n # test with original # scale_b = (1.0 / iscale).squeeze() - # qweights = (torch.round(x * iscale)).clamp(min_max[0], min_max[1]).to(torch.int8) # m * n + # qweights = (c_round(x * iscale)).clamp(-8.0, 7.0).to(torch.int8) # m * n + + # obtain qwights based on scale_b + qweights = (c_round(x * iscale_b)).clamp(min_max[0], min_max[1]).to(torch.int8) # m * n qweights = qweights.reshape(x.shape[0], -1 , 2) # m * n/2 * 2 low_bit, high_bit = qweights.split(1, dim=-1) high_bit = high_bit.squeeze().view(torch.int8)