Skip to content

Commit

Permalink
[WIP] Support npu load_low_bit method (#11502)
Browse files Browse the repository at this point in the history
* npu_load_low_bit
  • Loading branch information
leonardozcm authored Jul 4, 2024
1 parent f079379 commit 57b8adb
Show file tree
Hide file tree
Showing 4 changed files with 440 additions and 28 deletions.
7 changes: 5 additions & 2 deletions python/llm/src/ipex_llm/transformers/low_bit_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
if qtype in [SYM_INT8_RTN, SYM_INT4_RTN]:
dst_tensor = torch.empty(dst_size, dtype=RTN_DTYPE[qtype],
device=device)
dst_tensor = dst_tensor.reshape(tensor.shape[0], tensor.shape[-1] // QK)
scale = torch.empty(n // k, dtype=torch.float32,
device=device)
else:
Expand All @@ -239,7 +240,6 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
scale_ptr = ctypes.cast(scale.data.data_ptr(), ctypes.POINTER(ctypes.c_float))
ggml.ggml_quantize_tensor_rtn(src, dst, scale_ptr, qtype, n,
k, hist, enable_scale_search)
dst_tensor = dst_tensor.reshape(tensor.shape[0], tensor.shape[-1] // QK)
return dst_tensor, scale.type(torch.float16)
else:
ggml.ggml_quantize_tensor(src, dst, qtype, n, k, hist, enable_scale_search)
Expand All @@ -252,7 +252,10 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
ggml.ggml_quantize_tensor_with_weights(src, dst, qtype,
n // in_features, in_features,
hist, imatrix)
return dst_tensor
if qtype in [SYM_INT8_RTN, SYM_INT4_RTN]:
return dst_tensor, scale.type(torch.float16)
else:
return dst_tensor


def ggml_q_format_convet_cpu2xpu(tensor: torch.Tensor, num_elem: int, qtype: int):
Expand Down
254 changes: 236 additions & 18 deletions python/llm/src/ipex_llm/transformers/npu_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@
#

import os
import copy
import types
import warnings
import torch
import transformers
from typing import List
from unittest.mock import patch
from transformers.dynamic_module_utils import get_imports
from transformers.configuration_utils import PretrainedConfig

import intel_npu_acceleration_library as npu_lib

Expand All @@ -44,6 +46,23 @@ def ignore_argument(kwargs: dict, key: 'str'):
warnings.warn(f"argument `{key}={arg}` will be ignored")


def save_low_bit(self, model_dir: str, *args, **kwargs):
origin_device = self.device
kwargs['safe_serialization'] = False
self.save_pretrained(model_dir, *args, **kwargs)
import json
import os
# We conveniently save all the keys of the model to have them on hand,
# so that when using 'low_cpumem load',
# it's not necessary to load the entire model to extract its keys
# and we can avoid gc not triggered potentially.
load_keys = {"all_checkpoint_keys": list(self.state_dict().keys())}
with open(os.path.join(model_dir, "load_keys.json"), "w") as json_file:
json.dump(load_keys, json_file)
if origin_device != 'cpu':
self.to(origin_device)


class _BaseAutoModelClass:
HF_MODEL = None

Expand Down Expand Up @@ -110,7 +129,18 @@ def from_pretrained(cls,
ignore_argument(kwargs, "speculative")
ignore_argument(kwargs, "pipeline_parallel_stages")

model = cls.HF_Model.from_pretrained(*args, **kwargs)
_args = copy.deepcopy(args)
_kwargs = copy.deepcopy(kwargs)
try:
# To handle the input CUDA setting (such as 'device_map={"":0}'), ignore it
kwargs.pop('device_map', None)
model = cls.HF_Model.from_pretrained(*args, **kwargs)
except NotImplementedError:
logger.info("Failed to load models with `low_cpu_mem_usage` specified, "
"will fall to traditional load method with higher memory consumption.")
_kwargs["low_cpu_mem_usage"] = False
model = cls.HF_Model.from_pretrained(*_args, **_kwargs)
model.config.update({"bigdl_lcmu_enabled": False})

logger.info(f"Converting model, it may takes up to several minutes ...")
try:
Expand All @@ -120,7 +150,7 @@ def from_pretrained(cls,
with torch.no_grad():
optimize_llm(model)
if qtype in ["sym_int8_rtn", "sym_int4_rtn"]:
cls.load_convert(qtype, model, *args, **kwargs)
cls.load_convert(qtype, model, 'cpu', *args, **kwargs)
else:
if not qtype.is_floating_point:
model = quantize_model(model, qtype)
Expand All @@ -131,27 +161,21 @@ def from_pretrained(cls,
model = npu_lib.compile(model, qtype, False)
logger.info(f"Finish to convert model")

model.config.update({"bigdl_transformers_low_bit": qtype})

# add save_low_bit to pretrained model dynamically
model.save_low_bit = types.MethodType(cls.save_low_bit, model)
model.save_low_bit = types.MethodType(save_low_bit, model)

return model

@classmethod
def load_convert(cls, q_k, optimize_model, *arg, **kwarg):
def load_convert(cls, q_k, optimize_model, device, *arg, **kwarg):
from ipex_llm.transformers.npu_models.convert import replace_with_QuantizedLinear
replace_with_QuantizedLinear(optimize_model, q_k)
replace_with_QuantizedLinear(optimize_model, q_k, device=device)

@staticmethod
def save_low_bit(self, model_dir: str, *args, **kwargs):
os.makedirs(model_dir, exist_ok=True)
model_name = "pytorch_npu_model.pt"
model_path = os.path.join(model_dir, model_name)
del self.save_low_bit # workaround a bug
torch.save(self, model_path)

@staticmethod
@classmethod
@patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import)
def load_low_bit(model_dir: str, *args, **kwargs):
def load_low_bit(cls, pretrained_model_name_or_path: str, *model_args, **kwargs):
if kwargs.pop('torch_dtype', None) not in [None, 'auto', torch.float]:
warnings.warn("`torch_dtype` will be ignored, `torch.float` will be used")

Expand All @@ -165,9 +189,203 @@ def load_low_bit(model_dir: str, *args, **kwargs):
ignore_argument(kwargs, "speculative")
ignore_argument(kwargs, "pipeline_parallel_stages")

model_name = "pytorch_npu_model.pt"
model_path = os.path.join(model_dir, model_name)
return torch.load(model_path)
from transformers.models.auto.configuration_auto import AutoConfig
from transformers.modeling_utils import no_init_weights, get_state_dict_dtype
from transformers.dynamic_module_utils import resolve_trust_remote_code, \
get_class_from_dynamic_module
from transformers.models.auto.auto_factory import _get_model_class
from transformers.utils.generic import ContextManagers
from transformers.generation.configuration_utils import GenerationConfig
from ipex_llm.transformers.utils import extract_local_archive_file, get_local_shard_files, \
load_state_dict
from accelerate.big_modeling import init_empty_weights

trust_remote_code = kwargs.pop("trust_remote_code", None)
kwargs_orig = copy.deepcopy(kwargs)

config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path,
return_unused_kwargs=True,
trust_remote_code=trust_remote_code,
**kwargs,
)

# if torch_dtype=auto was passed here, ensure to pass it on
if kwargs_orig.get("torch_dtype", None) == "auto":
kwargs["torch_dtype"] = "auto"

# Maybe needed when extract_local_archive_file
subfolder = kwargs.get("subfolder", "")
variant = kwargs.get("variant", None)
offload_folder = kwargs.pop("offload_folder", None)
offload_state_dict = kwargs.pop("offload_state_dict", False)
torch_dtype = kwargs.pop("torch_dtype", "auto")
sharded_metadata = None

config_dict, _ = PretrainedConfig.get_config_dict(pretrained_model_name_or_path)
qtype = config_dict.pop("bigdl_transformers_low_bit", False)
bigdl_lcmu_enabled = config_dict.pop("bigdl_lcmu_enabled", True)

invalidInputError(qtype,
"Detect this model is not a low-bit model, Please use from_pretrained"
" with load_in_4bit or load_in_low_bit to get a low-bit model , and "
" serialize the model using save_low_bit first.")

invalidInputError(qtype in ["sym_int8_rtn", "sym_int4_rtn"],
f"Unknown bigdl_transformers_low_bit value: {qtype},"
f" expected: sym_int4, asym_int4, sym_int5, asym_int5 or sym_int8.")

has_remote_code = hasattr(config, "auto_map") and cls.HF_Model.__name__ in config.auto_map
has_local_code = type(config) in cls.HF_Model._model_mapping.keys()
trust_remote_code = resolve_trust_remote_code(
trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code
)
if has_remote_code and trust_remote_code:
class_ref = config.auto_map[cls.HF_Model.__name__]
model_class = get_class_from_dynamic_module(
class_ref, pretrained_model_name_or_path, **kwargs
)
if os.path.isdir(pretrained_model_name_or_path):
model_class.register_for_auto_class(cls.HF_Model.__name__)
else:
cls.HF_Model.register(config.__class__, model_class, exist_ok=True)
elif type(config) in cls.HF_Model._model_mapping.keys():
model_class = _get_model_class(config, cls.HF_Model._model_mapping)

resolved_archive_file, is_sharded = extract_local_archive_file(
pretrained_model_name_or_path,
subfolder,
variant)

if is_sharded:
resolved_archive_file, sharded_metadata = \
get_local_shard_files(pretrained_model_name_or_path,
resolved_archive_file,
subfolder=subfolder)

# set dtype to instantiate the model under:
# 1. If torch_dtype is not None, we use that dtype
# 2. If torch_dtype is "auto", we auto-detect dtype from the loaded state_dict,
# by checking its first weights entry that is of a floating type
# - we assume all floating dtype weights are of the same dtype
# we also may have config.torch_dtype available, but we won't rely on it till v5
dtype_orig = None

if torch_dtype is not None:
if isinstance(torch_dtype, str):
if torch_dtype == "auto":
if hasattr(config, "torch_dtype") and config.torch_dtype is not None:
torch_dtype = config.torch_dtype

else:
if is_sharded and "dtype" in sharded_metadata:
torch_dtype = sharded_metadata["dtype"]
else:
one_state_dict = load_state_dict(resolved_archive_file[0])
torch_dtype = get_state_dict_dtype(one_state_dict)
del one_state_dict # free CPU memory
else:
invalidInputError(False,
f'`torch_dtype` can be either `torch.dtype` or `"auto"`,'
'but received {torch_dtype}')
dtype_orig = model_class._set_default_torch_dtype(torch_dtype)

# Pretrained Model
_fast_init = kwargs.pop("_fast_init", True)
init_contexts = [no_init_weights(_enable=_fast_init)]
init_contexts.append(init_empty_weights())

if bigdl_lcmu_enabled:
with ContextManagers(init_contexts):
if config.architectures is not None and config.architectures[0] in \
["ChatGLMModel", "ChatGLMForConditionalGeneration"]:

"""
ChatGLMModel uses skip_init by default, which will force modules placed on cpu
if the device is not specified. This will further cause replaced linear
allocating memory on cpu.
"""
kwargs["device"] = "meta"
model = model_class(config, *model_args, **kwargs)
else:
model = model_class(config, *model_args, **kwargs)

# Loading args may differ based on their usage
quant_device = "meta" if bigdl_lcmu_enabled else "cpu"
logger.info(f"Converting model, it may takes up to several minutes ...")
try:
# for intel_npu_acceleration_library >= 1.1.0
from intel_npu_acceleration_library.quantization import quantize_model
from intel_npu_acceleration_library.compiler import create_npu_kernels
with torch.no_grad():
optimize_llm(model)
if qtype in ["sym_int8_rtn", "sym_int4_rtn"]:
cls.load_convert(qtype, model, quant_device, *model_args, **kwargs)
else:
if not qtype.is_floating_point:
model = quantize_model(model, qtype)
create_npu_kernels(model)
model = model.eval()
except ImportError as _e:
# for intel_npu_acceleration_library < 1.1.0
model = npu_lib.compile(model, qtype, False)

if is_sharded:
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
else:
import os
import json
with open(os.path.join(pretrained_model_name_or_path,
"load_keys.json"), "r") as json_file:
loaded_data = json.load(json_file)
loaded_state_dict_keys = loaded_data["all_checkpoint_keys"]

# restore default dtype
if dtype_orig is not None:
torch.set_default_dtype(dtype_orig)

(
model,
missing_keys,
unexpected_keys,
mismatched_keys,
offload_index,
error_msgs,
) = model_class._load_pretrained_model(
model,
None,
loaded_state_dict_keys, # XXX: rename?
resolved_archive_file,
pretrained_model_name_or_path,
sharded_metadata=sharded_metadata,
_fast_init=False, # always false to avoid pre-init behaviors
low_cpu_mem_usage=bigdl_lcmu_enabled,
offload_folder=offload_folder,
offload_state_dict=offload_state_dict,
dtype=torch_dtype,
keep_in_fp32_modules=[],
)

# make sure token embedding weights are still tied if needed
model.tie_weights()

# Set model in evaluation mode to deactivate DropOut modules by default
model.eval()

# If it is a model with generation capabilities, attempt to load the generation config
if model.can_generate():
try:
model.generation_config = GenerationConfig.from_pretrained(
pretrained_model_name_or_path,
subfolder=subfolder,
**kwargs,
)
except (OSError, TypeError):
pass
for param in model.parameters():
param.requires_grad_(False)

return model


class AutoModelForCausalLM(_BaseAutoModelClass):
Expand Down
15 changes: 7 additions & 8 deletions python/llm/src/ipex_llm/transformers/npu_models/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@


import torch
import importlib
from intel_npu_acceleration_library.nn import QuantizedLinear
from ipex_llm.transformers.npu_models.linear import QuantizedLinear


def module_optimization(func) -> torch.nn.Module:
Expand All @@ -31,7 +30,7 @@ def module_optimization(func) -> torch.nn.Module:
torch.nn.Module: optimized module
"""

def wrapper(model: torch.nn.Module, qtype, *args, **kwargs):
def wrapper(model: torch.nn.Module, qtype, device, *args, **kwargs):
"""Recursively apply the optimization function.
Args:
Expand All @@ -41,23 +40,23 @@ def wrapper(model: torch.nn.Module, qtype, *args, **kwargs):
"""
for name, layer in model.named_children():
new_layer = func(layer, qtype, *args, **kwargs)
new_layer = func(layer, qtype, device, *args, **kwargs)
if new_layer:
model.add_module(name, new_layer)
wrapper(new_layer, qtype, *args, **kwargs)
wrapper(new_layer, qtype, device, *args, **kwargs)
else:
wrapper(layer, qtype, *args, **kwargs)
wrapper(layer, qtype, device, *args, **kwargs)

return wrapper


@module_optimization
def replace_with_QuantizedLinear(layer, qtype):
def replace_with_QuantizedLinear(layer, qtype, device):
from ipex_llm.transformers.low_bit_linear import ggml_convert_qtype
from ipex_llm.ggml.quantize import ggml_tensor_qtype
iqtype = ggml_tensor_qtype[qtype]
if isinstance(layer, torch.nn.Linear):
qweights, scale = ggml_convert_qtype(layer.weight.data, iqtype, 'cpu')
qweights, scale = ggml_convert_qtype(layer.weight.data, iqtype, device=device)
return QuantizedLinear(qweights, scale, layer.bias)


Expand Down
Loading

0 comments on commit 57b8adb

Please sign in to comment.