From f3dbc6f743ac1cee831545cd35201665b6316e60 Mon Sep 17 00:00:00 2001 From: ydshieh Date: Wed, 27 Nov 2024 15:38:26 +0100 Subject: [PATCH 1/9] try 1 --- src/transformers/models/dbrx/modeling_dbrx.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index 659fa154ecf776..312878fced951b 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -832,7 +832,7 @@ class DbrxPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_cache_class = True _supports_quantized_cache = True - _supports_static_cache = True + _supports_static_cache = False def _init_weights(self, module: nn.Module): std = self.config.initializer_range From c9358789f3a701be658c1b351ce3ffa4a2d2fbf2 Mon Sep 17 00:00:00 2001 From: ydshieh Date: Wed, 27 Nov 2024 15:42:31 +0100 Subject: [PATCH 2/9] try 1 --- src/transformers/models/dbrx/modeling_dbrx.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index 312878fced951b..7fb46ea2c32415 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -832,7 +832,6 @@ class DbrxPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_cache_class = True _supports_quantized_cache = True - _supports_static_cache = False def _init_weights(self, module: nn.Module): std = self.config.initializer_range From 853ed93bc774b68c12733deb8bd2ffaa6c2cf1bf Mon Sep 17 00:00:00 2001 From: ydshieh Date: Wed, 27 Nov 2024 15:48:37 +0100 Subject: [PATCH 3/9] try 1 --- src/transformers/models/idefics/modeling_idefics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index 8bd24728b03885..e24c7eb3f8a1ff 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -1155,7 +1155,7 @@ def forward( elif position_ids is None: position_ids = cache_position.unsqueeze(0) - if (pixel_values, image_encoder_embeddings, perceiver_embeddings).count(None) != 2: + if len([x is None for x in [pixel_values, image_encoder_embeddings, perceiver_embeddings]]) != 2: raise ValueError( "Exactly 1 of pixel_values, image_encoder_embeddings or perceiver_embeddings has to be not-None." ) From bb921edef5be4cf45c09b2c2618052f28b09a027 Mon Sep 17 00:00:00 2001 From: ydshieh Date: Wed, 27 Nov 2024 15:55:45 +0100 Subject: [PATCH 4/9] try 1 --- src/transformers/models/idefics/modeling_idefics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index e24c7eb3f8a1ff..336a160c54a9ff 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -1155,7 +1155,7 @@ def forward( elif position_ids is None: position_ids = cache_position.unsqueeze(0) - if len([x is None for x in [pixel_values, image_encoder_embeddings, perceiver_embeddings]]) != 2: + if sum([x is None for x in [pixel_values, image_encoder_embeddings, perceiver_embeddings]]) != 2: raise ValueError( "Exactly 1 of pixel_values, image_encoder_embeddings or perceiver_embeddings has to be not-None." ) From 565a6b13691d30d432b95f5702972c2ca0475f4b Mon Sep 17 00:00:00 2001 From: ydshieh Date: Wed, 27 Nov 2024 16:02:25 +0100 Subject: [PATCH 5/9] try 1 --- src/transformers/models/idefics/modeling_idefics.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index 336a160c54a9ff..ca0f44d447033f 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -917,7 +917,6 @@ class IdeficsPreTrainedModel(PreTrainedModel): _no_split_modules = ["IdeficsDecoderLayer", "IdeficsGatedCrossAttentionLayer"] _supports_sdpa = True _supports_cache_class = True - _supports_static_cache = True def _init_weights(self, module): # important: this ported version of Idefics isn't meant for training from scratch - only From 2facf2624d4c1f288958f9e944b1ea0f637120c9 Mon Sep 17 00:00:00 2001 From: ydshieh Date: Wed, 27 Nov 2024 16:11:35 +0100 Subject: [PATCH 6/9] try 1 --- src/transformers/models/dbrx/modeling_dbrx.py | 2 ++ src/transformers/models/idefics/modeling_idefics.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index 7fb46ea2c32415..3b1b5536e48854 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -676,6 +676,8 @@ def forward( v1_chunked = [v1.squeeze(dim=0) for v1 in v1_chunked] w2_chunked = [w2.squeeze(dim=0) for w2 in w2_chunked] for expert_idx in range(0, self.moe_num_experts): + # (This cause torch.compile to fail with `torch._dynamo.exc.Unsupported: dynamic shape operator: aten.nonzero.default`) + # (set torch._dynamo.config.capture_dynamic_output_shape_ops = True may help but not tested) topk_idx, token_idx = torch.where(expert_mask[expert_idx]) if token_idx.shape[0] == 0: continue diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index ca0f44d447033f..619ef5ff1688e7 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -868,6 +868,8 @@ def forward( ) hidden_states = nn.functional.dropout(hidden_states, p=self.config, training=self.training) # Fill in zeros for cross_attention hidden_states of tokens attending to no images + # (This cause torch.compile to fail with `torch._dynamo.exc.Unsupported: dynamic shape operator: aten.nonzero.default`) + # (set torch._dynamo.config.capture_dynamic_output_shape_ops = True may help but not tested) hidden_states[cross_attention_gate == 0] = hidden_states[cross_attention_gate == 0].fill_(0) hidden_states = residual + self.act_cross_attn(self.alpha_cross_attn) * hidden_states From b63002499dfb3f2728dbe4f404b8db0b5ce53d3c Mon Sep 17 00:00:00 2001 From: ydshieh Date: Wed, 27 Nov 2024 16:22:44 +0100 Subject: [PATCH 7/9] try 1 --- src/transformers/models/granitemoe/modeling_granitemoe.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/transformers/models/granitemoe/modeling_granitemoe.py b/src/transformers/models/granitemoe/modeling_granitemoe.py index 4871fc3584faee..3028162a1f95a7 100644 --- a/src/transformers/models/granitemoe/modeling_granitemoe.py +++ b/src/transformers/models/granitemoe/modeling_granitemoe.py @@ -1152,8 +1152,6 @@ def _update_causal_mask( if attention_mask is not None and attention_mask.dim() == 4: # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing - if attention_mask.max() != 0: - raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") causal_mask = attention_mask else: causal_mask = torch.full( From 711756699588e36e4fff7cbcdab7220a417ae112 Mon Sep 17 00:00:00 2001 From: ydshieh Date: Wed, 27 Nov 2024 16:29:22 +0100 Subject: [PATCH 8/9] try 1 --- src/transformers/models/granitemoe/modeling_granitemoe.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/granitemoe/modeling_granitemoe.py b/src/transformers/models/granitemoe/modeling_granitemoe.py index 3028162a1f95a7..8ef46bccad10f2 100644 --- a/src/transformers/models/granitemoe/modeling_granitemoe.py +++ b/src/transformers/models/granitemoe/modeling_granitemoe.py @@ -330,6 +330,8 @@ def forward(self, hidden_states): ) # [num_tokens, num_experts] gates = zeros.scatter(1, top_k_indices, 1) # [num_tokens, num_experts] expert_size = gates.long().sum(0) # [num_experts,] + # (This cause torch.compile to fail with `torch._dynamo.exc.Unsupported: Backend compiler failed with a fake tensor exception at`) + # (and `DataDependentOutputException`) expert_size = expert_size.tolist() # sort and group input tokens according to expert assignment @@ -838,7 +840,6 @@ class GraniteMoePreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_cache_class = True _supports_quantized_cache = True - _supports_static_cache = True def _init_weights(self, module): std = self.config.initializer_range From b371e98ad2fd7f6a317c888c1e8c9db47b3a1c5a Mon Sep 17 00:00:00 2001 From: ydshieh Date: Wed, 27 Nov 2024 16:44:33 +0100 Subject: [PATCH 9/9] try 1 --- src/transformers/models/jetmoe/modeling_jetmoe.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index a4bb1d78fdc5ce..81e15a4af8a2d5 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -214,6 +214,8 @@ def forward(self, hidden_states): ) # [num_tokens, num_experts] gates = zeros.scatter(1, top_k_indices, 1) # [num_tokens, num_experts] expert_size = gates.long().sum(0) # [num_experts,] + # (This cause torch.compile to fail with `torch._dynamo.exc.Unsupported: Backend compiler failed with a fake tensor exception at`) + # (and `DataDependentOutputException`) expert_size = expert_size.tolist() # sort and group input tokens according to expert assignment