Skip to content

Commit

Permalink
fix style
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix committed Jan 15, 2025
1 parent 58b906f commit 46f1c26
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 45 deletions.
39 changes: 8 additions & 31 deletions optimum/exporters/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,28 +14,15 @@
# limitations under the License.
"""Base exporters config."""

from abc import ABC



import copy
import enum
import gc
import inspect
import itertools
import os
import re
from abc import ABC, abstractmethod
from collections import OrderedDict
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

import numpy as np
from transformers.utils import is_accelerate_available, is_torch_available
from transformers.utils import is_torch_available


if is_torch_available():
import torch.nn as nn
pass

from .utils import (
DEFAULT_DUMMY_SHAPES,
Expand All @@ -46,19 +33,18 @@
from .utils import TRANSFORMERS_MINIMUM_VERSION as GLOBAL_MIN_TRANSFORMERS_VERSION
from .utils.doc import add_dynamic_docstring
from .utils.import_utils import is_torch_version, is_transformers_version
from .error_utils import MissingMandatoryAxisDimension


# from .model_patcher import ModelPatcher

if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel, TFPreTrainedModel
from transformers import PretrainedConfig

from .model_patcher import PatchingSpec

logger = logging.get_logger(__name__)



GENERATE_DUMMY_DOCSTRING = r"""
Generates the dummy inputs necessary for tracing the model. If not explicitely specified, default input shapes are used.
Expand Down Expand Up @@ -90,13 +76,11 @@
"""



# TODO: Remove
class ExportConfig(ABC):
pass



class ExportersConfig(ABC):
"""
Base class describing metadata on how to export the model through the ONNX format.
Expand Down Expand Up @@ -141,19 +125,19 @@ class ExportersConfig(ABC):
"audio-xvector": ["logits"], # for onnx : ["logits", "embeddings"]
"depth-estimation": ["predicted_depth"],
"document-question-answering": ["logits"],
"feature-extraction": ["last_hidden_state"], # for neuron : ["last_hidden_state", "pooler_output"]
"feature-extraction": ["last_hidden_state"], # for neuron : ["last_hidden_state", "pooler_output"]
"fill-mask": ["logits"],
"image-classification": ["logits"],
"image-segmentation": ["logits"], # for tflite : ["logits", "pred_boxes", "pred_masks"]
"image-to-text": ["logits"],
"image-to-image": ["reconstruction"],
"mask-generation": ["logits"],
"masked-im": ["logits"], # for onnx : ["reconstruction"]
"masked-im": ["logits"], # for onnx : ["reconstruction"]
"multiple-choice": ["logits"],
"object-detection": ["logits", "pred_boxes"],
"question-answering": ["start_logits", "end_logits"],
"semantic-segmentation": ["logits"],
"text2text-generation": ["logits"], # for tflite : ["logits", "encoder_last_hidden_state"],
"text2text-generation": ["logits"], # for tflite : ["logits", "encoder_last_hidden_state"],
"text-classification": ["logits"],
"text-generation": ["logits"],
"time-series-forecasting": ["prediction_outputs"],
Expand All @@ -179,7 +163,6 @@ def __init__(
self.mandatory_axes = ()
self._axes: Dict[str, int] = {}


def _create_dummy_input_generator_classes(self, **kwargs) -> List[DummyInputGenerator]:
"""
Instantiates the dummy input generators from `self.DUMMY_INPUT_GENERATOR_CLASSES`.
Expand All @@ -190,7 +173,6 @@ def _create_dummy_input_generator_classes(self, **kwargs) -> List[DummyInputGene
# self._validate_mandatory_axes()
return [cls_(self.task, self._normalized_config, **kwargs) for cls_ in self.DUMMY_INPUT_GENERATOR_CLASSES]


@property
@abstractmethod
def inputs(self) -> Dict[str, Dict[int, str]]:
Expand All @@ -213,7 +195,6 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
common_outputs = self._TASK_TO_COMMON_OUTPUTS[self.task]
return copy.deepcopy(common_outputs)


@property
def values_override(self) -> Optional[Dict[str, Any]]:
"""
Expand Down Expand Up @@ -251,18 +232,15 @@ def is_torch_support_available(self) -> bool:

return False


@add_dynamic_docstring(text=GENERATE_DUMMY_DOCSTRING, dynamic_elements=DEFAULT_DUMMY_SHAPES)
def generate_dummy_inputs(self, framework: str = "pt", **kwargs) -> Dict:

"""
Generates dummy inputs that the exported model should be able to process.
This method is actually used to determine the input specs that are needed for the export.
Returns:
`Dict[str, [tf.Tensor, torch.Tensor]]`: A dictionary mapping input names to dummy tensors.
"""


dummy_inputs_generators = self._create_dummy_input_generator_classes(**kwargs)
dummy_inputs = {}
Expand Down Expand Up @@ -303,5 +281,4 @@ def flatten_inputs(cls, inputs: Dict[str, Any]) -> Dict[str, Any]:
# ) -> ModelPatcher:
# return ModelPatcher(self, model, model_kwargs=model_kwargs)


############################################################################################################################################################
12 changes: 2 additions & 10 deletions optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import itertools
import os
import re
from abc import ABC, abstractmethod
from abc import ABC
from collections import OrderedDict
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union
Expand All @@ -41,16 +41,13 @@
is_diffusers_available,
logging,
)
from ...utils import TORCH_MINIMUM_VERSION as GLOBAL_MIN_TORCH_VERSION
from ...utils import TRANSFORMERS_MINIMUM_VERSION as GLOBAL_MIN_TRANSFORMERS_VERSION
from ...utils.doc import add_dynamic_docstring
from ...utils.import_utils import (
is_onnx_available,
is_onnxruntime_available,
is_torch_version,
is_transformers_version,
)
from ..base import ExportConfig, ExportersConfig
from ..base import ExportersConfig
from .constants import ONNX_DECODER_MERGED_NAME, ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME
from .model_patcher import ModelPatcher, Seq2SeqModelPatcher

Expand All @@ -66,7 +63,6 @@
if is_diffusers_available():
from diffusers import ModelMixin

from .model_patcher import PatchingSpec

logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -103,7 +99,6 @@


class OnnxConfig(ExportersConfig):

DEFAULT_ONNX_OPSET = 11
VARIANTS = {"default": "The default ONNX variant."}
DEFAULT_VARIANT = "default"
Expand Down Expand Up @@ -281,7 +276,6 @@ def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> ModelPatcher:
return ModelPatcher(self, model, model_kwargs=model_kwargs)


@property
def torch_to_onnx_input_map(self) -> Dict[str, str]:
Expand Down Expand Up @@ -348,8 +342,6 @@ def ordered_inputs(self, model: Union["PreTrainedModel", "TFPreTrainedModel"]) -
ordered_inputs[name] = dynamic_axes
return ordered_inputs



# TODO: use instead flatten_inputs and remove
@classmethod
def flatten_output_collection_property(cls, name: str, field: Iterable[Any]) -> Dict[str, Any]:
Expand Down
4 changes: 1 addition & 3 deletions optimum/exporters/tflite/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
"""TensorFlow Lite configuration base classes."""

from abc import ABC, abstractmethod
from abc import ABC
from ctypes import ArgumentError
from dataclasses import dataclass
from enum import Enum
Expand Down Expand Up @@ -192,7 +192,6 @@ def __init__(
point_batch_size: Optional[int] = None,
nb_points_per_image: Optional[int] = None,
):

super().__init__(config=config, task=task, int_dtype="int64", float_dtype="fp32")

# self.mandatory_axes = ()
Expand Down Expand Up @@ -269,7 +268,6 @@ def _create_dummy_input_generator_classes(self) -> List["DummyInputGenerator"]:
def generate_dummy_inputs(self) -> Dict[str, "tf.Tensor"]:
return super().generate_dummy_inputs(framework="tf")


@property
def inputs_specs(self) -> List["TensorSpec"]:
"""
Expand Down
2 changes: 1 addition & 1 deletion optimum/exporters/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,7 @@ def get_speecht5_models_for_export(
use_past=use_past,
use_past_in_inputs=False, # Irrelevant here.
behavior=config._behavior, # Irrelevant here.
preprocessors=config._preprocessors,
# preprocessors=config._preprocessors,
is_postnet_and_vocoder=True,
legacy=config.legacy,
)
Expand Down

0 comments on commit 46f1c26

Please sign in to comment.