Skip to content

Commit

Permalink
Duplicate check v2
Browse files Browse the repository at this point in the history
  • Loading branch information
qiyuangong committed Jun 13, 2024
1 parent 5e25766 commit 58b2729
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 3 deletions.
4 changes: 3 additions & 1 deletion python/llm/src/ipex_llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@
if BIGDL_IMPORT_IPEX:
# Import Intel Extension for PyTorch as ipex if XPU version is installed
from .utils.ipex_importer import ipex_importer
ipex_importer.import_ipex()
# Avoid duplicate import
if ipex_importer.get_ipex_version() is None:
ipex_importer.import_ipex()

# Default is true, set to true to auto patching bigdl-llm to ipex_llm.
BIGDL_COMPATIBLE_MODE = os.getenv("BIGDL_COMPATIBLE_MODE", 'True').lower() in ('true', '1', 't')
Expand Down
27 changes: 25 additions & 2 deletions python/llm/src/ipex_llm/utils/ipex_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,35 @@
import builtins
import sys
from ipex_llm.utils.common import log4Error
import inspect


# Save the original __import__ function
original_import = builtins.__import__
RAW_IMPORT = builtins.__import__
ipex_duplicate_import_error = "intel_extension_for_pytorch has already been automatically " + \
"imported. Please avoid importing it again!"


def replace_import():
builtins.__import__ = custom_ipex_import


def revert_import():
builtins.__import__ = RAW_IMPORT


def custom_ipex_import(name, globals=None, locals=None, fromlist=(), level=0):
"""
Custom import function to avoid importing ipex again
"""
if fromlist is not None or '.' in name:
return RAW_IMPORT(name, globals, locals, fromlist, level)
# Only check ipex for main thread
if name == "ipex" or name == "intel_extension_for_pytorch":
log4Error.invalidInputError(False,
ipex_duplicate_import_error)
return RAW_IMPORT(name, globals, locals, fromlist, level)


class IPEXImporter:
"""
Auto import Intel Extension for PyTorch as ipex,
Expand Down Expand Up @@ -71,6 +92,8 @@ def import_ipex(self):
ipex_duplicate_import_error)
self.directly_import_ipex()
self.ipex_version = ipex.__version__
# Replace builtin import to avoid duplicate ipex import
replace_import()
logging.info("intel_extension_for_pytorch auto imported")

def directly_import_ipex(self):
Expand Down

0 comments on commit 58b2729

Please sign in to comment.