diff --git a/python/llm/src/ipex_llm/transformers/model.py b/python/llm/src/ipex_llm/transformers/model.py index 432a1d0e396..f81ee840942 100644 --- a/python/llm/src/ipex_llm/transformers/model.py +++ b/python/llm/src/ipex_llm/transformers/model.py @@ -44,7 +44,6 @@ from typing import List from unittest.mock import patch from transformers.configuration_utils import PretrainedConfig -from transformers.dynamic_module_utils import get_imports from ipex_llm.ggml.quantize import ggml_tensor_qtype, gguf_mixed_qtype from ipex_llm.utils.common import invalidInputError @@ -115,7 +114,7 @@ class _BaseAutoModelClass: @classmethod @patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import) - @patch("transformers.utils.is_torch_sdpa_available", patch_sdpa_available) + @patch("transformers.utils.is_torch_sdpa_available", patch_sdpa_available, create=True) def from_pretrained(cls, *args, **kwargs): @@ -543,7 +542,7 @@ def load_convert(cls, q_k, optimize_model, *args, **kwargs): @classmethod @patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import) - @patch("transformers.utils.is_torch_sdpa_available", patch_sdpa_available) + @patch("transformers.utils.is_torch_sdpa_available", patch_sdpa_available, create=True) def load_low_bit(cls, pretrained_model_name_or_path, *model_args, diff --git a/python/llm/src/ipex_llm/transformers/patches.py b/python/llm/src/ipex_llm/transformers/patches.py index e733910498a..f115ffa5402 100644 --- a/python/llm/src/ipex_llm/transformers/patches.py +++ b/python/llm/src/ipex_llm/transformers/patches.py @@ -16,6 +16,7 @@ # from typing import List +from transformers.dynamic_module_utils import get_imports def patch_flash_attn_import(filename: str) -> List[str]: