Skip to content

Commit

Permalink
Add tf available and version (#2154)
Browse files Browse the repository at this point in the history
* remove torch version

* add tf check

* fix
  • Loading branch information
echarlaix authored Jan 9, 2025
1 parent 605ed7e commit 8d1347f
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 27 deletions.
10 changes: 7 additions & 3 deletions optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,12 @@
from ...utils import TORCH_MINIMUM_VERSION as GLOBAL_MIN_TORCH_VERSION
from ...utils import TRANSFORMERS_MINIMUM_VERSION as GLOBAL_MIN_TRANSFORMERS_VERSION
from ...utils.doc import add_dynamic_docstring
from ...utils.import_utils import is_onnx_available, is_onnxruntime_available, is_transformers_version
from ...utils.import_utils import (
is_onnx_available,
is_onnxruntime_available,
is_torch_version,
is_transformers_version,
)
from ..base import ExportConfig
from .constants import ONNX_DECODER_MERGED_NAME, ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME
from .model_patcher import ModelPatcher, Seq2SeqModelPatcher
Expand Down Expand Up @@ -386,9 +391,8 @@ def is_torch_support_available(self) -> bool:
`bool`: Whether the installed version of PyTorch is compatible with the model.
"""
if is_torch_available():
from ...utils import torch_version
return is_torch_version(">=", self.MIN_TORCH_VERSION.base_version)

return torch_version >= self.MIN_TORCH_VERSION
return False

@property
Expand Down
7 changes: 3 additions & 4 deletions optimum/exporters/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,17 +851,16 @@ def export(
)

if is_torch_available() and isinstance(model, nn.Module):
from ...utils import torch_version
from ...utils.import_utils import _torch_version

if not is_torch_onnx_support_available():
raise MinimumVersionError(
f"Unsupported PyTorch version, minimum required is {TORCH_MINIMUM_VERSION}, got: {torch_version}"
f"Unsupported PyTorch version, minimum required is {TORCH_MINIMUM_VERSION}, got: {_torch_version}"
)

if not config.is_torch_support_available:
raise MinimumVersionError(
f"Unsupported PyTorch version for this model. Minimum required is {config.MIN_TORCH_VERSION},"
f" got: {torch.__version__}"
f"Unsupported PyTorch version for this model. Minimum required is {config.MIN_TORCH_VERSION}, got: {_torch_version}"
)

export_output = export_pytorch(
Expand Down
2 changes: 2 additions & 0 deletions optimum/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@
is_onnxruntime_available,
is_pydantic_available,
is_sentence_transformers_available,
is_tf_available,
is_timm_available,
is_torch_available,
is_torch_onnx_support_available,
is_torch_version,
is_transformers_available,
Expand Down
65 changes: 52 additions & 13 deletions optimum/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
"""Import utilities."""

import importlib.metadata as importlib_metadata
import importlib.metadata
import importlib.util
import inspect
import operator as op
Expand All @@ -23,7 +23,6 @@

import numpy as np
from packaging import version
from transformers.utils import is_torch_available


TORCH_MINIMUM_VERSION = version.parse("1.11.0")
Expand Down Expand Up @@ -64,14 +63,46 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
_datasets_available = _is_package_available("datasets")
_diffusers_available, _diffusers_version = _is_package_available("diffusers", return_version=True)
_transformers_available, _transformers_version = _is_package_available("transformers", return_version=True)
_torch_available, _torch_version = _is_package_available("torch", return_version=True)

# importlib.metadata.version seem to not be robust with the ONNX Runtime extensions (`onnxruntime-gpu`, etc.)
_onnxruntime_available = _is_package_available("onnxruntime", return_version=False)


# TODO : Remove
torch_version = None
if is_torch_available():
torch_version = version.parse(importlib_metadata.version("torch"))
torch_version = version.parse(importlib.metadata.version("torch")) if _torch_available else None


# Note: _is_package_available("tensorflow") fails for tensorflow-cpu. Please test any changes to the line below
# with tensorflow-cpu to make sure it still works!
_tf_available = importlib.util.find_spec("tensorflow") is not None
_tf_version = None
if _tf_available:
candidates = (
"tensorflow",
"tensorflow-cpu",
"tensorflow-gpu",
"tf-nightly",
"tf-nightly-cpu",
"tf-nightly-gpu",
"tf-nightly-rocm",
"intel-tensorflow",
"intel-tensorflow-avx512",
"tensorflow-rocm",
"tensorflow-macos",
"tensorflow-aarch64",
)
# For the metadata, we have to look for both tensorflow and tensorflow-cpu
for pkg in candidates:
try:
_tf_version = importlib.metadata.version(pkg)
break
except importlib.metadata.PackageNotFoundError:
pass
_tf_available = _tf_version is not None
if _tf_available:
if version.parse(_tf_version) < version.parse("2"):
_tf_available = False
_tf_version = _tf_version or "N/A"


# This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L319
Expand All @@ -91,7 +122,7 @@ def compare_versions(library_or_version: Union[str, version.Version], operation:
raise ValueError(f"`operation` must be one of {list(STR_OPERATION_TO_FUNC.keys())}, received {operation}")
operation = STR_OPERATION_TO_FUNC[operation]
if isinstance(library_or_version, str):
library_or_version = version.parse(importlib_metadata.version(library_or_version))
library_or_version = version.parse(importlib.metadata.version(library_or_version))
return operation(library_or_version, version.parse(requirement_version))


Expand All @@ -117,15 +148,15 @@ def is_torch_version(operation: str, reference_version: str):
"""
Compare the current torch version to a given reference with an operation.
"""
if not is_torch_available():
if not _torch_available:
return False

import torch

return compare_versions(version.parse(version.parse(torch.__version__).base_version), operation, reference_version)


_is_torch_onnx_support_available = is_torch_available() and is_torch_version(">=", TORCH_MINIMUM_VERSION.base_version)
_is_torch_onnx_support_available = _torch_available and is_torch_version(">=", TORCH_MINIMUM_VERSION.base_version)


def is_torch_onnx_support_available():
Expand Down Expand Up @@ -176,9 +207,17 @@ def is_transformers_available():
return _transformers_available


def is_torch_available():
return _torch_available


def is_tf_available():
return _tf_available


def is_auto_gptq_available():
if _auto_gptq_available:
v = version.parse(importlib_metadata.version("auto_gptq"))
v = version.parse(importlib.metadata.version("auto_gptq"))
if v >= AUTOGPTQ_MINIMUM_VERSION:
return True
else:
Expand All @@ -189,7 +228,7 @@ def is_auto_gptq_available():

def is_gptqmodel_available():
if _gptqmodel_available:
v = version.parse(importlib_metadata.version("gptqmodel"))
v = version.parse(importlib.metadata.version("gptqmodel"))
if v >= GPTQMODEL_MINIMUM_VERSION:
return True
else:
Expand Down Expand Up @@ -260,10 +299,10 @@ def check_if_torch_greater(target_version: str) -> bool:
Returns:
bool: whether the check is True or not.
"""
if not is_torch_available():
if not _torch_available:
return False

return torch_version >= version.parse(target_version)
return version.parse(_torch_version) >= version.parse(target_version)


@contextmanager
Expand Down
3 changes: 1 addition & 2 deletions optimum/utils/input_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@
from typing import Any, List, Optional, Tuple, Union

import numpy as np
from transformers.utils import is_tf_available, is_torch_available

from ..utils import is_diffusers_version, is_transformers_version
from ..utils import is_diffusers_version, is_tf_available, is_torch_available, is_transformers_version
from .normalized_config import (
NormalizedConfig,
NormalizedEncoderDecoderConfig,
Expand Down
8 changes: 3 additions & 5 deletions tests/exporters/onnx/test_onnx_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,20 +207,18 @@ def _onnx_export(
model.config.pad_token_id = 0

if is_torch_available():
from optimum.utils import torch_version
from optimum.utils.import_utils import _torch_version, _transformers_version

if not onnx_config.is_transformers_support_available:
import transformers

pytest.skip(
"Skipping due to incompatible Transformers version. Minimum required is"
f" {onnx_config.MIN_TRANSFORMERS_VERSION}, got: {transformers.__version__}"
f" {onnx_config.MIN_TRANSFORMERS_VERSION}, got: {_transformers_version}"
)

if not onnx_config.is_torch_support_available:
pytest.skip(
"Skipping due to incompatible PyTorch version. Minimum required is"
f" {onnx_config.MIN_TORCH_VERSION}, got: {torch_version}"
f" {onnx_config.MIN_TORCH_VERSION}, got: {_torch_version}"
)

atol = onnx_config.ATOL_FOR_VALIDATION
Expand Down

0 comments on commit 8d1347f

Please sign in to comment.