diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index 5ad827689b41..3230952bf5c7 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -675,6 +675,8 @@ def forward( v1_chunked = [v1.squeeze(dim=0) for v1 in v1_chunked] w2_chunked = [w2.squeeze(dim=0) for w2 in w2_chunked] for expert_idx in range(0, self.moe_num_experts): + # (This cause torch.compile to fail with `torch._dynamo.exc.Unsupported: dynamic shape operator: aten.nonzero.default`) + # (set torch._dynamo.config.capture_dynamic_output_shape_ops = True may help but not tested) topk_idx, token_idx = torch.where(expert_mask[expert_idx]) if token_idx.shape[0] == 0: continue @@ -831,7 +833,6 @@ class DbrxPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_cache_class = True _supports_quantized_cache = True - _supports_static_cache = True def _init_weights(self, module: nn.Module): std = self.config.initializer_range diff --git a/src/transformers/models/granitemoe/modeling_granitemoe.py b/src/transformers/models/granitemoe/modeling_granitemoe.py index 77ab0cece3ea..306457d572e8 100644 --- a/src/transformers/models/granitemoe/modeling_granitemoe.py +++ b/src/transformers/models/granitemoe/modeling_granitemoe.py @@ -330,6 +330,8 @@ def forward(self, hidden_states): ) # [num_tokens, num_experts] gates = zeros.scatter(1, top_k_indices, 1) # [num_tokens, num_experts] expert_size = gates.long().sum(0) # [num_experts,] + # (This cause torch.compile to fail with `torch._dynamo.exc.Unsupported: Backend compiler failed with a fake tensor exception at`) + # (and `DataDependentOutputException`) expert_size = expert_size.tolist() # sort and group input tokens according to expert assignment @@ -841,7 +843,6 @@ class GraniteMoePreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_cache_class = True _supports_quantized_cache = True - _supports_static_cache = True def _init_weights(self, module): std = self.config.initializer_range @@ -1155,8 +1156,6 @@ def _update_causal_mask( if attention_mask is not None and attention_mask.dim() == 4: # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing - if attention_mask.max() != 0: - raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") causal_mask = attention_mask else: causal_mask = torch.full( diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index 2e502d02fdef..3d0a956a7bf7 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -868,7 +868,7 @@ def forward( ) hidden_states = nn.functional.dropout(hidden_states, p=self.config, training=self.training) # Fill in zeros for cross_attention hidden_states of tokens attending to no images - hidden_states[cross_attention_gate == 0] = hidden_states[cross_attention_gate == 0].fill_(0) + hidden_states = hidden_states.masked_fill((cross_attention_gate == 0)[:, :, None], 0.0) hidden_states = residual + self.act_cross_attn(self.alpha_cross_attn) * hidden_states # Fully Connected @@ -917,7 +917,6 @@ class IdeficsPreTrainedModel(PreTrainedModel): _no_split_modules = ["IdeficsDecoderLayer", "IdeficsGatedCrossAttentionLayer"] _supports_sdpa = True _supports_cache_class = True - _supports_static_cache = True def _init_weights(self, module): # important: this ported version of Idefics isn't meant for training from scratch - only @@ -1155,7 +1154,7 @@ def forward( elif position_ids is None: position_ids = cache_position.unsqueeze(0) - if (pixel_values, image_encoder_embeddings, perceiver_embeddings).count(None) != 2: + if sum([x is None for x in [pixel_values, image_encoder_embeddings, perceiver_embeddings]]) != 2: raise ValueError( "Exactly 1 of pixel_values, image_encoder_embeddings or perceiver_embeddings has to be not-None." ) diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index fca47eb3fa0d..a5aa6e8a9537 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -216,6 +216,8 @@ def forward(self, hidden_states): ) # [num_tokens, num_experts] gates = zeros.scatter(1, top_k_indices, 1) # [num_tokens, num_experts] expert_size = gates.long().sum(0) # [num_experts,] + # (This cause torch.compile to fail with `torch._dynamo.exc.Unsupported: Backend compiler failed with a fake tensor exception at`) + # (and `DataDependentOutputException`) expert_size = expert_size.tolist() # sort and group input tokens according to expert assignment