diff --git a/docs/source/en/serialization.mdx b/docs/source/en/serialization.mdx index 269a7271fda0ca..db4f8aa83daefd 100644 --- a/docs/source/en/serialization.mdx +++ b/docs/source/en/serialization.mdx @@ -74,6 +74,7 @@ Ready-made configurations include the following architectures: - GPT-J - GroupViT - I-BERT +- ImageGPT - LayoutLM - LayoutLMv3 - LeViT diff --git a/src/transformers/models/imagegpt/__init__.py b/src/transformers/models/imagegpt/__init__.py index ecf7ba9408d1bf..c2189a3ce17ac7 100644 --- a/src/transformers/models/imagegpt/__init__.py +++ b/src/transformers/models/imagegpt/__init__.py @@ -21,7 +21,9 @@ from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available -_import_structure = {"configuration_imagegpt": ["IMAGEGPT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ImageGPTConfig"]} +_import_structure = { + "configuration_imagegpt": ["IMAGEGPT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ImageGPTConfig", "ImageGPTOnnxConfig"] +} try: if not is_vision_available(): @@ -48,7 +50,7 @@ if TYPE_CHECKING: - from .configuration_imagegpt import IMAGEGPT_PRETRAINED_CONFIG_ARCHIVE_MAP, ImageGPTConfig + from .configuration_imagegpt import IMAGEGPT_PRETRAINED_CONFIG_ARCHIVE_MAP, ImageGPTConfig, ImageGPTOnnxConfig try: if not is_vision_available(): diff --git a/src/transformers/models/imagegpt/configuration_imagegpt.py b/src/transformers/models/imagegpt/configuration_imagegpt.py index 60719831fe10b5..7a043c31ceb255 100644 --- a/src/transformers/models/imagegpt/configuration_imagegpt.py +++ b/src/transformers/models/imagegpt/configuration_imagegpt.py @@ -14,10 +14,17 @@ # limitations under the License. """ OpenAI ImageGPT configuration""" +from collections import OrderedDict +from typing import TYPE_CHECKING, Any, Mapping, Optional + from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig from ...utils import logging +if TYPE_CHECKING: + from ... import FeatureExtractionMixin, TensorType + logger = logging.get_logger(__name__) IMAGEGPT_PRETRAINED_CONFIG_ARCHIVE_MAP = { @@ -140,3 +147,56 @@ def __init__( self.tie_word_embeddings = tie_word_embeddings super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) + + +class ImageGPTOnnxConfig(OnnxConfig): + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict( + [ + ("input_ids", {0: "batch", 1: "sequence"}), + ] + ) + + def generate_dummy_inputs( + self, + preprocessor: "FeatureExtractionMixin", + batch_size: int = 1, + seq_length: int = -1, + is_pair: bool = False, + framework: Optional["TensorType"] = None, + num_channels: int = 3, + image_width: int = 32, + image_height: int = 32, + ) -> Mapping[str, Any]: + """ + Generate inputs to provide to the ONNX exporter for the specific framework + + Args: + preprocessor ([`PreTrainedTokenizerBase`] or [`FeatureExtractionMixin`]): + The preprocessor associated with this model configuration. + batch_size (`int`, *optional*, defaults to -1): + The batch size to export the model for (-1 means dynamic axis). + num_choices (`int`, *optional*, defaults to -1): + The number of candidate answers provided for multiple choice task (-1 means dynamic axis). + seq_length (`int`, *optional*, defaults to -1): + The sequence length to export the model for (-1 means dynamic axis). + is_pair (`bool`, *optional*, defaults to `False`): + Indicate if the input is a pair (sentence 1, sentence 2) + framework (`TensorType`, *optional*, defaults to `None`): + The framework (PyTorch or TensorFlow) that the tokenizer will generate tensors for. + num_channels (`int`, *optional*, defaults to 3): + The number of channels of the generated images. + image_width (`int`, *optional*, defaults to 40): + The width of the generated images. + image_height (`int`, *optional*, defaults to 40): + The height of the generated images. + + Returns: + Mapping[str, Tensor] holding the kwargs to provide to the model's forward function + """ + + input_image = self._generate_dummy_images(batch_size, num_channels, image_height, image_width) + inputs = dict(preprocessor(input_image, framework)) + + return inputs diff --git a/src/transformers/onnx/features.py b/src/transformers/onnx/features.py index 6a0ec0f7c70794..878fcce651f0a9 100644 --- a/src/transformers/onnx/features.py +++ b/src/transformers/onnx/features.py @@ -341,6 +341,9 @@ class FeaturesManager: "question-answering", onnx_config_cls="models.ibert.IBertOnnxConfig", ), + "imagegpt": supported_features_mapping( + "default", "image-classification", onnx_config_cls="models.imagegpt.ImageGPTOnnxConfig" + ), "layoutlm": supported_features_mapping( "default", "masked-lm", diff --git a/tests/onnx/test_onnx_v2.py b/tests/onnx/test_onnx_v2.py index 81cd55d3bb5a81..eac6ee06348532 100644 --- a/tests/onnx/test_onnx_v2.py +++ b/tests/onnx/test_onnx_v2.py @@ -193,6 +193,7 @@ def test_values_override(self): ("detr", "facebook/detr-resnet-50"), ("distilbert", "distilbert-base-cased"), ("electra", "google/electra-base-generator"), + ("imagegpt", "openai/imagegpt-small"), ("resnet", "microsoft/resnet-50"), ("roberta", "roberta-base"), ("roformer", "junnyu/roformer_chinese_base"),