Skip to content

Commit

Permalink
Patch sdpa check function in specific module attributes table (intel-…
Browse files Browse the repository at this point in the history
  • Loading branch information
leonardozcm authored Oct 29, 2024
1 parent 3700e81 commit 546f455
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
4 changes: 2 additions & 2 deletions python/llm/src/ipex_llm/transformers/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ class _BaseAutoModelClass:

@classmethod
@patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import)
@patch("transformers.utils.is_torch_sdpa_available", patch_sdpa_available, create=True)
@patch("transformers.modeling_utils.is_torch_sdpa_available", patch_sdpa_available, create=True)
def from_pretrained(cls,
*args,
**kwargs):
Expand Down Expand Up @@ -542,7 +542,7 @@ def load_convert(cls, q_k, optimize_model, *args, **kwargs):

@classmethod
@patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import)
@patch("transformers.utils.is_torch_sdpa_available", patch_sdpa_available, create=True)
@patch("transformers.modeling_utils.is_torch_sdpa_available", patch_sdpa_available, create=True)
def load_low_bit(cls,
pretrained_model_name_or_path,
*model_args,
Expand Down
10 changes: 9 additions & 1 deletion python/llm/src/ipex_llm/transformers/patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from typing import List
from transformers.dynamic_module_utils import get_imports
from ipex_llm.utils.ipex_importer import IPEXImporter


def patch_flash_attn_import(filename: str) -> List[str]:
Expand All @@ -28,4 +29,11 @@ def patch_flash_attn_import(filename: str) -> List[str]:


def patch_sdpa_available() -> bool:
return False
if IPEXImporter.is_xpu_version_installed():
return False
else:
try:
from transformers.utils import is_torch_sdpa_available
return is_torch_sdpa_available()
except ImportError:
return False

0 comments on commit 546f455

Please sign in to comment.