diff --git a/.github/workflows/test_modeling_ort.yml b/.github/workflows/test_modeling_ort.yml new file mode 100644 index 0000000000..1792c15369 --- /dev/null +++ b/.github/workflows/test_modeling_ort.yml @@ -0,0 +1,32 @@ +# This workflow will install Python dependencies, run tests and lint with a variety of Python versions +# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions +name: Onnxruntime Models (Inference) / Python - Test + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + +jobs: + build: + strategy: + fail-fast: false + matrix: + python-version: [3.8, 3.9] + os: [ubuntu-20.04 ] #, windows-2019, macos-10.15] + + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v2 + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + pip install .[tests,onnxruntime] + - name: Test with pytest + shell: bash + run: | + pytest tests/onnxruntime/test_modeling_ort.py diff --git a/.gitignore b/.gitignore index 01cb4df7e6..88dea28593 100644 --- a/.gitignore +++ b/.gitignore @@ -131,3 +131,7 @@ dmypy.json # Models *.onnx +# include small test model for tests +!tests/assets/onnx/model.onnx + +.vscode \ No newline at end of file diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index d218f55a46..6709401554 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -3,8 +3,12 @@ title: 🤗 Optimum - local: quickstart title: Quickstart + - local: pipelines + title: Pipelines for inference title: Get started - sections: + - local: onnxruntime/modeling_ort + title: Inference - local: onnxruntime/configuration title: Configuration - local: onnxruntime/optimization diff --git a/docs/source/onnxruntime/modeling_ort.mdx b/docs/source/onnxruntime/modeling_ort.mdx new file mode 100644 index 0000000000..9dec0c6dab --- /dev/null +++ b/docs/source/onnxruntime/modeling_ort.mdx @@ -0,0 +1,103 @@ + + +# Optimum Inference with ONNX Runtime + +Optimum is a utility package for building and running inference with accelerated runtime like ONNX Runtime. +Optimum can be used to load optimized models from the [Hugging Face Hub](hf.co/models) and create pipelines +to run accelerated inference without rewriting your APIs. + +## Switching from Transformers to Optimum Inference + +The Optimum Inference models are API compatible with Hugging Face Transformers models. This means you can just replace your `AutoModelForXxx` class with the corresponding `ORTModelForXxx` class in `optimum`. For example, this is how you can use a question answering model in `optimum`: + +```diff +from transformers import AutoTokenizer, pipeline +-from transformers import AutoModelForQuestionAnswering ++from optimum.onnxruntime import ORTModelForQuestionAnswering + +-model = AutoModelForQuestionAnswering.from_pretrained("deepset/roberta-base-squad2") # pytorch checkpoint ++model = ORTModelForQuestionAnswering.from_pretrained("optimum/roberta-base-squad2") # onnx checkpoint +tokenizer = AutoTokenizer.from_pretrained("deepset/roberta-base-squad2") + +onnx_qa = pipeline("question-answering",model=model,tokenizer=tokenizer) + +question = "What's my name?" +context = "My name is Philipp and I live in Nuremberg." +pred = onnx_qa(question, context) +``` + +Optimum Inference also includes methods to convert vanilla Transformers models to optimized ones. Simply pass `from_transformers=True` to the `from_pretrained()` method, and your model will be loaded and converted to ONNX on-the-fly: + +```python +>>> from transformers import AutoTokenizer, pipeline +>>> from optimum.onnxruntime import ORTModelForSequenceClassification + +# load model from hub and convert +>>> model = ORTModelForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english",from_transformers=True) +>>> tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english") + +# create pipeline +>>> onnx_classifier = pipeline("text-classification",model=model,tokenizer=tokenizer) + +>>> result = onnx_classifier(text="This is a great model") +[{'label': 'POSITIVE', 'score': 0.9998838901519775}] +``` + +You can find a complete walkhrough Optimum Inference for ONNX Runtime in this [notebook](xx). + +### Working with the [Hugging Face Model Hub](https://hf.co/models) + +The Optimum model classes like [`~ORTModelForSequenceClassification`] are integrated with the [Hugging Face Model Hub](https://hf.co/models)), which means you can not only +load model from the Hub, but also push your models to the Hub with `push_to_hub()` method. Below is an example which downloads a vanilla Transformers model +from the Hub and converts it to an optimum onnxruntime model and pushes it back into a new repository. + + +```python +>>> from transformers import AutoTokenizer +>>> from optimum.onnxruntime import ORTModelForSequenceClassification + +# load model from hub and convert +>>> model = ORTModelForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english",from_transformers=True) +>>> tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english") + +# save converted model +>>> model.save_pretrained("a_local_path_for_convert_onnx_model") +>>> tokenizer.save_pretrained("a_local_path_for_convert_onnx_model") + +# push model onnx model to HF Hub +>>> model.push_to_hub("a_local_path_for_convert_onnx_model", + repository_id="my-onnx-repo", + use_auth_token=True + ) +``` + +## ORTModel + +[[autodoc]] onnxruntime.modeling_ort.ORTModel + +## ORTModelForFeatureExtraction + +[[autodoc]] onnxruntime.modeling_ort.ORTModelForFeatureExtraction + +## ORTModelForQuestionAnswering + +[[autodoc]] onnxruntime.modeling_ort.ORTModelForQuestionAnswering + +## ORTModelForSequenceClassification + +[[autodoc]] onnxruntime.modeling_ort.ORTModelForSequenceClassification + +## ORTModelForTokenClassification + +[[autodoc]] onnxruntime.modeling_ort.ORTModelForTokenClassification + diff --git a/docs/source/pipelines.mdx b/docs/source/pipelines.mdx new file mode 100644 index 0000000000..2e8e315af2 --- /dev/null +++ b/docs/source/pipelines.mdx @@ -0,0 +1,218 @@ + + +# Optimum pipelines for inference + +The [`pipeline`] makes it simple to use models from the [Model Hub](https://huggingface.co/models) for accelerated inference on a variety of tasks such as text classification. +Even if you don't have experience with a specific modality or understand the code powering the models, you can still use them with the [`pipeline`]! This tutorial will teach you to: + + + +You can also use the `pipeline()` function from Transformers and provide your `OptimumModel`. + + + +Currenlty supported tasks are: + +**Onnx Runtime** + +* `feature-extraction` +* `text-classification` +* `token-classification` +* `question-answering` +* `zero-shot-classification` +* `text-generation` + +## Optimum pipeline usage + +While each task has an associated [~`pipeline`], which it is simpler to use the general [~`pipeline`] abstraction which contains all the specific task pipelines. +The [~`pipeline`] automatically loads a default model and tokenizer capable of inference for your task. + +1. Start by creating a [~`pipeline`] and specify an inference task: + +```python +>>> from optimum import pipeline + +>>> classifier = pipeline(task="text-classification", accelerator="ort") + +``` + +2. Pass your input text to the [~`pipeline`]: + +```python +>>> classifier("I like you. I love you.") +[{'label': 'POSITIVE', 'score': 0.9998838901519775}] +``` + +_Note: The default models used in the [~`pipeline`] are not optimized or quantized, there won't be an performance improvement compared to there pytorch counter parts._ + +### Using vanilla Transformers model and converting to ONNX + +The [`pipeline`] accepts any supported model from the [Model Hub](https://huggingface.co/models). +There are tags on the Model Hub that allow you to filter for a model you'd like to use for your task. +Once you've picked an appropriate model, load it with the `from_pretrained("{model_id}",from_transformers=True)` method associated with the `ORTModelFor*` +[`AutoTokenizer'] class. For example, here's how you can load the [`ORTModelForQuestionAnswering`] class for question answering: + +```python +>>> from transformers import AutoTokenizer +>>> from optimum.onnxruntime import ORTModelForQuestionAnswering +>>> from optimum import pipeline + +>>> tokenizer = AutoTokenizer.from_pretrained("deepset/roberta-base-squad2") +>>> # loading the pytorch checkpoint and converting to ORT format by providing the from_transformers=True parameter +>>> model = ORTModelForQuestionAnswering.from_pretrained("deepset/roberta-base-squad2",from_transformers=True) + +>>> onnx_qa = pipeline("question-answering", model=model, tokenizer=tokenizer) +>>> question = "What's my name?" +>>> context = "My name is Philipp and I live in Nuremberg." + +>>> pred = onnx_qa(question=question, context=context) +``` + +### Using Optimum models + +The [`pipeline`] is tightly integrated with [Model Hub](https://huggingface.co/models) and can load optimized models directly, e.g. those created with OnnxRuntime. +There are tags on the Model Hub that allow you to filter for a model you'd like to use for your task. +Once you've picked an appropriate model, load it with the `from_pretrained()` method associated with the corresponding `ORTModelFor*` +and [`AutoTokenizer'] class. For example, here's how you can load an optimized model for question answering: + +```python +>>> from transformers import AutoTokenizer +>>> from optimum.onnxruntime import ORTModelForQuestionAnswering +>>> from optimum import pipeline + +>>> tokenizer = AutoTokenizer.from_pretrained("optimum/roberta-base-squad2") +>>> # loading already converted and optimized ORT checkpoint for inference +>>> model = ORTModelForQuestionAnswering.from_pretrained("optimum/roberta-base-squad2") + +>>> onnx_qa = pipeline("question-answering", model=model, tokenizer=tokenizer) +>>> question = "What's my name?" +>>> context = "My name is Philipp and I live in Nuremberg." + +>>> pred = onnx_qa(question=question, context=context) +``` + + +### Optimizing and Quantizing in Pipelines + +The [`pipeline`] can not only run inference on vanilla Onnxruntime checkpoints you can also use checkpoints optimized with `ORTQuantizer` and `ORTOptimizer` +Below you can find two examples on how you could [~`ORTOptimizer`] and [~`ORTQuantizer`] to optimize/quantize your model and use it for inference afterwards. + +### Quantizing with [~`ORTQuantizer`] + +```python +>>> from pathlib import Path +>>> from optimum.onnxruntime import ORTModelForSequenceClassification, ORTQuantizer +>>> from optimum.onnxruntime.configuration import AutoQuantizationConfig +>>> from optimum.pipelines import pipeline +>>> from transformers import AutoTokenizer + +# define model_id and load tokenizer +>>> model_id = "distilbert-base-uncased-finetuned-sst-2-english" +>>> tokenizer = AutoTokenizer.from_pretrained(model_id) +>>> save_path = Path("optimum_model") +>>> save_path.mkdir(exist_ok=True) + +# use ORTQuantizer to export the model and define quantization configuration +>>> quantizer = ORTQuantizer.from_pretrained(model_id, feature="sequence-classification") +>>> qconfig = AutoQuantizationConfig.avx512_vnni(is_static=False, per_channel=True) + +# apply the quantization configuration to the model +>>> quantizer.export( + onnx_model_path=save_path / "model.onnx", + onnx_quantized_model_output_path=save_path / "model-quantized.onnx", + quantization_config=qconfig, + ) +>>> quantizer.model.config.save_pretrained(save_path) # saves config.json + +# load optimized model from local path or repository +>>> model = ORTModelForSequenceClassification.from_pretrained(save_path,file_name="model-quantized.onnx") + +# create transformers pipeline +>>> onnx_clx = pipeline("text-classification", model=model, tokenizer=tokenizer) +>>> text = "I like the new ORT pipeline" +>>> pred = onnx_clx(text) +>>> print(pred) + +# save model & push model to the hub +>>> tokenizer.save_pretrained("new_path_for_directory") +>>> model.save_pretrained("new_path_for_directory") +>>> model.push_to_hub("new_path_for_directory", + repository_id="my-onnx-repo", + use_auth_token=True + ) +``` + +### Optimizing with [~`ORTOptimizer`] + +```python +>>> from pathlib import Path +>>> from optimum.onnxruntime import ORTModelForSequenceClassification, ORTOptimizer +>>> from optimum.onnxruntime.configuration import OptimizationConfig +>>> from optimum.pipelines import pipeline + +# define model_id and load tokenizer +>>> model_id = "distilbert-base-uncased-finetuned-sst-2-english" +>>> tokenizer = AutoTokenizer.from_pretrained(model_id) +>>> save_path = Path("optimum_model") +>>> save_path.mkdir(exist_ok=True) + +# use ORTOptimizer to export the model and define quantization configuration +>>> optimizer = ORTOptimizer.from_pretrained(model_id, feature="sequence-classification") +>>> optimization_config = OptimizationConfig(optimization_level=2) + +# apply the optimization configuration to the model +>>> optimizer.export( + onnx_model_path=save_path / "model.onnx", + onnx_optimized_model_output_path=save_path / "model-optimized.onnx", + optimization_config=optimization_config, +) +>>> optimizer.model.config.save_pretrained(save_path) # saves config.json + +# load optimized model from local path or repository +>>> model = ORTModelForSequenceClassification.from_pretrained(save_path,file_name="model-optimized.onnx") + +# create transformers pipeline +>>> onnx_clx = pipeline("text-classification", model=model, tokenizer=tokenizer) +>>> text = "I like the new ORT pipeline" +>>> pred = onnx_clx(text) +>>> print(pred) + +# save model & push model to the hub +>>> tokenizer.save_pretrained("new_path_for_directory") +>>> model.save_pretrained("new_path_for_directory") +>>> model.push_to_hub("new_path_for_directory", + repository_id="my-onnx-repo", + use_auth_token=True) +``` + +## Transformers pipeline usage + +The [`pipeline`] is just a light wrapper around the `transformers.pipeline` function to enable checks for supported tasks and additional features +, like quantization and optimization. This being said you can use the `transformers.pipeline` and just replace your `AutoFor*` with the optimum + `ORTModelFor*` class. + +```diff +from transformers import AutoTokenizer, pipeline +-from transformers import AutoModelForQuestionAnswering ++from optimum.onnxruntime import ORTModelForQuestionAnswering + +-model = AutoModelForQuestionAnswering.from_pretrained("deepset/roberta-base-squad2") ++model = ORTModelForQuestionAnswering.from_transformers("optimum/roberta-base-squad2") +tokenizer = AutoTokenizer.from_pretrained("deepset/roberta-base-squad2") + +onnx_qa = pipeline("question-answering",model=model,tokenizer=tokenizer) + +question = "What's my name?" +context = "My name is Philipp and I live in Nuremberg." +pred = onnx_qa(question, context) +``` \ No newline at end of file diff --git a/docs/source/quickstart.mdx b/docs/source/quickstart.mdx index dc3ddb63a0..5289d81891 100644 --- a/docs/source/quickstart.mdx +++ b/docs/source/quickstart.mdx @@ -14,7 +14,7 @@ specific language governing permissions and limitations under the License. At its core, 🤗 Optimum uses _configuration objects_ to define parameters for optimization on different accelerators. These objects are then used to instantiate dedicated _optimizers_, _quantizers_, and _pruners_. For example, here's how you can apply dynamic quantization with ONNX Runtime: -```py +```python >>> from optimum.onnxruntime import ORTConfig, ORTQuantizer >>> # The model we wish to quantize @@ -32,7 +32,7 @@ At its core, 🤗 Optimum uses _configuration objects_ to define parameters for ``` In this example, we've quantized a model from the Hugging Face Hub, but it could also be a path to a local model directory. The `feature` argument in the `from_pretrained()` method corresponds to the type of task that we wish to quantize the model for. The result from applying the `export()` method is a `model-quantized.onnx` file that can be used to run inference. Here's an example of how to load an ONNX Runtime model and generate predictions with it: -```py +```python >>> from functools import partial >>> from datasets import Dataset >>> from optimum.onnxruntime import ORTModel @@ -55,13 +55,13 @@ In this example, we've quantized a model from the Hugging Face Hub, but it could Similarly, you can apply static quantization by simply setting `is_static` to `True` when instantiating the `QuantizationConfig` object: -```py +```python >>> qconfig = AutoQuantizationConfig.arm64(is_static=True, per_channel=False) ``` Static quantization relies on feeding batches of data through the model to estimate the activation quantization parameters ahead of inference time. To support this, 🤗 Optimum allows you to provide a _calibration dataset_. The calibration dataset can be a simple `Dataset` object from the 🤗 Datasets library, or any dataset that's hosted on the Hugging Face Hub. For this example, we'll pick the [`sst2`](https://huggingface.co/datasets/glue/viewer/sst2/test) dataset that the model was originally trained on: -```py +```python >>> from optimum.onnxruntime.configuration import AutoCalibrationConfig >>> # Create the calibration dataset @@ -92,7 +92,7 @@ Static quantization relies on feeding batches of data through the model to estim As a final example, let's take a look at applying _graph optimizations_ techniques such as operator fusion and constant folding. As before, we load a configuration object, but this time by setting the optimization level instead of the quantization approach: -```py +```python >>> from optimum.onnxruntime.configuration import OptimizationConfig >>> # optimization_config=99 enables all available graph optimisations @@ -101,7 +101,7 @@ As a final example, let's take a look at applying _graph optimizations_ techniqu Next, we load an _optimizer_ to apply these optimisations to our model: -```py +```python >>> from optimum.onnxruntime import ORTOptimizer >>> optimizer = ORTOptimizer.from_pretrained( diff --git a/optimum/modeling_base.py b/optimum/modeling_base.py new file mode 100644 index 0000000000..9d1f8375ea --- /dev/null +++ b/optimum/modeling_base.py @@ -0,0 +1,232 @@ +import json +import logging +import os +import subprocess +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Optional, Union + +from transformers import AutoConfig + +import requests +from huggingface_hub import HfApi, HfFolder, hf_hub_download + +from .utils import CONFIG_NAME + + +logger = logging.getLogger(__name__) + + +class OptimizedModel(ABC): + config_class = AutoConfig + load_tf_weights = None + base_model_prefix = "optimized_model" + + def __init__(self, model=None, config=None, **kwargs): + super().__init__() + self.model = model + self.config = config + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + @abstractmethod + def forward(self, *args, **kwargs): + """ + Forward pass of the model, needs to be overwritten. + """ + pass + + def save_pretrained( + self, + save_directory: Union[str, os.PathLike], + push_to_hub: bool = False, + **kwargs, + ): + """ + Save a model and its configuration file to a directory, so that it can be re-loaded using the + `[`~OptimizedModel.from_pretrained`]` class method. + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to which to save. Will be created if it doesn't exist. + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether or not to push your model to the Hugging Face model hub after saving it. + + + + Using `push_to_hub=True` will synchronize the repository you are pushing to with `save_directory`, + which requires `save_directory` to be a local clone of the repo you are pushing to if it's an existing + folder. Pass along `temp_dir=True` to use a temporary directory instead. + + + """ + if os.path.isfile(save_directory): + logger.error(f"Provided path ({save_directory}) should be a directory, not a file") + return + + os.makedirs(save_directory, exist_ok=True) + + # Save the config + self.config.save_pretrained(save_directory) + + # saving model weights/files + self._save_pretrained(save_directory, **kwargs) + + if push_to_hub: + return self.push_to_hub(save_directory, **kwargs) + + @abstractmethod + def _save_pretrained(self, save_directory, **kwargs): + """ + Save a model weights into a directory, so that it can be re-loaded using the + `[`~OptimizedModel.from_pretrained`]` class method. + """ + pass + + def push_to_hub( + self, + save_directory: str = None, + repository_id: Optional[str] = None, + private: Optional[bool] = None, + use_auth_token: Optional[Union[bool, str]] = None, + ) -> str: + if isinstance(use_auth_token, str): + huggingface_token = use_auth_token + elif use_auth_token: + huggingface_token = HfFolder.get_token() + else: + raise ValueError("You need to proivde `use_auth_token` to be able to push to the hub") + api = HfApi() + + user = api.whoami(huggingface_token) + self.git_config_username_and_email(git_email=user["email"], git_user=user["fullname"]) + + api.create_repo( + token=huggingface_token, + name=repository_id, + organization=user["name"], + exist_ok=True, + private=private, + ) + for path, subdirs, files in os.walk(save_directory): + for name in files: + local_file_path = os.path.join(path, name) + _, hub_file_path = os.path.split(local_file_path) + # FIXME: when huggingface_hub fixes the return of upload_file + try: + api.upload_file( + token=huggingface_token, + repo_id=f"{user['name']}/{repository_id}", + path_or_fileobj=os.path.join(os.getcwd(), local_file_path), + path_in_repo=hub_file_path, + ) + except KeyError: + pass + except NameError: + pass + + def git_config_username_and_email(self, git_user: str = None, git_email: str = None): + """ + Set git user name and email (only in the current repo) + """ + try: + if git_user is not None: + subprocess.run( + ["git", "config", "--global", "user.name", git_user], + stderr=subprocess.PIPE, + stdout=subprocess.PIPE, + check=True, + encoding="utf-8", + ) + if git_email is not None: + subprocess.run( + ["git", "config", "--global", "user.email", git_email], + stderr=subprocess.PIPE, + stdout=subprocess.PIPE, + check=True, + encoding="utf-8", + ) + except subprocess.CalledProcessError as exc: + raise EnvironmentError(exc.stderr) + + @classmethod + def _from_pretrained( + cls, + model_id: Union[str, os.PathLike], + use_auth_token: Optional[Union[bool, str, None]] = None, + revision: Optional[Union[str, None]] = None, + force_download: bool = True, + cache_dir: Optional[str] = None, + **kwargs, + ): + """Overwrite this method in subclass to define how to load your model from pretrained""" + raise NotImplementedError("Overwrite this method in subclass to define how to load your model from pretrained") + + @classmethod + def from_pretrained( + cls, + model_id: Union[str, Path], + from_transformers: bool = False, + force_download: bool = True, + use_auth_token: Optional[str] = None, + cache_dir: Optional[str] = None, + **model_kwargs, + ): + revision = None + if len(str(model_id).split("@")) == 2: + model_id, revision = model_id.split("@") + + if os.path.isdir(model_id) and CONFIG_NAME in os.listdir(model_id): + config_file = os.path.join(model_id, CONFIG_NAME) + else: + try: + config_file = hf_hub_download( + repo_id=model_id, + filename=CONFIG_NAME, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + use_auth_token=use_auth_token, + ) + except requests.exceptions.RequestException: + logger.warning("config.json NOT FOUND in HuggingFace Hub") + config_file = None + + if config_file is not None: + with open(config_file, "r", encoding="utf-8") as f: + config = json.load(f) + model_kwargs.update({"config": config}) + + if from_transformers: + return cls._from_transformers( + model_id=model_id, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + use_auth_token=use_auth_token, + **model_kwargs, + ) + else: + return cls._from_pretrained( + model_id=model_id, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + use_auth_token=use_auth_token, + **model_kwargs, + ) + + @classmethod + def _from_transformers( + cls, + model_id: Union[str, os.PathLike], + use_auth_token: Optional[Union[bool, str, None]] = None, + revision: Optional[Union[str, None]] = None, + force_download: bool = True, + cache_dir: Optional[str] = None, + **kwargs, + ): + """Overwrite this method in subclass to define how to load your model from vanilla transformers model""" + raise NotImplementedError( + "Overwrite this method in subclass to define how to load your model from vanilla transformers model" + ) diff --git a/optimum/onnxruntime/__init__.py b/optimum/onnxruntime/__init__.py index 87380f5473..20ae1283e5 100644 --- a/optimum/onnxruntime/__init__.py +++ b/optimum/onnxruntime/__init__.py @@ -50,7 +50,15 @@ class ORTQuantizableOperator(Enum): from .configuration import ORTConfig from .model import ORTModel +from .modeling_ort import ( + ORTModelForCausalLM, + ORTModelForFeatureExtraction, + ORTModelForQuestionAnswering, + ORTModelForSequenceClassification, + ORTModelForTokenClassification, +) from .optimization import ORTOptimizer from .quantization import ORTQuantizer from .trainer import ORTTrainer from .trainer_seq2seq import ORTSeq2SeqTrainer +from .utils import ONNX_WEIGHTS_NAME diff --git a/optimum/onnxruntime/modeling_ort.py b/optimum/onnxruntime/modeling_ort.py new file mode 100644 index 0000000000..054e7b77e5 --- /dev/null +++ b/optimum/onnxruntime/modeling_ort.py @@ -0,0 +1,689 @@ +import logging +import os +import shutil +from pathlib import Path +from typing import Any, Dict, Optional, Union + +import torch +from transformers import AutoTokenizer, PretrainedConfig +from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, default_cache_path +from transformers.generation_utils import GenerationMixin +from transformers.modeling_outputs import ( + BaseModelOutput, + CausalLMOutputWithCrossAttentions, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from transformers.onnx import FeaturesManager, export + +import onnxruntime as ort +from huggingface_hub import HfApi, hf_hub_download + +from ..modeling_base import OptimizedModel +from .utils import ONNX_WEIGHTS_NAME, _is_gpu_available + + +logger = logging.getLogger(__name__) + + +_TOKENIZER_FOR_DOC = "AutoTokenizer" + +ONNX_MODEL_START_DOCSTRING = r""" + This model inherits from [`ORTModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving) + Parameters: + config ([`PretrainedConfig`](https://huggingface.co/docs/transformers/main_classes/configuration#transformers.PretrainedConfig)): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~ORTModel.from_pretrained`] method to load the model weights. + model ([`onnxruntime.InferenceSession`](https://onnxruntime.ai/docs/api/python/api_summary.html#inferencesession)): This is the main class used to run a model. Check out the [`~ORTModel.load_model`] + for more information. +""" + +ONNX_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.Tensor` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + Indices can be obtained using [`AutoTokenizer`](https://huggingface.co/docs/transformers/autoclass_tutorial#autotokenizer). + See [`PreTrainedTokenizer.encode`](https://huggingface.co/docs/transformers/main_classes/tokenizer#transformers.PreTrainedTokenizerBase.encode) and + [`PreTrainedTokenizer.__call__`](https://huggingface.co/docs/transformers/main_classes/tokenizer#transformers.PreTrainedTokenizerBase.__call__) for details. + [What are input IDs?](https://huggingface.co/docs/transformers/glossary#input-ids) + attention_mask (`torch.Tensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](https://huggingface.co/docs/transformers/glossary#attention-mask) + token_type_ids (`torch.Tensor` of shape `({0})`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, 1]`: + - 1 for tokens that are **sentence A**, + - 0 for tokens that are **sentence B**. + [What are token type IDs?](https://huggingface.co/docs/transformers/glossary#token-type-ids) +""" + + +@add_start_docstrings( + """ + Base ORTModel class for implementing models using ONNX Runtime. The ORTModel implements generic methods for interacting + with the Hugging Face Hub as well as exporting vanilla transformers models to ONNX using `transformers.onnx` toolchain. + The ORTModel implements additionally generic methods for optimizing and quantizing Onnx models. + """, +) +class ORTModel(OptimizedModel): + base_model_prefix = "onnx_model" + + def __init__(self, model=None, config=None, **kwargs): + self.model = model + self.config = config + self.model_save_dir = kwargs.get("model_save_dir", None) + self.latest_model_name = kwargs.get("latest_model_name", "model.onnx") + + def forward(self, *args, **kwargs): + raise NotImplementedError + + @staticmethod + def load_model(path: Union[str, Path], provider=None): + """ + loads ONNX Inference session with Provider. Default Provider is if CUDAExecutionProvider GPU available else `CPUExecutionProvider` + Arguments: + path (:obj:`str` or :obj:`Path`): + Directory from which to load + provider(:obj:`str`): + Onnxruntime provider to use for loading the model, defaults to `CUDAExecutionProvider` if GPU is + available else `CPUExecutionProvider` + """ + if provider is None: + provider = "CUDAExecutionProvider" if _is_gpu_available() else "CPUExecutionProvider" + + return ort.InferenceSession(path, providers=[provider]) + + def _save_pretrained(self, save_directory: Union[str, Path], file_name: Optional[str] = None, **kwargs): + """ + Save a model and its configuration file to a directory, so that it can be re-loaded using the + `:func:`~optimum.onnxruntime.modeling_ort.ORTModel.from_pretrained`` class method. It will always save the latest_model_name. + Arguments: + save_directory (:obj:`str` or :obj:`Path`): + Directory where to save the model file. + file_name(:obj:`str`): + Overwrites the default model file name from `"model.onnx"` to `file_name`. This allows you to save the model with + a different name. + """ + model_file_name = file_name if file_name is not None else ONNX_WEIGHTS_NAME + + src_path = self.model_save_dir.joinpath(self.latest_model_name) + dst_path = Path(save_directory).joinpath(model_file_name) + shutil.copyfile(src_path, dst_path) + + @classmethod + def _from_pretrained( + cls, + model_id: Union[str, Path], + use_auth_token: Optional[Union[bool, str, None]] = None, + revision: Optional[Union[str, None]] = None, + force_download: bool = True, + cache_dir: Optional[str] = None, + file_name: Optional[str] = None, + **kwargs, + ): + """ + Load a model and its configuration file from a directory or the HF Hub. + Implements: https://github.com/huggingface/huggingface_hub/blob/e67de48368bc1843e40afc1cc9d236402b9609ee/src/huggingface_hub/hub_mixin.py#L73 + Arguments: + model_id (:obj:`str` or :obj:`Path`): + Directory from which to load + use_auth_token (:obj:`str` or :obj:`bool`): + Is needed to load models from a private repository + revision (:obj:`str`): + Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id + cache_dir (:obj:`Union[str, Path]`, `optional`): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + force_download (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + file_name(:obj:`str`): + Overwrites the default model file name from `"model.onnx"` to `file_name`. This allows you to load different model files from the same + repository or directory. + kwargs (:obj:`Dict`, `optional`):: + kwargs will be passed to the model during initialization + """ + config_dict = kwargs.pop("config", {}) + model_file_name = file_name if file_name is not None else ONNX_WEIGHTS_NAME + # load model from local directory + if os.path.isdir(model_id): + config = PretrainedConfig.from_dict(config_dict) + model = ORTModel.load_model(os.path.join(model_id, model_file_name)) + kwargs["model_save_dir"] = Path(model_id) + # load model from hub + else: + # download model + model_cache_path = hf_hub_download( + repo_id=model_id, + filename=model_file_name, + use_auth_token=use_auth_token, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + ) + kwargs["model_save_dir"] = Path(model_cache_path).parent + kwargs["latest_model_name"] = Path(model_cache_path).name + model = ORTModel.load_model(model_cache_path) + config = PretrainedConfig.from_dict(config_dict) + return cls(model=model, config=config, **kwargs) + + @classmethod + def _from_transformers( + cls, + model_id: str, + save_dir: Union[str, Path] = default_cache_path, + use_auth_token: Optional[Union[bool, str, None]] = None, + revision: Optional[Union[str, None]] = None, + force_download: bool = True, + cache_dir: Optional[str] = None, + **kwargs, + ): + """ + Converts a vanilla Transformers model into an optimized model using `transformers.onnx.export_onnx`. + Arguments: + model_id (:obj:`str` or :obj:`Path`): + Directory from which to load + save_dir (:obj:`str` or :obj:`Path`): + Directory where the onnx model should be saved, default to `transformers.file_utils.default_cache_path`, which is the cache dir for + transformers. + use_auth_token (:obj:`str` or :obj:`bool`): + Is needed to load models from a private repository + revision (:obj:`str`): + Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id + cache_dir (:obj:`Union[str, Path]`, `optional`): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + force_download (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + kwargs (:obj:`Dict`, `optional`):: + kwargs will be passed to the model during initialization + """ + + # create local save dir in cache dir + save_dir = Path(save_dir).joinpath(model_id) + save_dir.mkdir(parents=True, exist_ok=True) + kwargs["model_save_dir"] = save_dir + + # reads pipeline task from ORTModelForXXX class if available else tries to extract from hub + if cls.pipeline_task is not None: + task = cls.pipeline_task + else: + task = HfApi().model_info(model_id, revision=revision).pipeline_tag + if task in ["sentiment-analysis", "text-classification", "zero-shot-classification"]: + task = "sequence-classification" + elif task in ["feature-extraction", "fill-mask"]: + task = "default" + # 2. convert to temp dir + # FIXME: transformers.onnx conversion doesn't support private models + tokenizer = AutoTokenizer.from_pretrained(model_id) + model = FeaturesManager.get_model_from_feature(task, model_id) + _, model_onnx_config = FeaturesManager.check_supported_model_or_raise(model, feature=task) + onnx_config = model_onnx_config(model.config) + + # export model + export( + preprocessor=tokenizer, + model=model, + config=onnx_config, + opset=onnx_config.default_onnx_opset, + output=save_dir.joinpath(ONNX_WEIGHTS_NAME), + ) + kwargs["config"] = model.config.__dict__ + # 3. load normal model + return cls._from_pretrained(save_dir.as_posix(), **kwargs) + + +FEAUTRE_EXTRACTION_SAMPLE = r""" + Example of feature extraction: + + ```python + >>> from transformers import {processor_class} + >>> from optimum.onnxruntime import {model_class} + >>> import torch + + >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}") + >>> model = {model_class}.from_pretrained("{checkpoint}") + + >>> inputs = tokenizer("My name is Philipp and I live in Germany.", return_tensors="pt") + + >>> outputs = model(**inputs) + >>> logits = outputs.logits + >>> list(logits.shape) + ``` + + Example using `transformers.pipeline`: + + ```python + >>> from transformers import {processor_class}, pipeline + >>> from optimum.onnxruntime import {model_class} + + >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}") + >>> model = {model_class}.from_pretrained("{checkpoint}") + >>> onnx_extractor = pipeline("feature-extraction", model=model, tokenizer=tokenizer) + + >>> text = "My name is Philipp and I live in Germany." + >>> pred = onnx_extractor(text) + ``` +""" + + +@add_start_docstrings( + """ + Onnx Model with a MaskedLMOutput for feature-extraction tasks. + """, + ONNX_MODEL_START_DOCSTRING, +) +class ORTModelForFeatureExtraction(ORTModel): + """ + Feature Extraction model for ONNX. + """ + + # used in from_transformers to export model to onnx + pipeline_task = "default" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # create {name:idx} dict for model outputs + self.model_outputs = {output_key.name: idx for idx, output_key in enumerate(self.model.get_outputs())} + + @add_start_docstrings_to_model_forward( + ONNX_INPUTS_DOCSTRING.format("batch_size, sequence_length") + + FEAUTRE_EXTRACTION_SAMPLE.format( + processor_class=_TOKENIZER_FOR_DOC, + model_class="ORTModelForFeatureExtraction", + checkpoint="optimum/all-MiniLM-L6-v2", + ) + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + **kwargs, + ): + # converts pytorch inputs into numpy inputs for onnx + onnx_inputs = { + "input_ids": input_ids.cpu().detach().numpy(), + "attention_mask": attention_mask.cpu().detach().numpy(), + } + if token_type_ids is not None: + onnx_inputs["token_type_ids"] = token_type_ids.cpu().detach().numpy() + # run inference + outputs = self.model.run(None, onnx_inputs) + # converts output to namedtuple for pipelines post-processing + return BaseModelOutput( + last_hidden_state=torch.from_numpy(outputs[self.model_outputs["last_hidden_state"]]), + ) + + +QUESTION_ANSWERING_SAMPLE = r""" + Example of question answering: + + ```python + >>> from transformers import {processor_class} + >>> from optimum.onnxruntime import {model_class} + >>> import torch + + >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}") + >>> model = {model_class}.from_pretrained("{checkpoint}") + + >>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet" + >>> inputs = tokenizer(question, text, return_tensors="pt") + >>> start_positions = torch.tensor([1]) + >>> end_positions = torch.tensor([3]) + + >>> outputs = model(**inputs, start_positions=start_positions, end_positions=end_positions) + >>> start_scores = outputs.start_logits + >>> end_scores = outputs.end_logits + ``` + Example using `transformers.pipeline`: + + ```python + >>> from transformers import {processor_class}, pipeline + >>> from optimum.onnxruntime import {model_class} + + >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}") + >>> model = {model_class}.from_pretrained("{checkpoint}") + >>> onnx_qa = pipeline("question-answering", model=model, tokenizer=tokenizer) + + >>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet" + >>> pred = onnx_qa(question, text) + ``` +""" + + +@add_start_docstrings( + """ + Onnx Model with a QuestionAnsweringModelOutput for extractive question-answering tasks like SQuAD. + """, + ONNX_MODEL_START_DOCSTRING, +) +class ORTModelForQuestionAnswering(ORTModel): + """ + Question Answering model for ONNX. + """ + + # used in from_transformers to export model to onnx + pipeline_task = "question-answering" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # create {name:idx} dict for model outputs + self.model_outputs = {output_key.name: idx for idx, output_key in enumerate(self.model.get_outputs())} + + @add_start_docstrings_to_model_forward( + ONNX_INPUTS_DOCSTRING.format("batch_size, sequence_length") + + QUESTION_ANSWERING_SAMPLE.format( + processor_class=_TOKENIZER_FOR_DOC, + model_class="ORTModelForQuestionAnswering", + checkpoint="optimum/roberta-base-squad2", + ) + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + **kwargs, + ): + # converts pytorch inputs into numpy inputs for onnx + onnx_inputs = { + "input_ids": input_ids.cpu().detach().numpy(), + "attention_mask": attention_mask.cpu().detach().numpy(), + } + if token_type_ids is not None: + onnx_inputs["token_type_ids"] = token_type_ids.cpu().detach().numpy() + # run inference + outputs = self.model.run(None, onnx_inputs) + # converts output to namedtuple for pipelines post-processing + return QuestionAnsweringModelOutput( + start_logits=torch.from_numpy(outputs[self.model_outputs["start_logits"]]), + end_logits=torch.from_numpy(outputs[self.model_outputs["end_logits"]]), + ) + + +SEQUENCE_CLASSIFICATION_SAMPLE = r""" + Example of single-label classification: + + ```python + >>> from transformers import {processor_class} + >>> from optimum.onnxruntime import {model_class} + >>> import torch + + >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}") + >>> model = {model_class}.from_pretrained("{checkpoint}") + + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + + >>> outputs = model(**inputs) + >>> logits = outputs.logits + >>> list(logits.shape) + ``` + + Example using `transformers.pipelines`: + + ```python + >>> from transformers import {processor_class}, pipeline + >>> from optimum.onnxruntime import {model_class} + + >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}") + >>> model = {model_class}.from_pretrained("{checkpoint}") + >>> onnx_classifier = pipeline("text-classification", model=model, tokenizer=tokenizer) + + >>> text = "Hello, my dog is cute" + >>> pred = onnx_classifier(text) + ``` + + Example using zero-shot-classification `transformers.pipelines`: + + ```python + >>> from transformers import {processor_class}, pipeline + >>> from optimum.onnxruntime import {model_class} + + >>> tokenizer = {processor_class}.from_pretrained("optimum/distilbert-base-uncased-mnli") + >>> model = {model_class}.from_pretrained("optimum/distilbert-base-uncased-mnli") + >>> onnx_z0 = pipeline("zero-shot-classification", model=model, tokenizer=tokenizer) + + >>> sequence_to_classify = "Who are you voting for in 2020?" + >>> candidate_labels = ["Europe", "public health", "politics", "elections"] + >>> pred = onnx_z0(sequence_to_classify, candidate_labels, multi_class=True) + ``` +""" + + +@add_start_docstrings( + """ + Onnx Model with a sequence classification/regression head on top (a linear layer on top of the + pooled output) e.g. for GLUE tasks. + """, + ONNX_MODEL_START_DOCSTRING, +) +class ORTModelForSequenceClassification(ORTModel): + """ + Sequence Classification model for ONNX. + """ + + # used in from_transformers to export model to onnx + pipeline_task = "sequence-classification" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # create {name:idx} dict for model outputs + self.model_outputs = {output_key.name: idx for idx, output_key in enumerate(self.model.get_outputs())} + self.model_inputs = {output_key.name: idx for idx, output_key in enumerate(self.model.get_inputs())} + + @add_start_docstrings_to_model_forward( + ONNX_INPUTS_DOCSTRING.format("batch_size, sequence_length") + + SEQUENCE_CLASSIFICATION_SAMPLE.format( + processor_class=_TOKENIZER_FOR_DOC, + model_class="ORTModelForSequenceClassification", + checkpoint="optimum/distilbert-base-uncased-finetuned-sst-2-english", + ) + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + **kwargs, + ): + # converts pytorch inputs into numpy inputs for onnx + onnx_inputs = { + "input_ids": input_ids.cpu().detach().numpy(), + "attention_mask": attention_mask.cpu().detach().numpy(), + } + + if token_type_ids is not None: + onnx_inputs["token_type_ids"] = token_type_ids.cpu().detach().numpy() + # run inference + outputs = self.model.run(None, onnx_inputs) + # converts output to namedtuple for pipelines post-processing + return SequenceClassifierOutput( + logits=torch.from_numpy(outputs[self.model_outputs["logits"]]), + ) + + +TOKEN_CLASSIFICATION_SAMPLE = r""" + Example of token classification: + + ```python + >>> from transformers import {processor_class} + >>> from optimum.onnxruntime import {model_class} + >>> import torch + + >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}") + >>> model = {model_class}.from_pretrained("{checkpoint}") + + >>> inputs = tokenizer("My name is Philipp and I live in Germany.", return_tensors="pt") + + >>> outputs = model(**inputs) + >>> logits = outputs.logits + >>> list(logits.shape) + ``` + + Example using `transformers.pipelines`: + + ```python + >>> from transformers import {processor_class}, pipeline + >>> from optimum.onnxruntime import {model_class} + + >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}") + >>> model = {model_class}.from_pretrained("{checkpoint}") + >>> onnx_ner = pipeline("token-classification", model=model, tokenizer=tokenizer) + + >>> text = "My name is Philipp and I live in Germany." + >>> pred = onnx_ner(text) + ``` +""" + + +@add_start_docstrings( + """ + Onnx Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. + for Named-Entity-Recognition (NER) tasks. + """, + ONNX_MODEL_START_DOCSTRING, +) +class ORTModelForTokenClassification(ORTModel): + """ + Token Classification model for ONNX. + """ + + # used in from_transformers to export model to onnx + pipeline_task = "token-classification" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # create {name:idx} dict for model outputs + self.model_outputs = {output_key.name: idx for idx, output_key in enumerate(self.model.get_outputs())} + + @add_start_docstrings_to_model_forward( + ONNX_INPUTS_DOCSTRING.format("batch_size, sequence_length") + + TOKEN_CLASSIFICATION_SAMPLE.format( + processor_class=_TOKENIZER_FOR_DOC, + model_class="ORTModelForTokenClassification", + checkpoint="optimum/bert-base-NER", + ) + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + **kwargs, + ): + # converts pytorch inputs into numpy inputs for onnx + onnx_inputs = { + "input_ids": input_ids.cpu().detach().numpy(), + "attention_mask": attention_mask.cpu().detach().numpy(), + } + if token_type_ids is not None: + onnx_inputs["token_type_ids"] = token_type_ids.cpu().detach().numpy() + # run inference + outputs = self.model.run(None, onnx_inputs) + # converts output to namedtuple for pipelines post-processing + return TokenClassifierOutput( + logits=torch.from_numpy(outputs[self.model_outputs["logits"]]), + ) + + +TEXT_GENERATION_SAMPLE = r""" + Example of text generation: + + ```python + >>> from transformers import {processor_class} + >>> from optimum.onnxruntime import {model_class} + >>> import torch + + >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}") + >>> model = {model_class}.from_pretrained("{checkpoint}") + + >>> inputs = tokenizer("My name is Philipp and I live in Germany.", return_tensors="pt") + + >>> gen_tokens = model.generate(**inputs,do_sample=True,temperature=0.9, min_length=20,max_length=20) + >>> tokenizer.batch_decode(gen_tokens) + ``` + + Example using `transformers.pipelines`: + + ```python + >>> from transformers import {processor_class}, pipeline + >>> from optimum.onnxruntime import {model_class} + + >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}") + >>> model = {model_class}.from_pretrained("{checkpoint}") + >>> onnx_gen = pipeline("text-generation", model=model, tokenizer=tokenizer) + + >>> text = "My name is Philipp and I live in Germany." + >>> gen = onnx_gen(text) + ``` +""" + + +@add_start_docstrings( + """ + Onnx Model with a causal language modeling head on top (linear layer with weights tied to the input + embeddings). + """, + ONNX_MODEL_START_DOCSTRING, +) +class ORTModelForCausalLM(ORTModel, GenerationMixin): + """ + Causal LM model for ONNX. + """ + + # used in from_transformers to export model to onnx + pipeline_task = "causal-lm" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # create {name:idx} dict for model outputs + self.main_input_name = "input_ids" + self.model_outputs = {output_key.name: idx for idx, output_key in enumerate(self.model.get_outputs())} + + @property + def device(self) -> torch.device: + """ + `torch.device`: The device on which the module is (assuming that all the module parameters are on the same + device). + """ + return torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> Dict[str, Any]: + """ + Implement in subclasses of [`PreTrainedModel`] for custom behavior to prepare inputs in the generate method. + """ + inputs = {"input_ids": input_ids} + if kwargs.get("attention_mask", None) is not None: + inputs["attention_mask"] = kwargs["attention_mask"] + return inputs + + @add_start_docstrings_to_model_forward( + ONNX_INPUTS_DOCSTRING.format("batch_size, sequence_length") + + TOKEN_CLASSIFICATION_SAMPLE.format( + processor_class=_TOKENIZER_FOR_DOC, + model_class="ORTModelForCausalLM", + checkpoint="optimum/gpt2", + ) + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + **kwargs, + ): + # converts pytorch inputs into numpy inputs for onnx + onnx_inputs = { + "input_ids": input_ids.cpu().detach().numpy(), + "attention_mask": attention_mask.cpu().detach().numpy(), + } + # run inference + outputs = self.model.run(None, onnx_inputs) + # converts output to namedtuple for pipelines post-processing + return CausalLMOutputWithCrossAttentions( + logits=torch.from_numpy(outputs[self.model_outputs["logits"]]), + ) diff --git a/optimum/onnxruntime/utils.py b/optimum/onnxruntime/utils.py index 24a0322168..e2f6210940 100644 --- a/optimum/onnxruntime/utils.py +++ b/optimum/onnxruntime/utils.py @@ -16,14 +16,31 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, DefaultDict, Dict, List, Optional, Set, Tuple, Union +import torch from transformers.utils import logging import onnx +import onnxruntime as ort from onnx import ModelProto logger = logging.get_logger(__name__) +ONNX_WEIGHTS_NAME = "model.onnx" +OPTIMIZED_ONNX_WEIGHTS_NAME = "optimized_model.onnx" +QUANTIZED_ONNX_WEIGHTS_NAME = "q8_model.onnx" + + +def _is_gpu_available(): + """ + checks if a gpu is available. + """ + available_providers = ort.get_available_providers() + if "CUDAExecutionProvider" in available_providers and torch.cuda.is_available(): + return True + else: + return False + class ORTConfigManager: """ @@ -39,7 +56,7 @@ class ORTConfigManager: "bert": ("num_attention_heads", "hidden_size", "bert"), "albert": ("num_attention_heads", "hidden_size", "bert"), "camembert": ("num_attention_heads", "hidden_size", "bert"), - "distilbert": ("n_heads", "hidden_size", "bert"), + "distilbert": ("n_heads", "dim", "bert"), "electra": ("num_attention_heads", "hidden_size", "bert"), "roberta": ("num_attention_heads", "hidden_size", "bert"), "bart": ("encoder_attention_heads", "d_model", "bart"), diff --git a/optimum/pipelines.py b/optimum/pipelines.py new file mode 100644 index 0000000000..4cb6ff868f --- /dev/null +++ b/optimum/pipelines.py @@ -0,0 +1,106 @@ +from typing import Any, Optional, Union + +from transformers import ( + AutoTokenizer, + FeatureExtractionPipeline, + Pipeline, + PreTrainedTokenizer, + QuestionAnsweringPipeline, + TextClassificationPipeline, + TextGenerationPipeline, + TokenClassificationPipeline, + ZeroShotClassificationPipeline, +) +from transformers import pipeline as transformers_pipeline +from transformers.feature_extraction_utils import PreTrainedFeatureExtractor + +from optimum.onnxruntime.modeling_ort import ORTModelForCausalLM +from optimum.utils import is_onnxruntime_available + + +SUPPORTED_TASKS = {} + +if is_onnxruntime_available(): + from optimum.onnxruntime import ( + ORTModelForFeatureExtraction, + ORTModelForQuestionAnswering, + ORTModelForSequenceClassification, + ORTModelForTokenClassification, + ) + from optimum.onnxruntime.modeling_ort import ORTModel + + SUPPORTED_TASKS = { + "feature-extraction": { + "impl": FeatureExtractionPipeline, + "class": (ORTModelForFeatureExtraction,) if is_onnxruntime_available() else (), + "default": "distilbert-base-cased", + }, + "text-classification": { + "impl": TextClassificationPipeline, + "class": (ORTModelForSequenceClassification,) if is_onnxruntime_available() else (), + "default": "distilbert-base-uncased-finetuned-sst-2-english", + }, + "token-classification": { + "impl": TokenClassificationPipeline, + "class": (ORTModelForTokenClassification,) if is_onnxruntime_available() else (), + "default": "dbmdz/bert-large-cased-finetuned-conll03-english", + }, + "question-answering": { + "impl": QuestionAnsweringPipeline, + "class": (ORTModelForQuestionAnswering,) if is_onnxruntime_available() else (), + "default": "distilbert-base-cased-distilled-squad", + }, + "zero-shot-classification": { + "impl": ZeroShotClassificationPipeline, + "class": (ORTModelForSequenceClassification,) if is_onnxruntime_available() else (), + "default": "facebook/bart-large-mnli", + }, + "text-generation": { + "impl": TextGenerationPipeline, + "class": (ORTModelForCausalLM,) if is_onnxruntime_available() else (), + "default": "distilgpt2", + }, + } + + +def pipeline( + task: str = None, + model: Optional[Any] = None, + tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None, + feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, + use_fast: bool = True, + use_auth_token: Optional[Union[str, bool]] = None, + accelerator: Optional[str] = "ort", + **kwargs, +) -> Pipeline: + + if task not in list(SUPPORTED_TASKS.keys()): + raise ValueError(f"Task {task} is not supported. Supported tasks are { list(SUPPORTED_TASKS.keys())}") + + if accelerator != "ort": + raise ValueError(f"Accelerator {accelerator} is not supported. Supported accelerators are ort") + + if model is None: + model_id = SUPPORTED_TASKS[task]["default"] + model = SUPPORTED_TASKS[task]["class"][0].from_pretrained(model_id, from_transformers=True) + elif isinstance(model, str): + model_id = model + model = SUPPORTED_TASKS[task]["class"][0].from_pretrained(model, from_transformers=True) + elif isinstance(model, ORTModel): + if tokenizer is None: + raise ValueError("If you pass a model as a ORTModel, you must pass a tokenizer as well") + else: + raise ValueError( + f"""Model {model} is not supported. Please provide a valid model either as string or ORTModel. + You can also provide non model then a default one will be used""" + ) + if tokenizer is None: + tokenizer = AutoTokenizer.from_pretrained(model_id) + + return transformers_pipeline( + task, + model=model, + tokenizer=tokenizer, + use_fast=use_fast, + **kwargs, + ) diff --git a/optimum/utils/__init__.py b/optimum/utils/__init__.py index 4823278c4d..9152e2ff7f 100644 --- a/optimum/utils/__init__.py +++ b/optimum/utils/__init__.py @@ -11,3 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import importlib.util + + +CONFIG_NAME = "config.json" + +_onnxruntime_available = importlib.util.find_spec("onnxruntime") is not None + + +def is_onnxruntime_available(): + return _onnxruntime_available diff --git a/optimum/utils/testing_utils.py b/optimum/utils/testing_utils.py new file mode 100644 index 0000000000..166f6fe9e4 --- /dev/null +++ b/optimum/utils/testing_utils.py @@ -0,0 +1,13 @@ +import os +import unittest + + +def require_hf_token(test_case): + """ + Decorator marking a test that requires huggingface hub token. + """ + use_auth_token = os.environ.get("HF_AUTH_TOKEN", None) + if use_auth_token is None: + return unittest.skip("test requires hf token as `HF_AUTH_TOKEN` environment variable")(test_case) + else: + return test_case diff --git a/setup.py b/setup.py index 79efdd73f6..9a57e0a60b 100644 --- a/setup.py +++ b/setup.py @@ -12,14 +12,24 @@ assert False, "Error: Could not open '%s' due %s\n" % (filepath, error) -REQUIRED_PKGS = ["coloredlogs", "sympy", "transformers>=4.15.0", "torch>=1.9", "packaging"] +REQUIRED_PKGS = [ + "coloredlogs", + "sympy", + "transformers[sentencepiece]>=4.15.0", + "torch>=1.9", + "packaging", + "numpy", + "huggingface_hub==0.4.0", +] -TESTS_REQUIRE = ["pytest"] +TESTS_REQUIRE = ["pytest", "requests", "parameterized", "pytest-xdist"] QUALITY_REQUIRE = ["black~=22.0", "flake8>=3.8.3", "isort>=5.5.4"] EXTRAS_REQUIRE = { - "onnxruntime": ["onnx", "onnxruntime", "datasets>=1.2.1"], + # pip install -e ".[onnxruntime,dev,intel]" git+https://github.com/huggingface/transformers.git@main --upgrade + "onnxruntime": ["onnx", "onnxruntime", "datasets>=1.2.1"], # "transformers[sentencepiece]>4.17.0"], + "onnxruntime-gpu": ["onnx", "onnxruntime-gpu", "datasets>=1.2.1"], # "transformers[sentencepiece]>4.17.0"], "intel": [ "pycocotools", "neural_compressor>=1.9", @@ -62,6 +72,7 @@ packages=find_namespace_packages(include=["optimum*"]), install_requires=REQUIRED_PKGS, extras_require=EXTRAS_REQUIRE, + python_requires=">=3.8.0", include_package_data=True, zip_safe=False, ) diff --git a/tests/README.md b/tests/README.md new file mode 100644 index 0000000000..6ee77778c6 --- /dev/null +++ b/tests/README.md @@ -0,0 +1,18 @@ +# Helpful tips for testing & debugging optimum + +## VSCODE + +If you are using vscode you might have hare time discovering the test for the "testing" menu to run tests individually or debug them. You can copy the snippet below into `.vscode/settings.json`. + +```json +{ + "python.testing.pytestArgs": [ + "tests/onnxruntime", + "tests/test_*" + ], + "python.testing.unittestEnabled": false, + "python.testing.pytestEnabled": true +} +``` + +This snippet will discover all base tests and the tests inside the `tests/onnxruntime` folder. If you want the `intel` tests as well add them. \ No newline at end of file diff --git a/tests/assets/hub/config.json b/tests/assets/hub/config.json new file mode 100644 index 0000000000..5fe5579f21 --- /dev/null +++ b/tests/assets/hub/config.json @@ -0,0 +1 @@ +{"from_local":true} \ No newline at end of file diff --git a/tests/assets/onnx/config.json b/tests/assets/onnx/config.json new file mode 100644 index 0000000000..a7dcfd24b0 --- /dev/null +++ b/tests/assets/onnx/config.json @@ -0,0 +1,34 @@ +{ + "_name_or_path": "tiny-distilbert-classification", + "activation": "gelu", + "architectures": [ + "DistilBertForSequenceClassification" + ], + "attention_dropout": 0.1, + "dim": 2, + "dropout": 0.1, + "finetuning_task": "sst-2", + "hidden_dim": 2, + "id2label": { + "0": "NEGATIVE", + "1": "POSITIVE" + }, + "initializer_range": 0.02, + "label2id": { + "NEGATIVE": 0, + "POSITIVE": 1 + }, + "max_position_embeddings": 512, + "model_type": "distilbert", + "n_heads": 2, + "n_layers": 2, + "output_past": true, + "pad_token_id": 0, + "qa_dropout": 0.1, + "seq_classif_dropout": 0.2, + "sinusoidal_pos_embds": false, + "tie_weights_": true, + "torch_dtype": "float32", + "transformers_version": "4.10.0.dev0", + "vocab_size": 30522 +} diff --git a/tests/assets/onnx/model.onnx b/tests/assets/onnx/model.onnx new file mode 100644 index 0000000000..7ee2547229 Binary files /dev/null and b/tests/assets/onnx/model.onnx differ diff --git a/tests/onnxruntime/test_modeling_ort.py b/tests/onnxruntime/test_modeling_ort.py new file mode 100644 index 0000000000..968ad5601d --- /dev/null +++ b/tests/onnxruntime/test_modeling_ort.py @@ -0,0 +1,467 @@ +import os +import tempfile +import unittest + +import torch +from transformers import ( + AutoModel, + AutoModelForCausalLM, + AutoModelForQuestionAnswering, + AutoModelForSequenceClassification, + AutoModelForTokenClassification, + AutoTokenizer, + PretrainedConfig, + pipeline, +) + +import onnxruntime +from optimum.onnxruntime import ( + ONNX_WEIGHTS_NAME, + ORTModelForCausalLM, + ORTModelForFeatureExtraction, + ORTModelForQuestionAnswering, + ORTModelForSequenceClassification, + ORTModelForTokenClassification, +) +from optimum.onnxruntime.modeling_ort import ORTModel +from optimum.utils import CONFIG_NAME +from optimum.utils.testing_utils import require_hf_token +from parameterized import parameterized + + +class ORTModelIntergrationTest(unittest.TestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.TEST_MODEL_ID = "sshleifer/tiny-distilbert-base-cased-distilled-squad" + self.LOCAL_MODEL_PATH = "tests/assets/onnx" + self.ONNX_MODEL_ID = "philschmid/distilbert-onnx" + self.FAIL_ONNX_MODEL_ID = "sshleifer/tiny-distilbert-base-cased-distilled-squad" + + def test_load_model_from_local_path(self): + model = ORTModel.from_pretrained(self.LOCAL_MODEL_PATH) + self.assertIsInstance(model.model, onnxruntime.capi.onnxruntime_inference_collection.InferenceSession) + self.assertIsInstance(model.config, PretrainedConfig) + + def test_load_model_from_hub(self): + model = ORTModel.from_pretrained(self.ONNX_MODEL_ID) + self.assertIsInstance(model.model, onnxruntime.capi.onnxruntime_inference_collection.InferenceSession) + self.assertIsInstance(model.config, PretrainedConfig) + + def test_load_model_from_hub_without_onnx_model(self): + with self.assertRaises(Exception) as context: + ORTModel.from_pretrained(self.FAIL_ONNX_MODEL_ID) + self.assertEqual("Not Found", context.exception.response.reason) + + @require_hf_token + def test_load_model_from_hub_private(self): + model = ORTModel.from_pretrained(self.ONNX_MODEL_ID, use_auth_token=os.environ.get("HF_AUTH_TOKEN", None)) + self.assertIsInstance(model.model, onnxruntime.capi.onnxruntime_inference_collection.InferenceSession) + self.assertIsInstance(model.config, PretrainedConfig) + + def test_save_model(self): + with tempfile.TemporaryDirectory() as tmpdirname: + model = ORTModel.from_pretrained(self.LOCAL_MODEL_PATH) + model.save_pretrained(tmpdirname) + # folder contains all config files and pytorch_model.bin + folder_contents = os.listdir(tmpdirname) + self.assertTrue(ONNX_WEIGHTS_NAME in folder_contents) + self.assertTrue(CONFIG_NAME in folder_contents) + + @require_hf_token + def test_save_model_from_hub(self): + with tempfile.TemporaryDirectory() as tmpdirname: + model = ORTModel.from_pretrained(self.LOCAL_MODEL_PATH) + model.save_pretrained( + tmpdirname, + use_auth_token=os.environ.get("HF_AUTH_TOKEN", None), + push_to_hub=True, + repository_id=self.HUB_REPOSITORY, + private=True, + ) + + +class ORTModelForQuestionAnsweringIntergrationTest(unittest.TestCase): + SUPPORTED_ARCHITECTURES_WITH_MODEL_ID = { + "distilbert": "hf-internal-testing/tiny-random-distilbert", + "bert": "hf-internal-testing/tiny-random-bert", + # FIXME: Error: ONNX export failed: Couldn't export Python operator SymmetricQuantFunction + # "ibert": "hf-internal-testing/tiny-random-ibert", + "camembert": "etalab-ia/camembert-base-squadFR-fquad-piaf", + "roberta": "hf-internal-testing/tiny-random-roberta", + # TODO: used real model do to big difference in output + # "xlm-roberta": "hf-internal-testing/tiny-xlm-roberta", + "xlm-roberta": "deepset/xlm-roberta-base-squad2", + "electra": "hf-internal-testing/tiny-random-electra", + "albert": "hf-internal-testing/tiny-random-albert", + "bart": "hf-internal-testing/tiny-random-bart", + "mbart": "hf-internal-testing/tiny-random-mbart", + } + + @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_MODEL_ID.items()) + def test_supported_transformers_architectures(self, *args, **kwargs): + model_arch, model_id = args + model = ORTModelForQuestionAnswering.from_pretrained(model_id, from_transformers=True) + self.assertIsInstance(model.model, onnxruntime.capi.onnxruntime_inference_collection.InferenceSession) + self.assertIsInstance(model.config, PretrainedConfig) + + def test_load_vanilla_transformers_which_is_not_supported(self): + with self.assertRaises(Exception) as context: + model = ORTModelForQuestionAnswering.from_pretrained("t5-small") + + self.assertTrue("Unrecognized configuration class", context.exception) + + @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_MODEL_ID.items()) + def test_model_call(self, *args, **kwargs): + model_arch, model_id = args + model = ORTModelForQuestionAnswering.from_pretrained(model_id, from_transformers=True) + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokens = tokenizer( + "This is a sample output", + return_tensors="pt", + ) + outputs = model(**tokens) + self.assertTrue("start_logits" in outputs) + self.assertTrue("end_logits" in outputs) + + self.assertTrue(isinstance(outputs.start_logits, torch.Tensor)) + self.assertTrue(isinstance(outputs.end_logits, torch.Tensor)) + + @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_MODEL_ID.items()) + def test_compare_to_transformers(self, *args, **kwargs): + model_arch, model_id = args + onnx_model = ORTModelForQuestionAnswering.from_pretrained(model_id, from_transformers=True) + trfs_model = AutoModelForQuestionAnswering.from_pretrained(model_id) + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokens = tokenizer( + "This is a sample output", + return_tensors="pt", + ) + onnx_outputs = onnx_model(**tokens) + with torch.no_grad(): + trtfs_outputs = trfs_model(**tokens) + + # compare tensor outputs + self.assertTrue(torch.allclose(onnx_outputs.start_logits, trtfs_outputs.start_logits, atol=1e-4)) + self.assertTrue(torch.allclose(onnx_outputs.end_logits, trtfs_outputs.end_logits, atol=1e-4)) + + @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_MODEL_ID.items()) + def test_pipeline(self, *args, **kwargs): + model_arch, model_id = args + onnx_model = ORTModelForQuestionAnswering.from_pretrained(model_id, from_transformers=True) + tokenizer = AutoTokenizer.from_pretrained(model_id) + pp = pipeline("question-answering", model=onnx_model, tokenizer=tokenizer) + question = "Whats my name?" + context = "My Name is Philipp and I live in Nuremberg." + outputs = pp(question, context) + + # compare model output class + self.assertGreaterEqual(outputs["score"], 0.0) + self.assertTrue(isinstance(outputs["answer"], str)) + + +class ORTModelForSequenceClassificationIntergrationTest(unittest.TestCase): + SUPPORTED_ARCHITECTURES_WITH_MODEL_ID = { + "distilbert": "hf-internal-testing/tiny-random-distilbert", + "bert": "hf-internal-testing/tiny-random-bert", + # FIXME: Error: ONNX export failed: Couldn't export Python operator SymmetricQuantFunction + # "ibert": "hf-internal-testing/tiny-random-ibert", + "camembert": "cmarkea/distilcamembert-base-sentiment", + "roberta": "hf-internal-testing/tiny-random-roberta", + # TODO: used real model do to big difference in output + # "xlm-roberta": "hf-internal-testing/tiny-xlm-roberta", + "xlm-roberta": "unitary/multilingual-toxic-xlm-roberta", + "electra": "hf-internal-testing/tiny-random-electra", + "albert": "hf-internal-testing/tiny-random-albert", + "bart": "hf-internal-testing/tiny-random-bart", + "mbart": "hf-internal-testing/tiny-random-mbart", + } + + @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_MODEL_ID.items()) + def test_supported_transformers_architectures(self, *args, **kwargs): + model_arch, model_id = args + model = ORTModelForSequenceClassification.from_pretrained(model_id, from_transformers=True) + self.assertIsInstance(model.model, onnxruntime.capi.onnxruntime_inference_collection.InferenceSession) + self.assertIsInstance(model.config, PretrainedConfig) + + def test_load_vanilla_transformers_which_is_not_supported(self): + with self.assertRaises(Exception) as context: + model = ORTModelForSequenceClassification.from_pretrained("t5-small", from_transformers=Tru) + + self.assertTrue("Unrecognized configuration class", context.exception) + + @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_MODEL_ID.items()) + def test_model_forward_call(self, *args, **kwargs): + model_arch, model_id = args + model = ORTModelForSequenceClassification.from_pretrained(model_id, from_transformers=True) + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokens = tokenizer( + "This is a sample output", + return_tensors="pt", + ) + outputs = model(**tokens) + self.assertTrue("logits" in outputs) + self.assertTrue(isinstance(outputs.logits, torch.Tensor)) + + @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_MODEL_ID.items()) + def test_compare_to_transformers(self, *args, **kwargs): + model_arch, model_id = args + onnx_model = ORTModelForSequenceClassification.from_pretrained(model_id, from_transformers=True) + trfs_model = AutoModelForSequenceClassification.from_pretrained(model_id) + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokens = tokenizer( + "This is a sample output", + return_tensors="pt", + ) + with torch.no_grad(): + trtfs_outputs = trfs_model(**tokens) + onnx_outputs = onnx_model(**tokens) + + # compare tensor outputs + self.assertTrue(torch.allclose(onnx_outputs.logits, trtfs_outputs.logits, atol=1e-4)) + + @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_MODEL_ID.items()) + def test_pipeline(self, *args, **kwargs): + model_arch, model_id = args + onnx_model = ORTModelForSequenceClassification.from_pretrained(model_id, from_transformers=True) + tokenizer = AutoTokenizer.from_pretrained(model_id) + pp = pipeline("text-classification", model=onnx_model, tokenizer=tokenizer) + text = "My Name is Philipp and i live in Germany." + outputs = pp(text) + + # compare model output class + self.assertGreaterEqual(outputs[0]["score"], 0.0) + self.assertTrue(isinstance(outputs[0]["label"], str)) + + def test_pipeline_zero_shot_classification(self): + onnx_model = ORTModelForSequenceClassification.from_pretrained( + "typeform/distilbert-base-uncased-mnli", from_transformers=True + ) + tokenizer = AutoTokenizer.from_pretrained("typeform/distilbert-base-uncased-mnli") + pp = pipeline("zero-shot-classification", model=onnx_model, tokenizer=tokenizer) + sequence_to_classify = "Who are you voting for in 2020?" + candidate_labels = ["Europe", "public health", "politics", "elections"] + hypothesis_template = "This text is about {}." + outputs = pp(sequence_to_classify, candidate_labels, multi_class=True, hypothesis_template=hypothesis_template) + + # compare model output class + self.assertTrue(any(score > 0.0 for score in outputs["scores"])) + self.assertTrue(any(isinstance(label, str) for label in outputs["labels"])) + + +class ORTModelForTokenClassificationIntergrationTest(unittest.TestCase): + SUPPORTED_ARCHITECTURES_WITH_MODEL_ID = { + "distilbert": "hf-internal-testing/tiny-random-distilbert", + "bert": "hf-internal-testing/tiny-random-bert", + # FIXME: Error: ONNX export failed: Couldn't export Python operator SymmetricQuantFunction + # "ibert": "hf-internal-testing/tiny-random-ibert", + "camembert": "cmarkea/distilcamembert-base-ner", + "roberta": "hf-internal-testing/tiny-random-roberta", + # TODO: used real model do to big difference in output + # "xlm-roberta": "hf-internal-testing/tiny-xlm-roberta", + "xlm-roberta": "Davlan/xlm-roberta-base-wikiann-ner", + "electra": "hf-internal-testing/tiny-random-electra", + "albert": "hf-internal-testing/tiny-random-albert", + } + + @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_MODEL_ID.items()) + def test_supported_transformers_architectures(self, *args, **kwargs): + model_arch, model_id = args + model = ORTModelForTokenClassification.from_pretrained(model_id, from_transformers=True) + self.assertIsInstance(model.model, onnxruntime.capi.onnxruntime_inference_collection.InferenceSession) + self.assertIsInstance(model.config, PretrainedConfig) + + def test_load_vanilla_transformers_which_is_not_supported(self): + with self.assertRaises(Exception) as context: + model = ORTModelForTokenClassification.from_pretrained("t5-small", from_transformers=Tru) + + self.assertTrue("Unrecognized configuration class", context.exception) + + @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_MODEL_ID.items()) + def test_model_call(self, *args, **kwargs): + model_arch, model_id = args + model = ORTModelForTokenClassification.from_pretrained(model_id, from_transformers=True) + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokens = tokenizer( + "This is a sample output", + return_tensors="pt", + ) + outputs = model(**tokens) + self.assertTrue("logits" in outputs) + self.assertTrue(isinstance(outputs.logits, torch.Tensor)) + + @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_MODEL_ID.items()) + def test_compare_to_transformers(self, *args, **kwargs): + model_arch, model_id = args + onnx_model = ORTModelForTokenClassification.from_pretrained(model_id, from_transformers=True) + trfs_model = AutoModelForTokenClassification.from_pretrained(model_id) + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokens = tokenizer( + "This is a sample output", + return_tensors="pt", + ) + onnx_outputs = onnx_model(**tokens) + with torch.no_grad(): + trtfs_outputs = trfs_model(**tokens) + + # compare tensor outputs + self.assertTrue(torch.allclose(onnx_outputs.logits, trtfs_outputs.logits, atol=1e-4)) + + @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_MODEL_ID.items()) + def test_pipeline(self, *args, **kwargs): + model_arch, model_id = args + onnx_model = ORTModelForTokenClassification.from_pretrained(model_id, from_transformers=True) + tokenizer = AutoTokenizer.from_pretrained(model_id) + pp = pipeline("token-classification", model=onnx_model, tokenizer=tokenizer) + text = "My Name is Philipp and i live in Germany." + outputs = pp(text) + + # compare model output class + self.assertTrue(any(item["score"] > 0.0 for item in outputs)) + + +class ORTModelForFeatureExtractionIntergrationTest(unittest.TestCase): + SUPPORTED_ARCHITECTURES_WITH_MODEL_ID = { + "distilbert": "hf-internal-testing/tiny-random-distilbert", + "bert": "hf-internal-testing/tiny-random-bert", + # FIXME: Error: ONNX export failed: Couldn't export Python operator SymmetricQuantFunction + # "ibert": "hf-internal-testing/tiny-random-ibert", + "camembert": "cmarkea/distilcamembert-base", + "roberta": "hf-internal-testing/tiny-random-roberta", + # TODO: used real model do to big difference in output + # "xlm-roberta": "hf-internal-testing/tiny -xlm-roberta", + "xlm-roberta": "xlm-roberta-base", + "electra": "hf-internal-testing/tiny-random-electra", + "albert": "hf-internal-testing/tiny-random-albert", + } + + @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_MODEL_ID.items()) + def test_supported_transformers_architectures(self, *args, **kwargs): + model_arch, model_id = args + model = ORTModelForFeatureExtraction.from_pretrained(model_id, from_transformers=True) + self.assertIsInstance(model.model, onnxruntime.capi.onnxruntime_inference_collection.InferenceSession) + self.assertIsInstance(model.config, PretrainedConfig) + + def test_load_vanilla_transformers_which_is_not_supported(self): + with self.assertRaises(Exception) as context: + model = ORTModelForFeatureExtraction.from_pretrained("google/vit-base-patch16-224", from_transformers=Tru) + + self.assertTrue("Unrecognized configuration class", context.exception) + + @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_MODEL_ID.items()) + def test_model_call(self, *args, **kwargs): + model_arch, model_id = args + model = ORTModelForFeatureExtraction.from_pretrained(model_id, from_transformers=True) + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokens = tokenizer( + "This is a sample output", + return_tensors="pt", + ) + outputs = model(**tokens) + self.assertTrue("last_hidden_state" in outputs) + self.assertTrue(isinstance(outputs.last_hidden_state, torch.Tensor)) + + @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_MODEL_ID.items()) + def test_compare_to_transformers(self, *args, **kwargs): + model_arch, model_id = args + onnx_model = ORTModelForFeatureExtraction.from_pretrained(model_id, from_transformers=True) + trfs_model = AutoModel.from_pretrained(model_id) + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokens = tokenizer( + "This is a sample output", + return_tensors="pt", + ) + onnx_outputs = onnx_model(**tokens) + with torch.no_grad(): + trtfs_outputs = trfs_model(**tokens) + + # compare tensor outputs + self.assertTrue(torch.allclose(onnx_outputs.last_hidden_state, trtfs_outputs.last_hidden_state, atol=1e-4)) + + @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_MODEL_ID.items()) + def test_pipeline(self, *args, **kwargs): + model_arch, model_id = args + onnx_model = ORTModelForFeatureExtraction.from_pretrained(model_id, from_transformers=True) + tokenizer = AutoTokenizer.from_pretrained(model_id) + pp = pipeline("feature-extraction", model=onnx_model, tokenizer=tokenizer) + text = "My Name is Philipp and i live in Germany." + outputs = pp(text) + + # compare model output class + self.assertTrue(any(any(isinstance(item, float) for item in row) for row in outputs[0])) + + +class ORTModelForCausalLMIntergrationTest(unittest.TestCase): + SUPPORTED_ARCHITECTURES_WITH_MODEL_ID = { + "gpt2": "hf-internal-testing/tiny-random-gpt2", + } + + @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_MODEL_ID.items()) + def test_supported_transformers_architectures(self, *args, **kwargs): + model_arch, model_id = args + model = ORTModelForCausalLM.from_pretrained(model_id, from_transformers=True) + self.assertIsInstance(model.model, onnxruntime.capi.onnxruntime_inference_collection.InferenceSession) + self.assertIsInstance(model.config, PretrainedConfig) + + def test_load_vanilla_transformers_which_is_not_supported(self): + with self.assertRaises(Exception) as context: + model = ORTModelForCausalLM.from_pretrained("google/vit-base-patch16-224", from_transformers=True) + + self.assertTrue("Unrecognized configuration class", context.exception) + + @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_MODEL_ID.items()) + def test_model_call(self, *args, **kwargs): + model_arch, model_id = args + model = ORTModelForCausalLM.from_pretrained(model_id, from_transformers=True) + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokens = tokenizer( + "This is a sample output", + return_tensors="pt", + ) + outputs = model(**tokens) + self.assertTrue("logits" in outputs) + self.assertTrue(isinstance(outputs.logits, torch.Tensor)) + + @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_MODEL_ID.items()) + def test_generate_utils(self, *args, **kwargs): + model_arch, model_id = args + model = ORTModelForCausalLM.from_pretrained(model_id, from_transformers=True) + tokenizer = AutoTokenizer.from_pretrained(model_id) + text = "This is a sample output" + tokens = tokenizer( + text, + return_tensors="pt", + ) + outputs = model.generate(**tokens) + res = tokenizer.batch_decode(outputs, skip_special_tokens=True) + self.assertTrue(isinstance(res[0], str)) + self.assertTrue(len(res[0]) > len(text)) + + @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_MODEL_ID.items()) + def test_compare_to_transformers(self, *args, **kwargs): + model_arch, model_id = args + onnx_model = ORTModelForCausalLM.from_pretrained(model_id, from_transformers=True) + trfs_model = AutoModelForCausalLM.from_pretrained(model_id) + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokens = tokenizer( + "This is a sample output", + return_tensors="pt", + ) + onnx_outputs = onnx_model(**tokens) + with torch.no_grad(): + trtfs_outputs = trfs_model(**tokens) + + # compare tensor outputs + self.assertTrue(torch.allclose(onnx_outputs.logits, trtfs_outputs.logits, atol=1e-4)) + + @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_MODEL_ID.items()) + def test_pipeline(self, *args, **kwargs): + model_arch, model_id = args + onnx_model = ORTModelForCausalLM.from_pretrained(model_id, from_transformers=True) + tokenizer = AutoTokenizer.from_pretrained(model_id) + pp = pipeline("text-generation", model=onnx_model, tokenizer=tokenizer) + text = "My Name is Philipp and i live" + outputs = pp(text) + + # compare model output class + self.assertTrue(isinstance(outputs[0]["generated_text"], str)) + self.assertTrue(len(outputs[0]["generated_text"]) > len(text)) diff --git a/tests/test_modeling_base.py b/tests/test_modeling_base.py new file mode 100644 index 0000000000..74a049a965 --- /dev/null +++ b/tests/test_modeling_base.py @@ -0,0 +1,59 @@ +import os +import random +import tempfile +import unittest + +import torch +from transformers.configuration_utils import PretrainedConfig + +import requests as r +from optimum.modeling_base import OptimizedModel +from optimum.utils.testing_utils import require_hf_token + + +TEST_HUB_PATH = "philschmid/unit_test_model" +TEST_LOCAL_PATH = "tests/assets/hub" + + +class DummyModel(OptimizedModel): + def _save_pretrained(self, save_directory, **kwargs): + return + + @classmethod + def _from_pretrained(cls, **kwargs): + config = PretrainedConfig.from_dict(kwargs["config"]) + model = cls(model=torch.nn.Module, config=config) + return model + + def forward(self, *args, **kwargs): + pass + + +class TestOptimizedModel(unittest.TestCase): + def test_load_model_from_hub(self): + # TODO: figure out how to create repos and push stuff to staging + if os.getenv("HUGGINGFACE_CO_STAGING", False): + self.skipTest("Skip test on staging") + + dummy_model = DummyModel.from_pretrained(TEST_HUB_PATH) + self.assertTrue(dummy_model.config.remote) + + @require_hf_token + def test_push_to_hub(self): + with tempfile.TemporaryDirectory() as tmpdirname: + + model = DummyModel.from_pretrained(TEST_LOCAL_PATH) + # create remote hash to check if file was updated. + remote_hash = random.getrandbits(128) + model.config.from_local = remote_hash + + model.save_pretrained( + tmpdirname, + use_auth_token=os.environ.get("HF_AUTH_TOKEN", None), + push_to_hub=True, + repository_id="unit_test_save_model", + ) + # folder contains all config files and pytorch_model.bin + url = f"https://huggingface.co/philschmid/unit_test_save_model/raw/main/config.json" + response = r.get(url) + self.assertEqual(remote_hash, response.json()["from_local"])