-
Notifications
You must be signed in to change notification settings - Fork 27.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remove _supports_static_cache = True for some model classes #34975
base: main
Are you sure you want to change the base?
Changes from all commits
f3dbc6f
c935878
853ed93
bb921ed
565a6b1
2facf26
b630024
7117566
b371e98
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -330,6 +330,8 @@ def forward(self, hidden_states): | |
) # [num_tokens, num_experts] | ||
gates = zeros.scatter(1, top_k_indices, 1) # [num_tokens, num_experts] | ||
expert_size = gates.long().sum(0) # [num_experts,] | ||
# (This cause torch.compile to fail with `torch._dynamo.exc.Unsupported: Backend compiler failed with a fake tensor exception at`) | ||
# (and `DataDependentOutputException`) | ||
expert_size = expert_size.tolist() | ||
|
||
# sort and group input tokens according to expert assignment | ||
|
@@ -838,7 +840,6 @@ class GraniteMoePreTrainedModel(PreTrainedModel): | |
_supports_sdpa = True | ||
_supports_cache_class = True | ||
_supports_quantized_cache = True | ||
_supports_static_cache = True | ||
|
||
def _init_weights(self, module): | ||
std = self.config.initializer_range | ||
|
@@ -1152,8 +1153,6 @@ def _update_causal_mask( | |
|
||
if attention_mask is not None and attention_mask.dim() == 4: | ||
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing | ||
if attention_mask.max() != 0: | ||
raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") | ||
Comment on lines
-1155
to
-1156
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is the only place I see an extra check There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cc @mayank31398 you may know better if this check is really necessary There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it is not in granite modeling code though |
||
causal_mask = attention_mask | ||
else: | ||
causal_mask = torch.full( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -868,6 +868,8 @@ def forward( | |
) | ||
hidden_states = nn.functional.dropout(hidden_states, p=self.config, training=self.training) | ||
# Fill in zeros for cross_attention hidden_states of tokens attending to no images | ||
# (This cause torch.compile to fail with `torch._dynamo.exc.Unsupported: dynamic shape operator: aten.nonzero.default`) | ||
# (set torch._dynamo.config.capture_dynamic_output_shape_ops = True may help but not tested) | ||
hidden_states[cross_attention_gate == 0] = hidden_states[cross_attention_gate == 0].fill_(0) | ||
hidden_states = residual + self.act_cross_attn(self.alpha_cross_attn) * hidden_states | ||
|
||
|
@@ -917,7 +919,6 @@ class IdeficsPreTrainedModel(PreTrainedModel): | |
_no_split_modules = ["IdeficsDecoderLayer", "IdeficsGatedCrossAttentionLayer"] | ||
_supports_sdpa = True | ||
_supports_cache_class = True | ||
_supports_static_cache = True | ||
|
||
def _init_weights(self, module): | ||
# important: this ported version of Idefics isn't meant for training from scratch - only | ||
|
@@ -1155,7 +1156,7 @@ def forward( | |
elif position_ids is None: | ||
position_ids = cache_position.unsqueeze(0) | ||
|
||
if (pixel_values, image_encoder_embeddings, perceiver_embeddings).count(None) != 2: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this will fail torch compile with another different type error. |
||
if sum([x is None for x in [pixel_values, image_encoder_embeddings, perceiver_embeddings]]) != 2: | ||
raise ValueError( | ||
"Exactly 1 of pixel_values, image_encoder_embeddings or perceiver_embeddings has to be not-None." | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
jimba
has this lineexpert_size = expert_size.tolist()
too and it has no_supports_static_cache = True
. Let do the same for this model.