Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove _supports_static_cache = True for some model classes #34975

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/transformers/models/dbrx/modeling_dbrx.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,8 @@ def forward(
v1_chunked = [v1.squeeze(dim=0) for v1 in v1_chunked]
w2_chunked = [w2.squeeze(dim=0) for w2 in w2_chunked]
for expert_idx in range(0, self.moe_num_experts):
# (This cause torch.compile to fail with `torch._dynamo.exc.Unsupported: dynamic shape operator: aten.nonzero.default`)
# (set torch._dynamo.config.capture_dynamic_output_shape_ops = True may help but not tested)
topk_idx, token_idx = torch.where(expert_mask[expert_idx])
if token_idx.shape[0] == 0:
continue
Expand Down Expand Up @@ -832,7 +834,6 @@ class DbrxPreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True

def _init_weights(self, module: nn.Module):
std = self.config.initializer_range
Expand Down
5 changes: 2 additions & 3 deletions src/transformers/models/granitemoe/modeling_granitemoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,8 @@ def forward(self, hidden_states):
) # [num_tokens, num_experts]
gates = zeros.scatter(1, top_k_indices, 1) # [num_tokens, num_experts]
expert_size = gates.long().sum(0) # [num_experts,]
# (This cause torch.compile to fail with `torch._dynamo.exc.Unsupported: Backend compiler failed with a fake tensor exception at`)
# (and `DataDependentOutputException`)
expert_size = expert_size.tolist()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

jimba has this line expert_size = expert_size.tolist() too and it has no _supports_static_cache = True. Let do the same for this model.


# sort and group input tokens according to expert assignment
Expand Down Expand Up @@ -838,7 +840,6 @@ class GraniteMoePreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True

def _init_weights(self, module):
std = self.config.initializer_range
Expand Down Expand Up @@ -1152,8 +1153,6 @@ def _update_causal_mask(

if attention_mask is not None and attention_mask.dim() == 4:
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
if attention_mask.max() != 0:
raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`")
Comment on lines -1155 to -1156
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the only place I see an extra check attention_mask.max() != 0 within if attention_mask is not None and attention_mask.dim() == 4. Not sure if we really need it, but it gives another different error (different from what expert_size = expert_size.tolist() give above) if we use torch compile

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @mayank31398 you may know better if this check is really necessary

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is not in granite modeling code though

causal_mask = attention_mask
else:
causal_mask = torch.full(
Expand Down
5 changes: 3 additions & 2 deletions src/transformers/models/idefics/modeling_idefics.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,6 +868,8 @@ def forward(
)
hidden_states = nn.functional.dropout(hidden_states, p=self.config, training=self.training)
# Fill in zeros for cross_attention hidden_states of tokens attending to no images
# (This cause torch.compile to fail with `torch._dynamo.exc.Unsupported: dynamic shape operator: aten.nonzero.default`)
# (set torch._dynamo.config.capture_dynamic_output_shape_ops = True may help but not tested)
hidden_states[cross_attention_gate == 0] = hidden_states[cross_attention_gate == 0].fill_(0)
hidden_states = residual + self.act_cross_attn(self.alpha_cross_attn) * hidden_states

Expand Down Expand Up @@ -917,7 +919,6 @@ class IdeficsPreTrainedModel(PreTrainedModel):
_no_split_modules = ["IdeficsDecoderLayer", "IdeficsGatedCrossAttentionLayer"]
_supports_sdpa = True
_supports_cache_class = True
_supports_static_cache = True

def _init_weights(self, module):
# important: this ported version of Idefics isn't meant for training from scratch - only
Expand Down Expand Up @@ -1155,7 +1156,7 @@ def forward(
elif position_ids is None:
position_ids = cache_position.unsqueeze(0)

if (pixel_values, image_encoder_embeddings, perceiver_embeddings).count(None) != 2:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will fail torch compile with another different type error.

if sum([x is None for x in [pixel_values, image_encoder_embeddings, perceiver_embeddings]]) != 2:
raise ValueError(
"Exactly 1 of pixel_values, image_encoder_embeddings or perceiver_embeddings has to be not-None."
)
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/jetmoe/modeling_jetmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,8 @@ def forward(self, hidden_states):
) # [num_tokens, num_experts]
gates = zeros.scatter(1, top_k_indices, 1) # [num_tokens, num_experts]
expert_size = gates.long().sum(0) # [num_experts,]
# (This cause torch.compile to fail with `torch._dynamo.exc.Unsupported: Backend compiler failed with a fake tensor exception at`)
# (and `DataDependentOutputException`)
expert_size = expert_size.tolist()

# sort and group input tokens according to expert assignment
Expand Down