Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Init NPU quantize method and support q8_0_rtn #11452

Merged
merged 8 commits into from
Jul 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions python/llm/src/ipex_llm/ggml/model/llama/llama_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -991,6 +991,33 @@ def ggml_quantize_tensor_with_weights(
_lib.ggml_quantize_tensor_with_weights.restype = ctypes.c_size_t


# GGML API
def ggml_quantize_tensor_rtn(
src, # type: ctypes.Array[ctypes.c_float] # type: ignore
dst: ctypes.c_void_p,
scale_ptr, # type: ctypes.Array[ctypes.c_float] # type: ignore
qtype: ctypes.c_int,
n: ctypes.c_size_t,
k: ctypes.c_int,
hist, # type: ctypes.Array[ctypes.c_int64] # type: ignore
scale_search: ctypes.c_bool,
) -> int:
return _lib.ggml_quantize_tensor_rtn(src, dst, scale_ptr, qtype, n, k, hist, scale_search)


_lib.ggml_quantize_tensor_rtn.argtypes = [
ctypes.POINTER(ctypes.c_float),
ctypes.c_void_p,
ctypes.POINTER(ctypes.c_float),
ctypes.c_int,
ctypes.c_size_t,
ctypes.c_int,
ctypes.POINTER(ctypes.c_int64),
ctypes.c_bool,
]
_lib.ggml_quantize_tensor_rtn.restype = ctypes.c_size_t


def ggml_type_size(qtype: ctypes.c_int) -> int:
return _lib.ggml_type_size(qtype)

Expand Down
2 changes: 2 additions & 0 deletions python/llm/src/ipex_llm/ggml/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@
"q5_k": 28,
"fp6": 29,
"fp6_k": 30,
"sym_int4_rtn": 31,
"sym_int8_rtn": 32,
}

# mixed precison from llama.cpp
Expand Down
20 changes: 17 additions & 3 deletions python/llm/src/ipex_llm/transformers/low_bit_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
Q6_K = ggml_tensor_qtype["q6_k"]
Q5_K = ggml_tensor_qtype["q5_k"]
FP6_K = ggml_tensor_qtype["fp6_k"]
SYM_INT8_RTN = ggml_tensor_qtype["sym_int8_rtn"]


# For sym_int4
Expand Down Expand Up @@ -216,14 +217,27 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
f"Last dim of input tensor must be multiple of {QK}")

dst_size = (n // QK) * block_size_in_bytes
dst_tensor = torch.empty(dst_size, dtype=torch.uint8,
device=device)
if qtype in [SYM_INT8_RTN]:
dst_tensor = torch.empty(dst_size, dtype=torch.int8,
device=device)
scale = torch.empty(n // k, dtype=torch.float32,
device=device)
else:
dst_tensor = torch.empty(dst_size, dtype=torch.uint8,
device=device)

if not convert_shape_only and device != 'meta':
dst = ctypes.c_void_p(dst_tensor.data.data_ptr())
hist = (ctypes.c_int64 * 16)()
if qtype not in [IQ2_XXS, IQ2_XS, Q2_K, IQ1_S, Q4_K, Q6_K, Q5_K, FP6_K]:
ggml.ggml_quantize_tensor(src, dst, qtype, n, k, hist, enable_scale_search)
if qtype in [SYM_INT8_RTN]:
scale_ptr = ctypes.cast(scale.data.data_ptr(), ctypes.POINTER(ctypes.c_float))
ggml.ggml_quantize_tensor_rtn(src, dst, scale_ptr, qtype, n,
k, hist, enable_scale_search)
dst_tensor = dst_tensor.reshape_as(tensor)
return dst_tensor, scale.type(torch.float16)
else:
ggml.ggml_quantize_tensor(src, dst, qtype, n, k, hist, enable_scale_search)
else:
if imatrix is not None:
# quantize with importance matrix
Expand Down
16 changes: 12 additions & 4 deletions python/llm/src/ipex_llm/transformers/npu_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def from_pretrained(cls,
from intel_npu_acceleration_library.dtypes import int8, int4
qtype_map = {
'sym_int4': int4,
'sym_int8': int8,
'sym_int8': "sym_int8_rtn",
'fp16': torch.half,
'fp32': torch.float,
}
Expand Down Expand Up @@ -119,9 +119,12 @@ def from_pretrained(cls,
from intel_npu_acceleration_library.compiler import create_npu_kernels
with torch.no_grad():
optimize_llm(model)
if not qtype.is_floating_point:
model = quantize_model(model, qtype)
create_npu_kernels(model)
if qtype == "sym_int8_rtn":
cls.load_convert(qtype, model, *args, **kwargs)
else:
if not qtype.is_floating_point:
model = quantize_model(model, qtype)
create_npu_kernels(model)
model = model.eval()
except ImportError as _e:
# for intel_npu_acceleration_library < 1.1.0
Expand All @@ -133,6 +136,11 @@ def from_pretrained(cls,

return model

@classmethod
def load_convert(cls, q_k, optimize_model, *arg, **kwarg):
from ipex_llm.transformers.npu_models.convert import replace_with_QuantizedLinear
replace_with_QuantizedLinear(optimize_model, q_k)

@staticmethod
def save_low_bit(self, model_dir: str, *args, **kwargs):
os.makedirs(model_dir, exist_ok=True)
Expand Down
43 changes: 43 additions & 0 deletions python/llm/src/ipex_llm/transformers/npu_models/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,49 @@


import torch
from intel_npu_acceleration_library.nn import QuantizedLinear


def module_optimization(func) -> torch.nn.Module:
"""Optimize recursively a torch.nn.Module with a specific function.

The function `func` get called recursively to every module in the network.

Args:
func (Callable): optimization function

Returns:
torch.nn.Module: optimized module
"""

def wrapper(model: torch.nn.Module, qtype, *args, **kwargs):
"""Recursively apply the optimization function.

Args:
model (torch.nn.Module): original module
args (Any): positional arguments
kwargs (Any): keyword arguments

"""
for name, layer in model.named_children():
new_layer = func(layer, qtype, *args, **kwargs)
if new_layer:
model.add_module(name, new_layer)
wrapper(new_layer, qtype, *args, **kwargs)
else:
wrapper(layer, qtype, *args, **kwargs)

return wrapper


@module_optimization
def replace_with_QuantizedLinear(layer, qtype):
from ipex_llm.transformers.low_bit_linear import ggml_convert_qtype
from ipex_llm.ggml.quantize import ggml_tensor_qtype
iqtype = ggml_tensor_qtype[qtype]
if isinstance(layer, torch.nn.Linear):
qweights, scale = ggml_convert_qtype(layer.weight.data, iqtype, 'cpu')
return QuantizedLinear(qweights, scale, layer.bias)


def convert_forward(m, target_m, new_forward):
Expand Down
Loading