Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IPEX Duplicate importer V2 #11310

Merged
merged 10 commits into from
Jun 19, 2024
6 changes: 4 additions & 2 deletions python/llm/src/ipex_llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,16 @@
import sys
import types

# Default is false, set to true to auto importing Intel Extension for PyTorch.
# Default is True, set to False to disable auto importing Intel Extension for PyTorch.
USE_NPU = os.getenv("BIGDL_USE_NPU", 'False').lower() in ('true', '1', 't')
BIGDL_IMPORT_IPEX = os.getenv("BIGDL_IMPORT_IPEX", 'True').lower() in ('true', '1', 't')
BIGDL_IMPORT_IPEX = not USE_NPU and BIGDL_IMPORT_IPEX
if BIGDL_IMPORT_IPEX:
# Import Intel Extension for PyTorch as ipex if XPU version is installed
from .utils.ipex_importer import ipex_importer
ipex_importer.import_ipex()
# Avoid duplicate import
if ipex_importer.get_ipex_version() is None:
ipex_importer.import_ipex()

# Default is true, set to true to auto patching bigdl-llm to ipex_llm.
BIGDL_COMPATIBLE_MODE = os.getenv("BIGDL_COMPATIBLE_MODE", 'True').lower() in ('true', '1', 't')
Expand Down
3 changes: 3 additions & 0 deletions python/llm/src/ipex_llm/transformers/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,6 +773,9 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
f"{list(gguf_mixed_qtype.keys())[index]} "
f"format......")
modules_to_not_convert = [] if modules_to_not_convert is None else modules_to_not_convert
# Disable ipex duplicate import checker
from ipex_llm.utils.ipex_importer import revert_import
revert_import()

# using ipex_llm optimizer before changing to bigdl linear
_enable_ipex = get_enable_ipex()
Expand Down
3 changes: 3 additions & 0 deletions python/llm/src/ipex_llm/transformers/gguf/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@

def load_gguf_model(fpath: str, dtype: torch.dtype = torch.float, low_bit: str = "sym_int4"):
from .gguf import GGUFFileLoader
# Disable ipex duplicate import checker
from ipex_llm.utils.ipex_importer import revert_import
revert_import()

loader = GGUFFileLoader(fpath)
model_family = loader.config["general.architecture"]
Expand Down
67 changes: 64 additions & 3 deletions python/llm/src/ipex_llm/utils/ipex_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,73 @@
import logging
import builtins
import sys
from ipex_llm.utils.common import log4Error
import os
import inspect
from ipex_llm.utils.common import log4Error

# Save the original __import__ function
original_import = builtins.__import__

# Default is True, set to False to disable IPEX duplicate checker
BIGDL_CHECK_DUPLICATE_IMPORT = os.getenv("BIGDL_CHECK_DUPLICATE_IMPORT",
'True').lower() in ('true', '1', 't')
RAW_IMPORT = None
IS_IMPORT_REPLACED = False
ipex_duplicate_import_error = "intel_extension_for_pytorch has already been automatically " + \
"imported. Please avoid importing it again!"


def replace_import():
global RAW_IMPORT, IS_IMPORT_REPLACED
# Avoid multiple replacement
if not IS_IMPORT_REPLACED and RAW_IMPORT is None:
# Save the original __import__ function
RAW_IMPORT = builtins.__import__
builtins.__import__ = custom_ipex_import
IS_IMPORT_REPLACED = True


def revert_import():
if not BIGDL_CHECK_DUPLICATE_IMPORT:
return
global RAW_IMPORT, IS_IMPORT_REPLACED
# Only revert once
if RAW_IMPORT is not None and IS_IMPORT_REPLACED:
builtins.__import__ = RAW_IMPORT
IS_IMPORT_REPLACED = False


def get_calling_package():
"""
Return calling package name, e.g., ipex_llm.transformers
"""
# Get the current stack frame
frame = inspect.currentframe()
# Get the caller's frame
caller_frame = frame.f_back.f_back
# Get the caller's module
module = inspect.getmodule(caller_frame)
if module:
# Return the module's package name
return module.__package__
return None


def custom_ipex_import(name, globals=None, locals=None, fromlist=(), level=0):
"""
Custom import function to avoid importing ipex again
"""
if fromlist is not None or '.' in name:
return RAW_IMPORT(name, globals, locals, fromlist, level)
# Avoid errors in submodule import
calling = get_calling_package()
if calling is not None:
return RAW_IMPORT(name, globals, locals, fromlist, level)
# Only check ipex for main thread
if name == "ipex" or name == "intel_extension_for_pytorch":
log4Error.invalidInputError(False,
ipex_duplicate_import_error)
return RAW_IMPORT(name, globals, locals, fromlist, level)


class IPEXImporter:
"""
Auto import Intel Extension for PyTorch as ipex,
Expand Down Expand Up @@ -71,6 +129,9 @@ def import_ipex(self):
ipex_duplicate_import_error)
self.directly_import_ipex()
self.ipex_version = ipex.__version__
# Replace builtin import to avoid duplicate ipex import
if BIGDL_CHECK_DUPLICATE_IMPORT:
replace_import()
logging.info("intel_extension_for_pytorch auto imported")

def directly_import_ipex(self):
Expand Down
Loading