diff --git a/python/llm/src/ipex_llm/transformers/low_bit_linear.py b/python/llm/src/ipex_llm/transformers/low_bit_linear.py index 070395be0bf..e5d862f559c 100644 --- a/python/llm/src/ipex_llm/transformers/low_bit_linear.py +++ b/python/llm/src/ipex_llm/transformers/low_bit_linear.py @@ -225,6 +225,7 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int, if qtype in [SYM_INT8_RTN, SYM_INT4_RTN]: dst_tensor = torch.empty(dst_size, dtype=RTN_DTYPE[qtype], device=device) + dst_tensor = dst_tensor.reshape(tensor.shape[0], tensor.shape[-1] // QK) scale = torch.empty(n // k, dtype=torch.float32, device=device) else: @@ -239,7 +240,6 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int, scale_ptr = ctypes.cast(scale.data.data_ptr(), ctypes.POINTER(ctypes.c_float)) ggml.ggml_quantize_tensor_rtn(src, dst, scale_ptr, qtype, n, k, hist, enable_scale_search) - dst_tensor = dst_tensor.reshape(tensor.shape[0], tensor.shape[-1] // QK) return dst_tensor, scale.type(torch.float16) else: ggml.ggml_quantize_tensor(src, dst, qtype, n, k, hist, enable_scale_search) @@ -252,7 +252,10 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int, ggml.ggml_quantize_tensor_with_weights(src, dst, qtype, n // in_features, in_features, hist, imatrix) - return dst_tensor + if qtype in [SYM_INT8_RTN, SYM_INT4_RTN]: + return dst_tensor, scale.type(torch.float16) + else: + return dst_tensor def ggml_q_format_convet_cpu2xpu(tensor: torch.Tensor, num_elem: int, qtype: int): diff --git a/python/llm/src/ipex_llm/transformers/npu_model.py b/python/llm/src/ipex_llm/transformers/npu_model.py index cec844113d3..8d3afed64d9 100644 --- a/python/llm/src/ipex_llm/transformers/npu_model.py +++ b/python/llm/src/ipex_llm/transformers/npu_model.py @@ -15,6 +15,7 @@ # import os +import copy import types import warnings import torch @@ -22,6 +23,7 @@ from typing import List from unittest.mock import patch from transformers.dynamic_module_utils import get_imports +from transformers.configuration_utils import PretrainedConfig import intel_npu_acceleration_library as npu_lib @@ -44,6 +46,23 @@ def ignore_argument(kwargs: dict, key: 'str'): warnings.warn(f"argument `{key}={arg}` will be ignored") +def save_low_bit(self, model_dir: str, *args, **kwargs): + origin_device = self.device + kwargs['safe_serialization'] = False + self.save_pretrained(model_dir, *args, **kwargs) + import json + import os + # We conveniently save all the keys of the model to have them on hand, + # so that when using 'low_cpumem load', + # it's not necessary to load the entire model to extract its keys + # and we can avoid gc not triggered potentially. + load_keys = {"all_checkpoint_keys": list(self.state_dict().keys())} + with open(os.path.join(model_dir, "load_keys.json"), "w") as json_file: + json.dump(load_keys, json_file) + if origin_device != 'cpu': + self.to(origin_device) + + class _BaseAutoModelClass: HF_MODEL = None @@ -110,7 +129,18 @@ def from_pretrained(cls, ignore_argument(kwargs, "speculative") ignore_argument(kwargs, "pipeline_parallel_stages") - model = cls.HF_Model.from_pretrained(*args, **kwargs) + _args = copy.deepcopy(args) + _kwargs = copy.deepcopy(kwargs) + try: + # To handle the input CUDA setting (such as 'device_map={"":0}'), ignore it + kwargs.pop('device_map', None) + model = cls.HF_Model.from_pretrained(*args, **kwargs) + except NotImplementedError: + logger.info("Failed to load models with `low_cpu_mem_usage` specified, " + "will fall to traditional load method with higher memory consumption.") + _kwargs["low_cpu_mem_usage"] = False + model = cls.HF_Model.from_pretrained(*_args, **_kwargs) + model.config.update({"bigdl_lcmu_enabled": False}) logger.info(f"Converting model, it may takes up to several minutes ...") try: @@ -120,7 +150,7 @@ def from_pretrained(cls, with torch.no_grad(): optimize_llm(model) if qtype in ["sym_int8_rtn", "sym_int4_rtn"]: - cls.load_convert(qtype, model, *args, **kwargs) + cls.load_convert(qtype, model, 'cpu', *args, **kwargs) else: if not qtype.is_floating_point: model = quantize_model(model, qtype) @@ -131,27 +161,21 @@ def from_pretrained(cls, model = npu_lib.compile(model, qtype, False) logger.info(f"Finish to convert model") + model.config.update({"bigdl_transformers_low_bit": qtype}) + # add save_low_bit to pretrained model dynamically - model.save_low_bit = types.MethodType(cls.save_low_bit, model) + model.save_low_bit = types.MethodType(save_low_bit, model) return model @classmethod - def load_convert(cls, q_k, optimize_model, *arg, **kwarg): + def load_convert(cls, q_k, optimize_model, device, *arg, **kwarg): from ipex_llm.transformers.npu_models.convert import replace_with_QuantizedLinear - replace_with_QuantizedLinear(optimize_model, q_k) + replace_with_QuantizedLinear(optimize_model, q_k, device=device) - @staticmethod - def save_low_bit(self, model_dir: str, *args, **kwargs): - os.makedirs(model_dir, exist_ok=True) - model_name = "pytorch_npu_model.pt" - model_path = os.path.join(model_dir, model_name) - del self.save_low_bit # workaround a bug - torch.save(self, model_path) - - @staticmethod + @classmethod @patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import) - def load_low_bit(model_dir: str, *args, **kwargs): + def load_low_bit(cls, pretrained_model_name_or_path: str, *model_args, **kwargs): if kwargs.pop('torch_dtype', None) not in [None, 'auto', torch.float]: warnings.warn("`torch_dtype` will be ignored, `torch.float` will be used") @@ -165,9 +189,203 @@ def load_low_bit(model_dir: str, *args, **kwargs): ignore_argument(kwargs, "speculative") ignore_argument(kwargs, "pipeline_parallel_stages") - model_name = "pytorch_npu_model.pt" - model_path = os.path.join(model_dir, model_name) - return torch.load(model_path) + from transformers.models.auto.configuration_auto import AutoConfig + from transformers.modeling_utils import no_init_weights, get_state_dict_dtype + from transformers.dynamic_module_utils import resolve_trust_remote_code, \ + get_class_from_dynamic_module + from transformers.models.auto.auto_factory import _get_model_class + from transformers.utils.generic import ContextManagers + from transformers.generation.configuration_utils import GenerationConfig + from ipex_llm.transformers.utils import extract_local_archive_file, get_local_shard_files, \ + load_state_dict + from accelerate.big_modeling import init_empty_weights + + trust_remote_code = kwargs.pop("trust_remote_code", None) + kwargs_orig = copy.deepcopy(kwargs) + + config, kwargs = AutoConfig.from_pretrained( + pretrained_model_name_or_path, + return_unused_kwargs=True, + trust_remote_code=trust_remote_code, + **kwargs, + ) + + # if torch_dtype=auto was passed here, ensure to pass it on + if kwargs_orig.get("torch_dtype", None) == "auto": + kwargs["torch_dtype"] = "auto" + + # Maybe needed when extract_local_archive_file + subfolder = kwargs.get("subfolder", "") + variant = kwargs.get("variant", None) + offload_folder = kwargs.pop("offload_folder", None) + offload_state_dict = kwargs.pop("offload_state_dict", False) + torch_dtype = kwargs.pop("torch_dtype", "auto") + sharded_metadata = None + + config_dict, _ = PretrainedConfig.get_config_dict(pretrained_model_name_or_path) + qtype = config_dict.pop("bigdl_transformers_low_bit", False) + bigdl_lcmu_enabled = config_dict.pop("bigdl_lcmu_enabled", True) + + invalidInputError(qtype, + "Detect this model is not a low-bit model, Please use from_pretrained" + " with load_in_4bit or load_in_low_bit to get a low-bit model , and " + " serialize the model using save_low_bit first.") + + invalidInputError(qtype in ["sym_int8_rtn", "sym_int4_rtn"], + f"Unknown bigdl_transformers_low_bit value: {qtype}," + f" expected: sym_int4, asym_int4, sym_int5, asym_int5 or sym_int8.") + + has_remote_code = hasattr(config, "auto_map") and cls.HF_Model.__name__ in config.auto_map + has_local_code = type(config) in cls.HF_Model._model_mapping.keys() + trust_remote_code = resolve_trust_remote_code( + trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code + ) + if has_remote_code and trust_remote_code: + class_ref = config.auto_map[cls.HF_Model.__name__] + model_class = get_class_from_dynamic_module( + class_ref, pretrained_model_name_or_path, **kwargs + ) + if os.path.isdir(pretrained_model_name_or_path): + model_class.register_for_auto_class(cls.HF_Model.__name__) + else: + cls.HF_Model.register(config.__class__, model_class, exist_ok=True) + elif type(config) in cls.HF_Model._model_mapping.keys(): + model_class = _get_model_class(config, cls.HF_Model._model_mapping) + + resolved_archive_file, is_sharded = extract_local_archive_file( + pretrained_model_name_or_path, + subfolder, + variant) + + if is_sharded: + resolved_archive_file, sharded_metadata = \ + get_local_shard_files(pretrained_model_name_or_path, + resolved_archive_file, + subfolder=subfolder) + + # set dtype to instantiate the model under: + # 1. If torch_dtype is not None, we use that dtype + # 2. If torch_dtype is "auto", we auto-detect dtype from the loaded state_dict, + # by checking its first weights entry that is of a floating type + # - we assume all floating dtype weights are of the same dtype + # we also may have config.torch_dtype available, but we won't rely on it till v5 + dtype_orig = None + + if torch_dtype is not None: + if isinstance(torch_dtype, str): + if torch_dtype == "auto": + if hasattr(config, "torch_dtype") and config.torch_dtype is not None: + torch_dtype = config.torch_dtype + + else: + if is_sharded and "dtype" in sharded_metadata: + torch_dtype = sharded_metadata["dtype"] + else: + one_state_dict = load_state_dict(resolved_archive_file[0]) + torch_dtype = get_state_dict_dtype(one_state_dict) + del one_state_dict # free CPU memory + else: + invalidInputError(False, + f'`torch_dtype` can be either `torch.dtype` or `"auto"`,' + 'but received {torch_dtype}') + dtype_orig = model_class._set_default_torch_dtype(torch_dtype) + + # Pretrained Model + _fast_init = kwargs.pop("_fast_init", True) + init_contexts = [no_init_weights(_enable=_fast_init)] + init_contexts.append(init_empty_weights()) + + if bigdl_lcmu_enabled: + with ContextManagers(init_contexts): + if config.architectures is not None and config.architectures[0] in \ + ["ChatGLMModel", "ChatGLMForConditionalGeneration"]: + + """ + ChatGLMModel uses skip_init by default, which will force modules placed on cpu + if the device is not specified. This will further cause replaced linear + allocating memory on cpu. + """ + kwargs["device"] = "meta" + model = model_class(config, *model_args, **kwargs) + else: + model = model_class(config, *model_args, **kwargs) + + # Loading args may differ based on their usage + quant_device = "meta" if bigdl_lcmu_enabled else "cpu" + logger.info(f"Converting model, it may takes up to several minutes ...") + try: + # for intel_npu_acceleration_library >= 1.1.0 + from intel_npu_acceleration_library.quantization import quantize_model + from intel_npu_acceleration_library.compiler import create_npu_kernels + with torch.no_grad(): + optimize_llm(model) + if qtype in ["sym_int8_rtn", "sym_int4_rtn"]: + cls.load_convert(qtype, model, quant_device, *model_args, **kwargs) + else: + if not qtype.is_floating_point: + model = quantize_model(model, qtype) + create_npu_kernels(model) + model = model.eval() + except ImportError as _e: + # for intel_npu_acceleration_library < 1.1.0 + model = npu_lib.compile(model, qtype, False) + + if is_sharded: + loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"] + else: + import os + import json + with open(os.path.join(pretrained_model_name_or_path, + "load_keys.json"), "r") as json_file: + loaded_data = json.load(json_file) + loaded_state_dict_keys = loaded_data["all_checkpoint_keys"] + + # restore default dtype + if dtype_orig is not None: + torch.set_default_dtype(dtype_orig) + + ( + model, + missing_keys, + unexpected_keys, + mismatched_keys, + offload_index, + error_msgs, + ) = model_class._load_pretrained_model( + model, + None, + loaded_state_dict_keys, # XXX: rename? + resolved_archive_file, + pretrained_model_name_or_path, + sharded_metadata=sharded_metadata, + _fast_init=False, # always false to avoid pre-init behaviors + low_cpu_mem_usage=bigdl_lcmu_enabled, + offload_folder=offload_folder, + offload_state_dict=offload_state_dict, + dtype=torch_dtype, + keep_in_fp32_modules=[], + ) + + # make sure token embedding weights are still tied if needed + model.tie_weights() + + # Set model in evaluation mode to deactivate DropOut modules by default + model.eval() + + # If it is a model with generation capabilities, attempt to load the generation config + if model.can_generate(): + try: + model.generation_config = GenerationConfig.from_pretrained( + pretrained_model_name_or_path, + subfolder=subfolder, + **kwargs, + ) + except (OSError, TypeError): + pass + for param in model.parameters(): + param.requires_grad_(False) + + return model class AutoModelForCausalLM(_BaseAutoModelClass): diff --git a/python/llm/src/ipex_llm/transformers/npu_models/convert.py b/python/llm/src/ipex_llm/transformers/npu_models/convert.py index 8103df5cf6e..fcd70595baf 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/convert.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/convert.py @@ -15,8 +15,7 @@ import torch -import importlib -from intel_npu_acceleration_library.nn import QuantizedLinear +from ipex_llm.transformers.npu_models.linear import QuantizedLinear def module_optimization(func) -> torch.nn.Module: @@ -31,7 +30,7 @@ def module_optimization(func) -> torch.nn.Module: torch.nn.Module: optimized module """ - def wrapper(model: torch.nn.Module, qtype, *args, **kwargs): + def wrapper(model: torch.nn.Module, qtype, device, *args, **kwargs): """Recursively apply the optimization function. Args: @@ -41,23 +40,23 @@ def wrapper(model: torch.nn.Module, qtype, *args, **kwargs): """ for name, layer in model.named_children(): - new_layer = func(layer, qtype, *args, **kwargs) + new_layer = func(layer, qtype, device, *args, **kwargs) if new_layer: model.add_module(name, new_layer) - wrapper(new_layer, qtype, *args, **kwargs) + wrapper(new_layer, qtype, device, *args, **kwargs) else: - wrapper(layer, qtype, *args, **kwargs) + wrapper(layer, qtype, device, *args, **kwargs) return wrapper @module_optimization -def replace_with_QuantizedLinear(layer, qtype): +def replace_with_QuantizedLinear(layer, qtype, device): from ipex_llm.transformers.low_bit_linear import ggml_convert_qtype from ipex_llm.ggml.quantize import ggml_tensor_qtype iqtype = ggml_tensor_qtype[qtype] if isinstance(layer, torch.nn.Linear): - qweights, scale = ggml_convert_qtype(layer.weight.data, iqtype, 'cpu') + qweights, scale = ggml_convert_qtype(layer.weight.data, iqtype, device=device) return QuantizedLinear(qweights, scale, layer.bias) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/linear.py b/python/llm/src/ipex_llm/transformers/npu_models/linear.py new file mode 100644 index 00000000000..9c9022e787e --- /dev/null +++ b/python/llm/src/ipex_llm/transformers/npu_models/linear.py @@ -0,0 +1,192 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file is adapted from +# https://github.com/intel/intel-npu-acceleration-library/blob/main/intel_npu_acceleration_library/nn/linear.py + +# +# Copyright © 2024 Intel Corporation +# SPDX-License-Identifier: Apache 2.0 +# + +from intel_npu_acceleration_library.quantization import quantize_tensor, compress_to_i4 +from intel_npu_acceleration_library.nn.autograd import AutogradMatMul +from intel_npu_acceleration_library.backend import run_matmul +from intel_npu_acceleration_library.dtypes import NPUDtype +from typing import Optional, Union +import torch +from torch.nn import Parameter +import uuid +import math + +from ipex_llm.utils.common import invalidInputError + + +class Linear(torch.nn.Module): + """Torch Linear operation NPU backend.""" + + def __init__(self, weight: torch.Tensor, bias: Optional[torch.Tensor] = None): + """Initialize the Linear class. + + Args: + weight (torch.Tensor): Linear operation weight + bias (Optional[torch.Tensor], optional): Linear operation optional bias. + Defaults to None. + """ + super().__init__() + + self.weight = torch.nn.Parameter(weight) + self.bias = torch.nn.Parameter(bias) if isinstance(bias, torch.Tensor) else None + self.outC, self.inC = self.weight.shape + self.op_id = str(uuid.uuid4()) + self._mm = AutogradMatMul.apply + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Torch module forward method. + + Args: + x (torch.Tensor): Input tensor + + Returns: + torch.Tensor: result + """ + if self.training: + out = self._mm(x, self.weight, None) + else: + out = run_matmul(x, self.weight, None, self.op_id) + + if self.bias is None: + return out + return out + self.bias + + @staticmethod + def fromTorch( + layer: torch.nn.Linear, dtype: torch.dtype = torch.float16 + ) -> Union["Linear", "QuantizedLinear"]: + """Generate a NPU Linear layer from a torch one. + + Args: + layer (torch.nn.Linear): the original torch.nn.Linear model to run on the NPU + dtype (torch.dtype): the desired datatype + + Returns: + Union[Linear, QuantizedLinear]: A NPU linear layer + """ + if any(dim > 2**17 for dim in layer.weight.shape): + return layer + return Linear.fromTensor(layer.weight, getattr(layer, "bias", None), dtype) + + @staticmethod + def fromTensor( + weight: torch.Tensor, + bias: Optional[torch.Tensor], + dtype: torch.dtype = torch.float16, + ) -> Union["Linear", "QuantizedLinear"]: + """Generate a NPU Linear layer from a torch one. + + Args: + weight (torch.Tensor): the original weight tensor + bias (Optional[torch.Tensor]): the original bias tensor + dtype (torch.dtype): the desired datatype + + Raises: + RuntimeError: dtype not supported + + Returns: + Union[Linear, QuantizedLinear]: A NPU linear layer + """ + if dtype.is_floating_point: + if bias is None: + return Linear(weight.to(dtype), None) + return Linear(weight.to(dtype), bias.to(dtype)) + elif isinstance(dtype, NPUDtype): + weights_quant, scale = quantize_tensor(weight, (dtype.min, dtype.max)) + if dtype.bits == 4: + weights_quant = compress_to_i4(weights_quant) + return QuantizedLinear(weights_quant, scale, bias) + elif dtype == torch.int8: + weights_quant, scale = quantize_tensor(weight) + return QuantizedLinear(weights_quant, scale, bias) + else: + invalidInputError(False, + f"NPU do not support yet the requeste datatype: {dtype}") + + +class QuantizedLinear(torch.nn.Module): + """Torch Quantized Linear operation NPU backend.""" + + def __init__( + self, + weight: torch.Tensor, + scale: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ): + """Initialize the QuantizedLinear class. + + Args: + weight (torch.Tensor): Linear operation weight + scale (torch.Tensor): Quantization scale + bias (Optional[torch.Tensor], optional): Linear operation optional bias. + Defaults to None. + + Raises: + RuntimeError: Quantized weight must be in torch.int8 format + """ + super().__init__() + + self.weight = Parameter(weight, requires_grad=False) + if self.weight.dtype not in (torch.int8, torch.uint8): + invalidInputError( + False, + ( + f"Quantized weight must be in torch.(u)int8" + " dtype instead of {self.weight.dtype}" + ) + ) + self.outC, self.inC = self.weight.shape + if self.weight.dtype == torch.uint8: + # In case is Int4 we need to double the input channels because weights are compressed + self.inC *= 2 + self.scale = Parameter(scale * math.sqrt(self.inC), requires_grad=False) + self.bias = bias + self.op_id = str(uuid.uuid4()) + self._mm = AutogradMatMul.apply + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Torch module forward method. + + Args: + x (torch.Tensor): Input tensor + + Raises: + RuntimeError: Training is not supported for QuantizedLinear layer. + Use `.eval()` to do inference only + + Returns: + torch.Tensor: result + """ + if self.training: + invalidInputError( + False, + ( + "Training is not supported for QuantizedLinear layer." + "Use `.eval()` to do inference only" + ) + ) + out = run_matmul(x, self.weight.data, self.scale.data, self.op_id) + + if self.bias is None: + return out + return out + self.bias