Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transformers 4.48 #2158

Merged
merged 26 commits into from
Jan 29, 2025
Merged
Changes from 1 commit
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
5190280
test
IlyasMoutawwakil Jan 16, 2025
6a03d76
testing tensor cache x)
IlyasMoutawwakil Jan 20, 2025
7207215
fix logger
IlyasMoutawwakil Jan 20, 2025
6261094
condition cache class usage
IlyasMoutawwakil Jan 20, 2025
822066d
update opset for beit and data2vec vision and skip flattened/fused pk…
IlyasMoutawwakil Jan 20, 2025
3ab38fd
style
IlyasMoutawwakil Jan 20, 2025
d713e5a
fix args patcher
IlyasMoutawwakil Jan 20, 2025
bf4d1f3
fix modernbert testing
IlyasMoutawwakil Jan 20, 2025
230c3a0
adaot to new whisper returned generation length
IlyasMoutawwakil Jan 20, 2025
3d5d9c9
fix is_causal in transformers
IlyasMoutawwakil Jan 20, 2025
96e2714
fix modernbert failures
IlyasMoutawwakil Jan 20, 2025
78a2dba
style
IlyasMoutawwakil Jan 20, 2025
967c6e2
traceable cache
IlyasMoutawwakil Jan 20, 2025
1d74388
use pkv index
IlyasMoutawwakil Jan 24, 2025
d452c46
add version gard and clean up other model patcher version gards
IlyasMoutawwakil Jan 24, 2025
5dcab7f
patch sdpa attention in optimum for now
IlyasMoutawwakil Jan 24, 2025
656941a
remove modernbert condition
IlyasMoutawwakil Jan 24, 2025
1bcb38f
style
IlyasMoutawwakil Jan 24, 2025
23fa20e
fix MistralModelPatcher
IlyasMoutawwakil Jan 24, 2025
24c8f4b
correctly patch gpt2 in vision encoder decoder
IlyasMoutawwakil Jan 24, 2025
3694ea4
patch sdpa attention forward everywhere
IlyasMoutawwakil Jan 26, 2025
3d7d586
fix gpt2 cross attention in seq2seq as well
IlyasMoutawwakil Jan 26, 2025
10833d8
moved traceable cache to a file for simplicity of model patcher
IlyasMoutawwakil Jan 29, 2025
9491d17
Apply suggestions from code review
IlyasMoutawwakil Jan 29, 2025
2b73129
style
IlyasMoutawwakil Jan 29, 2025
dea98a0
fix
IlyasMoutawwakil Jan 29, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
patch sdpa attention in optimum for now
IlyasMoutawwakil committed Jan 24, 2025
commit 5dcab7f1ba003c88f703a03693ff1fa9c4430cf4
65 changes: 59 additions & 6 deletions optimum/exporters/onnx/model_patcher.py
Original file line number Diff line number Diff line change
@@ -31,12 +31,15 @@
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
if is_transformers_version(">=", "4.36"):
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask_for_sdpa
if is_transformers_version(">=", "4.43"):
from transformers.models.clip.modeling_clip import CLIPAttention, CLIPSdpaAttention
if is_transformers_version(">=", "4.42"):
from transformers.cache_utils import SlidingWindowCache, StaticCache
if is_transformers_version(">=", "4.48"):
import transformers.integrations.sdpa_attention
from transformers.cache_utils import DynamicCache, EncoderDecoderCache
if is_transformers_version(">=", "4.43"):
from transformers.models.clip.modeling_clip import CLIPAttention, CLIPSdpaAttention
from transformers.integrations.sdpa_attention import repeat_kv


