Skip to content

Commit

Permalink
update, v1 scale search
Browse files Browse the repository at this point in the history
  • Loading branch information
rnwang04 committed Dec 12, 2024
1 parent b7d7268 commit b3025a0
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions python/llm/src/ipex_llm/transformers/npu_models/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@
from torch import Tensor


def c_round(x: Tensor):
return torch.sign(x) * torch.floor(torch.abs(x) + 0.5)

def update_scale_grid_search(x: Tensor, iscale: Tensor, min_max: list, N: int = 128 + 1):
iscale = iscale.unsqueeze(1)

Expand All @@ -45,7 +48,7 @@ def update_scale_grid_search(x: Tensor, iscale: Tensor, min_max: list, N: int =
device = iscale.device
dtype = iscale.dtype
###############################
W_q = torch.round(x * iscale).clamp(min_max[0], min_max[1])
W_q = c_round(x * iscale).clamp(min_max[0], min_max[1])
n_clusters = W_q.shape[0]
rng = torch.abs(iscale).mean() * rng_dump if (rng_dump < 1.0) else rng_dump

Expand All @@ -55,8 +58,6 @@ def update_scale_grid_search(x: Tensor, iscale: Tensor, min_max: list, N: int =
.repeat(n_clusters, 1)
) + iscale

print(iscale_shifted.shape)

# Safe inverse
iscale_shifted[
torch.logical_and(iscale_shifted >= 0, torch.abs(iscale_shifted) <= z_val)
Expand All @@ -76,11 +77,12 @@ def update_scale_grid_search(x: Tensor, iscale: Tensor, min_max: list, N: int =
scale_b = 1.0 / iscale_b
iscale_b = iscale_b.unsqueeze(1)

# obtain qwights based on scale_b
qweights = (torch.round(x * iscale_b)).clamp(min_max[0], min_max[1]).to(torch.int8) # m * n
# test with original
# scale_b = (1.0 / iscale).squeeze()
# qweights = (torch.round(x * iscale)).clamp(min_max[0], min_max[1]).to(torch.int8) # m * n
# qweights = (c_round(x * iscale)).clamp(-8.0, 7.0).to(torch.int8) # m * n

# obtain qwights based on scale_b
qweights = (c_round(x * iscale_b)).clamp(min_max[0], min_max[1]).to(torch.int8) # m * n
qweights = qweights.reshape(x.shape[0], -1 , 2) # m * n/2 * 2
low_bit, high_bit = qweights.split(1, dim=-1)
high_bit = high_bit.squeeze().view(torch.int8)
Expand Down

0 comments on commit b3025a0

Please sign in to comment.