diff --git a/python/llm/src/ipex_llm/ggml/model/llama/llama_cpp.py b/python/llm/src/ipex_llm/ggml/model/llama/llama_cpp.py index abba27673f4..8308308045a 100644 --- a/python/llm/src/ipex_llm/ggml/model/llama/llama_cpp.py +++ b/python/llm/src/ipex_llm/ggml/model/llama/llama_cpp.py @@ -991,6 +991,33 @@ def ggml_quantize_tensor_with_weights( _lib.ggml_quantize_tensor_with_weights.restype = ctypes.c_size_t +# GGML API +def ggml_quantize_tensor_rtn( + src, # type: ctypes.Array[ctypes.c_float] # type: ignore + dst: ctypes.c_void_p, + scale_ptr, # type: ctypes.Array[ctypes.c_float] # type: ignore + qtype: ctypes.c_int, + n: ctypes.c_size_t, + k: ctypes.c_int, + hist, # type: ctypes.Array[ctypes.c_int64] # type: ignore + scale_search: ctypes.c_bool, +) -> int: + return _lib.ggml_quantize_tensor_rtn(src, dst, scale_ptr, qtype, n, k, hist, scale_search) + + +_lib.ggml_quantize_tensor_rtn.argtypes = [ + ctypes.POINTER(ctypes.c_float), + ctypes.c_void_p, + ctypes.POINTER(ctypes.c_float), + ctypes.c_int, + ctypes.c_size_t, + ctypes.c_int, + ctypes.POINTER(ctypes.c_int64), + ctypes.c_bool, +] +_lib.ggml_quantize_tensor_rtn.restype = ctypes.c_size_t + + def ggml_type_size(qtype: ctypes.c_int) -> int: return _lib.ggml_type_size(qtype) diff --git a/python/llm/src/ipex_llm/ggml/quantize.py b/python/llm/src/ipex_llm/ggml/quantize.py index 8388fc3bede..76702e88117 100644 --- a/python/llm/src/ipex_llm/ggml/quantize.py +++ b/python/llm/src/ipex_llm/ggml/quantize.py @@ -50,6 +50,8 @@ "q5_k": 28, "fp6": 29, "fp6_k": 30, + "sym_int4_rtn": 31, + "sym_int8_rtn": 32, } # mixed precison from llama.cpp diff --git a/python/llm/src/ipex_llm/transformers/low_bit_linear.py b/python/llm/src/ipex_llm/transformers/low_bit_linear.py index 1038688ed78..1d632d6f01c 100644 --- a/python/llm/src/ipex_llm/transformers/low_bit_linear.py +++ b/python/llm/src/ipex_llm/transformers/low_bit_linear.py @@ -81,6 +81,7 @@ Q6_K = ggml_tensor_qtype["q6_k"] Q5_K = ggml_tensor_qtype["q5_k"] FP6_K = ggml_tensor_qtype["fp6_k"] +SYM_INT8_RTN = ggml_tensor_qtype["sym_int8_rtn"] # For sym_int4 @@ -216,14 +217,27 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int, f"Last dim of input tensor must be multiple of {QK}") dst_size = (n // QK) * block_size_in_bytes - dst_tensor = torch.empty(dst_size, dtype=torch.uint8, - device=device) + if qtype in [SYM_INT8_RTN]: + dst_tensor = torch.empty(dst_size, dtype=torch.int8, + device=device) + scale = torch.empty(n // k, dtype=torch.float32, + device=device) + else: + dst_tensor = torch.empty(dst_size, dtype=torch.uint8, + device=device) if not convert_shape_only and device != 'meta': dst = ctypes.c_void_p(dst_tensor.data.data_ptr()) hist = (ctypes.c_int64 * 16)() if qtype not in [IQ2_XXS, IQ2_XS, Q2_K, IQ1_S, Q4_K, Q6_K, Q5_K, FP6_K]: - ggml.ggml_quantize_tensor(src, dst, qtype, n, k, hist, enable_scale_search) + if qtype in [SYM_INT8_RTN]: + scale_ptr = ctypes.cast(scale.data.data_ptr(), ctypes.POINTER(ctypes.c_float)) + ggml.ggml_quantize_tensor_rtn(src, dst, scale_ptr, qtype, n, + k, hist, enable_scale_search) + dst_tensor = dst_tensor.reshape_as(tensor) + return dst_tensor, scale.type(torch.float16) + else: + ggml.ggml_quantize_tensor(src, dst, qtype, n, k, hist, enable_scale_search) else: if imatrix is not None: # quantize with importance matrix diff --git a/python/llm/src/ipex_llm/transformers/npu_model.py b/python/llm/src/ipex_llm/transformers/npu_model.py index a941ae2984b..eb11bcef034 100644 --- a/python/llm/src/ipex_llm/transformers/npu_model.py +++ b/python/llm/src/ipex_llm/transformers/npu_model.py @@ -77,7 +77,7 @@ def from_pretrained(cls, from intel_npu_acceleration_library.dtypes import int8, int4 qtype_map = { 'sym_int4': int4, - 'sym_int8': int8, + 'sym_int8': "sym_int8_rtn", 'fp16': torch.half, 'fp32': torch.float, } @@ -119,9 +119,12 @@ def from_pretrained(cls, from intel_npu_acceleration_library.compiler import create_npu_kernels with torch.no_grad(): optimize_llm(model) - if not qtype.is_floating_point: - model = quantize_model(model, qtype) - create_npu_kernels(model) + if qtype == "sym_int8_rtn": + cls.load_convert(qtype, model, *args, **kwargs) + else: + if not qtype.is_floating_point: + model = quantize_model(model, qtype) + create_npu_kernels(model) model = model.eval() except ImportError as _e: # for intel_npu_acceleration_library < 1.1.0 @@ -133,6 +136,11 @@ def from_pretrained(cls, return model + @classmethod + def load_convert(cls, q_k, optimize_model, *arg, **kwarg): + from ipex_llm.transformers.npu_models.convert import replace_with_QuantizedLinear + replace_with_QuantizedLinear(optimize_model, q_k) + @staticmethod def save_low_bit(self, model_dir: str, *args, **kwargs): os.makedirs(model_dir, exist_ok=True) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/convert.py b/python/llm/src/ipex_llm/transformers/npu_models/convert.py index ba0918159fa..fc120c8efb4 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/convert.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/convert.py @@ -15,6 +15,49 @@ import torch +from intel_npu_acceleration_library.nn import QuantizedLinear + + +def module_optimization(func) -> torch.nn.Module: + """Optimize recursively a torch.nn.Module with a specific function. + + The function `func` get called recursively to every module in the network. + + Args: + func (Callable): optimization function + + Returns: + torch.nn.Module: optimized module + """ + + def wrapper(model: torch.nn.Module, qtype, *args, **kwargs): + """Recursively apply the optimization function. + + Args: + model (torch.nn.Module): original module + args (Any): positional arguments + kwargs (Any): keyword arguments + + """ + for name, layer in model.named_children(): + new_layer = func(layer, qtype, *args, **kwargs) + if new_layer: + model.add_module(name, new_layer) + wrapper(new_layer, qtype, *args, **kwargs) + else: + wrapper(layer, qtype, *args, **kwargs) + + return wrapper + + +@module_optimization +def replace_with_QuantizedLinear(layer, qtype): + from ipex_llm.transformers.low_bit_linear import ggml_convert_qtype + from ipex_llm.ggml.quantize import ggml_tensor_qtype + iqtype = ggml_tensor_qtype[qtype] + if isinstance(layer, torch.nn.Linear): + qweights, scale = ggml_convert_qtype(layer.weight.data, iqtype, 'cpu') + return QuantizedLinear(qweights, scale, layer.bias) def convert_forward(m, target_m, new_forward):