if TYPE_CHECKING:
from transformers import PreTrainedModel, TFPreTrainedModel
@@ -532,27 +535,74 @@ def _prepare_4d_causal_attention_mask_for_sdpa_patched(
return attention_mask


def patched_sdpa_attention_forward(
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
dropout: float = 0.0,
scaling: Optional[float] = None,
is_causal: Optional[bool] = None,
**kwargs,
) -> Tuple[torch.Tensor, None]:
if hasattr(module, "num_key_value_groups"):
key = repeat_kv(key, module.num_key_value_groups)
value = repeat_kv(value, module.num_key_value_groups)
causal_mask = attention_mask
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : key.shape[-2]]
# SDPA with memory-efficient backend is bugged with non-contiguous inputs and custom attn_mask for some torch versions
# Reference: https://github.com/pytorch/pytorch/issues/112577.
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
if is_causal is None:
is_causal = causal_mask is None and query.shape[2] > 1

# Shapes (e.g. query.shape[2]) are tensors during jit tracing, resulting in `is_causal` being a tensor.
# We convert it to a bool for the SDPA kernel that only accepts bools.
if torch.jit.is_tracing() and isinstance(is_causal, torch.Tensor):
is_causal = is_causal.item()

attn_output = torch.nn.functional.scaled_dot_product_attention(
query,
key,
value,
attn_mask=causal_mask,
dropout_p=dropout,
scale=scaling,
is_causal=is_causal,
)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, None


class DecoderModelPatcher(ModelPatcher):
def __enter__(self):
super().__enter__()
if is_transformers_version(">=", "4.36"):
AttentionMaskConverter._unmask_unattended = _unmask_unattended_patched_staticmethod

if is_transformers_version(">=", "4.36"):
patch_everywhere(
"_prepare_4d_causal_attention_mask_for_sdpa", _prepare_4d_causal_attention_mask_for_sdpa_patched
)

if is_transformers_version(">=", "4.48"):
transformers.integrations.sdpa_attention.sdpa_attention_forward = patched_sdpa_attention_forward

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
if is_transformers_version(">=", "4.36"):
AttentionMaskConverter._unmask_unattended = staticmethod(self.original_unmask_unattended)

if is_transformers_version(">=", "4.36"):
patch_everywhere(
"_prepare_4d_causal_attention_mask_for_sdpa", self.original_prepare_4d_causal_attention_mask_for_sdpa
)

if is_transformers_version(">=", "4.48"):
transformers.integrations.sdpa_attention.sdpa_attention_forward = self.original_sdpa_attention_forward

def __init__(
self,
config: "OnnxConfig",
@@ -565,6 +615,9 @@ def __init__(
self.original_prepare_4d_causal_attention_mask_for_sdpa = _prepare_4d_causal_attention_mask_for_sdpa
self.original_unmask_unattended = AttentionMaskConverter._unmask_unattended

if is_transformers_version(">=", "4.48"):
self.original_sdpa_attention_forward = transformers.integrations.sdpa_attention.sdpa_attention_forward


def falcon_build_alibi_tensor_patched(
attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype
12 changes: 4 additions & 8 deletions setup.py
Original file line number Diff line number Diff line change
@@ -50,8 +50,7 @@
"datasets>=1.2.1",
"evaluate",
"protobuf>=3.20.1",
# "transformers>=4.36,<4.49.0",
"transformers@git+https://github.com/huggingface/transformers.git@fix-gpt2-is-causal",
"transformers>=4.36,<4.49.0",
],
"onnxruntime-gpu": [
"onnx",
@@ -60,22 +59,19 @@
"evaluate",
"protobuf>=3.20.1",
"accelerate", # ORTTrainer requires it.
# "transformers>=4.36,<4.49.0",
"transformers@git+https://github.com/huggingface/transformers.git@fix-gpt2-is-causal",
"transformers>=4.36,<4.49.0",
],
"exporters": [
"onnx",
"onnxruntime",
"timm",
# "transformers>=4.36,<4.49.0",
"transformers@git+https://github.com/huggingface/transformers.git@fix-gpt2-is-causal",
"transformers>=4.36,<4.49.0",
],
"exporters-gpu": [
"onnx",
"onnxruntime-gpu",
"timm",
# "transformers>=4.36,<4.49.0",
"transformers@git+https://github.com/huggingface/transformers.git@fix-gpt2-is-causal",
"transformers>=4.36,<4.49.0",
],
"exporters-tf": [
"tensorflow>=2.4,<=2.12.1",