Skip to content

Commit

Permalink
temp save
Browse files Browse the repository at this point in the history
  • Loading branch information
rnwang04 committed Dec 12, 2024
1 parent 509bdb4 commit a34493d
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 2 deletions.
15 changes: 13 additions & 2 deletions python/llm/src/ipex_llm/transformers/npu_models/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,23 @@ def replace_with_QuantizedLinear(layer, qtype, device, modules_to_not_convert,
if (layer.in_features == 18944 and layer.out_features == 3584):
qtype = "sym_int8_rtn"
iqtype = ggml_tensor_qtype[qtype]
enable_scale_search = (os.environ.get("IPEX_LLM_NPU_QUANTIZATION_OPT", "0") != "0" or
os.environ.get("IPEX_LLM_NPU_QUANTIZATION_HQQ", "0") != "0")
if qtype == "asym_int4_rtn":
enable_scale_search = (os.environ.get("IPEX_LLM_NPU_QUANTIZATION_OPT", "0") != "0" or
os.environ.get("IPEX_LLM_NPU_QUANTIZATION_HQQ", "0") != "0")
elif qtype == "sym_int4_rtn":
enable_scale_search = os.environ.get("IPEX_LLM_NPU_QUANTIZATION_OPT", "0") != "0"
else:
enable_scale_search = False
qweights, scale = ggml_convert_qtype(layer.weight.data.to(torch.float32),
iqtype, device=device,
enable_scale_search=enable_scale_search,
imatrix=imatrix)
if qtype == "sym_int4_rtn" and os.environ.get("IPEX_LLM_NPU_QUANTIZATION_HQQ", "0") != "0":
from .quantize import scale_grid_search
# scale grid search
qweights, scale = scale_grid_search(layer.weight.data.to(torch.float32),
scale.to(torch.float32),
qweights)
zero = None
# split scale to scale & zero
if qtype == "asym_int4_rtn":
Expand Down
86 changes: 86 additions & 0 deletions python/llm/src/ipex_llm/transformers/npu_models/quantize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from
# https://github.com/mobiusml/hqq/blob/master/hqq/core/optimize.py
# which is licensed under Apache License 2.0:
#
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import numpy as np
from torch import float32, float16, Tensor
from functools import partial
from typing import Union


def update_scale_grid_search(x: Tensor, scale: Tensor, min_max: list, N: int = 128 + 1):
print(x.shape)
print(scale.shape)

assert N % 2 == 1, "Please check whether N: odd number"
rng_dump = 0.05 # 0.05 / 1.
z_val = 2e-4

device = scale.device
dtype = scale.dtype
###############################
print("init scale shape is : ", scale.shape)
W_q = (x / scale).clamp(min_max[0], min_max[1])
n_clusters = W_q.shape[0]
rng = torch.abs(scale).mean() * rng_dump if (rng_dump < 1.0) else rng_dump
print("rng is : ", rng)

scale_shifted = (
torch.linspace(-rng, rng, N)[None, :]
.to(dtype=dtype, device=device)
.repeat(n_clusters, 1)
)

scale_shifted += scale

# Safe inverse
scale_shifted[
torch.logical_and(scale_shifted >= 0, torch.abs(scale_shifted) <= z_val)
] = z_val
scale_shifted[
torch.logical_and(scale_shifted < 0, torch.abs(scale_shifted) <= z_val)
] = -z_val

err = torch.empty([n_clusters, N], dtype=dtype, device=device)
for i in range(N):
W_r = W_q * scale_shifted[:, i][:, None]
err[:, i] = torch.abs(x - W_r).mean(axis=1, keepdim=True).squeeze()
print(f"err [{i}] shape is ", err[i].shape)

ind_r = torch.argmin(err, axis=1).to(torch.int32)
ind_c = torch.arange(len(ind_r), dtype=torch.int32, device=device)
scale_b = scale_shifted[ind_c, ind_r]

# obtain qwights based on scale_b

return scale_b, qweights

0 comments on commit a34493d

Please sign in to comment.