From 58b272922fee450726b9313f64c0ce49aeaa703e Mon Sep 17 00:00:00 2001 From: Qiyuan Gong Date: Thu, 13 Jun 2024 17:53:03 +0800 Subject: [PATCH] Duplicate check v2 --- python/llm/src/ipex_llm/__init__.py | 4 ++- .../llm/src/ipex_llm/utils/ipex_importer.py | 27 +++++++++++++++++-- 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/python/llm/src/ipex_llm/__init__.py b/python/llm/src/ipex_llm/__init__.py index 748a1b3d74f6..43f64690ac6e 100644 --- a/python/llm/src/ipex_llm/__init__.py +++ b/python/llm/src/ipex_llm/__init__.py @@ -31,7 +31,9 @@ if BIGDL_IMPORT_IPEX: # Import Intel Extension for PyTorch as ipex if XPU version is installed from .utils.ipex_importer import ipex_importer - ipex_importer.import_ipex() + # Avoid duplicate import + if ipex_importer.get_ipex_version() is None: + ipex_importer.import_ipex() # Default is true, set to true to auto patching bigdl-llm to ipex_llm. BIGDL_COMPATIBLE_MODE = os.getenv("BIGDL_COMPATIBLE_MODE", 'True').lower() in ('true', '1', 't') diff --git a/python/llm/src/ipex_llm/utils/ipex_importer.py b/python/llm/src/ipex_llm/utils/ipex_importer.py index 3fb543af6d43..793a66c864b5 100644 --- a/python/llm/src/ipex_llm/utils/ipex_importer.py +++ b/python/llm/src/ipex_llm/utils/ipex_importer.py @@ -19,14 +19,35 @@ import builtins import sys from ipex_llm.utils.common import log4Error -import inspect + # Save the original __import__ function -original_import = builtins.__import__ +RAW_IMPORT = builtins.__import__ ipex_duplicate_import_error = "intel_extension_for_pytorch has already been automatically " + \ "imported. Please avoid importing it again!" +def replace_import(): + builtins.__import__ = custom_ipex_import + + +def revert_import(): + builtins.__import__ = RAW_IMPORT + + +def custom_ipex_import(name, globals=None, locals=None, fromlist=(), level=0): + """ + Custom import function to avoid importing ipex again + """ + if fromlist is not None or '.' in name: + return RAW_IMPORT(name, globals, locals, fromlist, level) + # Only check ipex for main thread + if name == "ipex" or name == "intel_extension_for_pytorch": + log4Error.invalidInputError(False, + ipex_duplicate_import_error) + return RAW_IMPORT(name, globals, locals, fromlist, level) + + class IPEXImporter: """ Auto import Intel Extension for PyTorch as ipex, @@ -71,6 +92,8 @@ def import_ipex(self): ipex_duplicate_import_error) self.directly_import_ipex() self.ipex_version = ipex.__version__ + # Replace builtin import to avoid duplicate ipex import + replace_import() logging.info("intel_extension_for_pytorch auto imported") def directly_import_ipex(self):