Skip to content

Commit

Permalink
clean branch
Browse files Browse the repository at this point in the history
  • Loading branch information
leonardozcm committed Jul 5, 2024
1 parent 72b4efa commit 11b8246
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 49 deletions.
5 changes: 5 additions & 0 deletions python/llm/src/ipex_llm/transformers/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,11 @@ def from_pretrained(cls,
else:
kwargs["pretraining_tp"] = 1
q_k = load_in_low_bit if load_in_low_bit else "sym_int4"

invalidInputError(q_k not in ["sym_int4_rtn", "sym_int8_rtn"],
f"The dtype {q_k} is specified for NPU"
"and cannot be used on CPU and GPU")

imatrix_file = kwargs.pop("imatrix", None)
if q_k in ["gguf_iq2_xxs", "gguf_iq2_xs", "gguf_iq1_s"]:
invalidInputError(imatrix_file is not None,
Expand Down
67 changes: 18 additions & 49 deletions python/llm/src/ipex_llm/transformers/npu_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,23 +90,12 @@ def from_pretrained(cls,
warnings.warn("`torch_dtype` will be ignored, `torch.float` will be used")
kwargs['torch_dtype'] = torch.float

low_bit = kwargs.pop('load_in_low_bit', 'fp32')
try:
# for intel_npu_acceleration_library >= 1.1.0
from intel_npu_acceleration_library.dtypes import int8, int4
qtype_map = {
'sym_int4': "sym_int4_rtn",
'sym_int8': "sym_int8_rtn",
'fp16': torch.half,
'fp32': torch.float,
}
except ImportError as _e:
# for intel_npu_acceleration_library < 1.1.0
qtype_map = {
'sym_int8': torch.int8,
'fp16': torch.half,
'fp32': torch.float,
}
low_bit = kwargs.pop('load_in_low_bit', 'sym_int4')
qtype_map = {
'sym_int4': "sym_int4_rtn",
'sym_int8': "sym_int8_rtn",
}

invalidInputError(low_bit in qtype_map.keys(),
f"unsupported low_bit: {low_bit}, "
f"only {list(qtype_map.keys())} are supported")
Expand Down Expand Up @@ -143,22 +132,13 @@ def from_pretrained(cls,
model.config.update({"bigdl_lcmu_enabled": False})

logger.info(f"Converting model, it may takes up to several minutes ...")
try:
# for intel_npu_acceleration_library >= 1.1.0
from intel_npu_acceleration_library.quantization import quantize_model
from intel_npu_acceleration_library.compiler import create_npu_kernels
with torch.no_grad():
optimize_llm(model)
if qtype in ["sym_int8_rtn", "sym_int4_rtn"]:
cls.load_convert(qtype, model, 'cpu', *args, **kwargs)
else:
if not qtype.is_floating_point:
model = quantize_model(model, qtype)
create_npu_kernels(model)
model = model.eval()
except ImportError as _e:
# for intel_npu_acceleration_library < 1.1.0
model = npu_lib.compile(model, qtype, False)

with torch.no_grad():
optimize_llm(model)
cls.load_convert(qtype, model, 'cpu', *args, **kwargs)

model = model.eval()

logger.info(f"Finish to convert model")

model.config.update({"bigdl_transformers_low_bit": qtype})
Expand Down Expand Up @@ -313,22 +293,11 @@ def load_low_bit(cls, pretrained_model_name_or_path: str, *model_args, **kwargs)
# Loading args may differ based on their usage
quant_device = "meta" if bigdl_lcmu_enabled else "cpu"
logger.info(f"Converting model, it may takes up to several minutes ...")
try:
# for intel_npu_acceleration_library >= 1.1.0
from intel_npu_acceleration_library.quantization import quantize_model
from intel_npu_acceleration_library.compiler import create_npu_kernels
with torch.no_grad():
optimize_llm(model)
if qtype in ["sym_int8_rtn", "sym_int4_rtn"]:
cls.load_convert(qtype, model, quant_device, *model_args, **kwargs)
else:
if not qtype.is_floating_point:
model = quantize_model(model, qtype)
create_npu_kernels(model)
model = model.eval()
except ImportError as _e:
# for intel_npu_acceleration_library < 1.1.0
model = npu_lib.compile(model, qtype, False)
with torch.no_grad():
optimize_llm(model)
cls.load_convert(qtype, model, quant_device, *model_args, **kwargs)

model = model.eval()

if is_sharded:
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
Expand Down

0 comments on commit 11b8246

Please sign in to comment.