Skip to content

Commit

Permalink
Support segformer fx (huggingface#19924)
Browse files Browse the repository at this point in the history
* Support segformer fx

* Add fx_compatible attribute to test_modeling_segformer.py

* Update glpn model (fx support)

glpn model was copied from segformer.

* Update utils/fx.py | add semantic-segmentation

for SegformerForSemanticSegmentation model

* Fix minor import order(isort)

* Add random input generation for segformer fx

Co-authored-by: noelbird <[email protected]>
  • Loading branch information
dwlim-nota and NoelBird authored Oct 28, 2022
1 parent dcca71b commit 347ba38
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 4 deletions.
4 changes: 2 additions & 2 deletions src/transformers/models/glpn/modeling_glpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def __init__(self, config, hidden_size, num_attention_heads, sequence_reduction_

def transpose_for_scores(self, hidden_states):
new_shape = hidden_states.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
hidden_states = hidden_states.view(*new_shape)
hidden_states = hidden_states.view(new_shape)
return hidden_states.permute(0, 2, 1, 3)

def forward(
Expand Down Expand Up @@ -190,7 +190,7 @@ def forward(

context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
context_layer = context_layer.view(new_context_layer_shape)

outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)

Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/segformer/modeling_segformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def __init__(self, config, hidden_size, num_attention_heads, sequence_reduction_

def transpose_for_scores(self, hidden_states):
new_shape = hidden_states.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
hidden_states = hidden_states.view(*new_shape)
hidden_states = hidden_states.view(new_shape)
return hidden_states.permute(0, 2, 1, 3)

def forward(
Expand Down Expand Up @@ -220,7 +220,7 @@ def forward(

context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
context_layer = context_layer.view(new_context_layer_shape)

outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)

Expand Down
4 changes: 4 additions & 0 deletions src/transformers/utils/fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES,
MODEL_FOR_PRETRAINING_MAPPING_NAMES,
MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES,
MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES,
Expand Down Expand Up @@ -80,6 +81,7 @@ def _generate_supported_model_class_names(
"image-classification": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
"ctc": MODEL_FOR_CTC_MAPPING_NAMES,
"audio-classification": MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
"semantic-segmentation": MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES,
}

if supported_tasks is None:
Expand Down Expand Up @@ -128,6 +130,7 @@ def _generate_supported_model_class_names(
"plbart",
"resnet",
"roberta",
"segformer",
"speech_to_text",
"speech_to_text_2",
"swin",
Expand Down Expand Up @@ -730,6 +733,7 @@ def _generate_dummy_input(
*get_values(MODEL_FOR_CAUSAL_LM_MAPPING_NAMES),
*get_values(MODEL_FOR_MASKED_LM_MAPPING_NAMES),
*get_values(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES),
*get_values(MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES),
"GPT2DoubleHeadsModel",
]:
inputs_dict["labels"] = torch.zeros(shape, dtype=torch.long, device=device)
Expand Down
1 change: 1 addition & 0 deletions tests/models/segformer/test_modeling_segformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ class SegformerModelTest(ModelTesterMixin, unittest.TestCase):
else ()
)

fx_compatible = True
test_head_masking = False
test_pruning = False
test_resize_embeddings = False
Expand Down

0 comments on commit 347ba38

Please sign in to comment.