diff --git a/optimum/exporters/onnx/_traceable_cache.py b/optimum/exporters/onnx/_traceable_cache.py new file mode 100644 index 0000000000..052cb04b12 --- /dev/null +++ b/optimum/exporters/onnx/_traceable_cache.py @@ -0,0 +1,95 @@ +import logging +from typing import Any, Dict, Optional, Tuple + +import torch + + +logger = logging.getLogger(__name__) + + +# Simply removing the nn.Module, same as in https://github.com/huggingface/transformers/pull/35873 +class TraceableCache: + """ + Base, abstract class for all caches. The actual data structure is specific to each subclass. + """ + + def __init__(self): + super().__init__() + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. These are specific to each subclass and allow new types of + cache to be created. + + Return: + A tuple containing the updated key and value states. + """ + raise NotImplementedError("Make sure to implement `update` in a subclass.") + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # TODO: deprecate this function in favor of `cache_position` + raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.") + + # Deprecate in favor of max-cache-shape because we want to be specifc by what we mean with "max_length" + # Prev some cache objects didn't have "max_length" (SlidingWindowCache or SinkCache) because the cache object technically handles + # infinite amount of tokens. In the codebase what we really need to check is the max capacity of certain cache instances, so + # we change naming to be more explicit + def get_max_length(self) -> Optional[int]: + logger.warning_once( + "`get_max_cache()` is deprecated for all Cache classes. Use `get_max_cache_shape()` instead. " + "Calling `get_max_cache()` will raise error from v4.48" + ) + return self.get_max_cache_shape() + + def get_max_cache_shape(self) -> Optional[int]: + """Returns the maximum sequence length (i.e. max capacity) of the cache object""" + raise NotImplementedError("Make sure to implement `get_max_cache_shape` in a subclass.") + + def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int: + """Given the sequence length of the new inputs, returns the usable length of the cache.""" + # Cache without size limit -> all cache is usable + # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache + # length, we will need to evict part of the cache (and thus not all cache is usable) + max_length = self.get_max_cache_shape() + previous_seq_length = self.get_seq_length(layer_idx) + if max_length is not None and previous_seq_length + new_seq_length > max_length: + return max_length - new_seq_length + return previous_seq_length + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + for layer_idx in range(len(self.key_cache)): + if self.key_cache[layer_idx] != []: + device = self.key_cache[layer_idx].device + self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) + if self.value_cache[layer_idx] != []: + device = self.value_cache[layer_idx].device + self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) + + @property + def seen_tokens(self): + logger.warning_once( + "The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` " + "model input instead." + ) + if hasattr(self, "_seen_tokens"): + return self._seen_tokens + else: + return None diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 503f28d057..f765eb7042 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -843,7 +843,7 @@ class DeiTOnnxConfig(ViTOnnxConfig): class BeitOnnxConfig(ViTOnnxConfig): - DEFAULT_ONNX_OPSET = 11 + DEFAULT_ONNX_OPSET = 14 # now uses F.scaled_dot_product_attention by default for torch>=2.1.1. class ConvNextOnnxConfig(ViTOnnxConfig): @@ -1573,13 +1573,12 @@ class Data2VecTextOnnxConfig(DistilBertOnnxConfig): class Data2VecVisionOnnxConfig(ViTOnnxConfig): - DEFAULT_ONNX_OPSET = 11 + DEFAULT_ONNX_OPSET = 14 # now uses F.scaled_dot_product_attention by default for torch>=2.1.1. class Data2VecAudioOnnxConfig(AudioOnnxConfig): - NORMALIZED_CONFIG_CLASS = NormalizedConfig - ATOL_FOR_VALIDATION = 1e-4 DEFAULT_ONNX_OPSET = 14 # now uses F.scaled_dot_product_attention by default for torch>=2.1.1. + NORMALIZED_CONFIG_CLASS = NormalizedConfig class PerceiverDummyInputGenerator(DummyVisionInputGenerator): diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index 80293e7b95..53476ff206 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -20,35 +20,34 @@ import types from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union +import torch import transformers -from packaging import version from transformers.models.speecht5.modeling_speecht5 import SpeechT5EncoderWithSpeechPrenet -from transformers.utils import is_torch_available +from ...utils import is_transformers_version, logging +from ._traceable_cache import TraceableCache -if is_torch_available(): - import torch -from ...configuration_utils import _transformers_version -from ...utils import logging - - -if _transformers_version > version.parse("4.34.99"): +if is_transformers_version(">=", "4.35"): from transformers.modeling_attn_mask_utils import AttentionMaskConverter -if _transformers_version >= version.parse("4.36"): +if is_transformers_version(">=", "4.36"): from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask_for_sdpa -else: - _prepare_4d_causal_attention_mask_for_sdpa = None - AttentionMaskConverter = None - -if _transformers_version >= version.parse("4.42"): +if is_transformers_version(">=", "4.43"): + from transformers.models.clip.modeling_clip import CLIPAttention, CLIPSdpaAttention +if is_transformers_version(">=", "4.42"): from transformers.cache_utils import SlidingWindowCache, StaticCache +if is_transformers_version(">=", "4.48"): + from transformers.cache_utils import DynamicCache, EncoderDecoderCache + from transformers.integrations.sdpa_attention import repeat_kv, sdpa_attention_forward + from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + if TYPE_CHECKING: from transformers import PreTrainedModel, TFPreTrainedModel from .base import OnnxConfig + logger = logging.get_logger(__name__) @@ -158,6 +157,7 @@ def onnx_compatible_unfold(input_tensor, dimension, size, step): UNSUPPORTED_OPS_PATCHING_SPEC = [PatchingSpec(torch.Tensor, "unfold", onnx_compatible_unfold, torch.Tensor.unfold)] +CACHE_PATCHING_SPEC = [PatchingSpec(transformers.cache_utils, "Cache", TraceableCache, transformers.cache_utils.Cache)] class ModelPatcher: @@ -171,6 +171,7 @@ def __init__( patching_specs = config.PATCHING_SPECS or [] patching_specs.extend(UNSUPPORTED_OPS_PATCHING_SPEC) + patching_specs.extend(CACHE_PATCHING_SPEC) self._patching_specs = [] for spec in patching_specs: @@ -197,6 +198,39 @@ def patched_forward(*args, **kwargs): signature = inspect.signature(self.orig_forward) args, kwargs = override_arguments(args, kwargs, signature, model_kwargs=self.model_kwargs) + if is_transformers_version(">=", "4.48"): + if "past_key_values" in signature.parameters: + pkv_index = list(signature.parameters.keys()).index("past_key_values") + + if ( + pkv_index < len(args) # pkv is in args + and isinstance(args[pkv_index], (list, tuple)) + and isinstance(args[pkv_index][0], (list, tuple)) + ): + if len(args[pkv_index][0]) == 2: + args[pkv_index] = DynamicCache.from_legacy_cache(args[pkv_index]) + elif len(args[pkv_index][0]) == 4: + args[pkv_index] = EncoderDecoderCache.from_legacy_cache(args[pkv_index]) + else: + raise ValueError( + f"past_key_values should have either 2 or 4 elements, but it has {len(args[pkv_index][0])} elements" + ) + elif ( + "past_key_values" in kwargs # pkv is in kwargs + and isinstance(kwargs["past_key_values"], (list, tuple)) + and isinstance(kwargs["past_key_values"][0], (list, tuple)) + ): + if len(kwargs["past_key_values"][0]) == 2: + kwargs["past_key_values"] = DynamicCache.from_legacy_cache(kwargs["past_key_values"]) + elif len(kwargs["past_key_values"][0]) == 4: + kwargs["past_key_values"] = EncoderDecoderCache.from_legacy_cache( + kwargs["past_key_values"] + ) + else: + raise ValueError( + f"past_key_values should have either 2 or 4 elements, but it has {len(kwargs['past_key_values'][0])} elements" + ) + outputs = self.orig_forward(*args, **kwargs) # This code block handles different cases of the filterd_outputs input to align it with the expected @@ -230,6 +264,11 @@ def patched_forward(*args, **kwargs): filterd_outputs[name] = outputs name = list(config.outputs.keys())[0] filterd_outputs[name] = outputs + + if is_transformers_version(">=", "4.48"): + if isinstance(filterd_outputs.get("past_key_values"), (DynamicCache, EncoderDecoderCache)): + filterd_outputs["past_key_values"] = outputs["past_key_values"].to_legacy_cache() + return filterd_outputs self.patched_forward = patched_forward @@ -259,6 +298,18 @@ def __call__(self, *args, **kwargs): class Seq2SeqModelPatcher(ModelPatcher): + def __enter__(self): + super().__enter__() + if is_transformers_version(">=", "4.48"): + # this is required when gpt2 is used as decoder in any + # encoder-decoder model with cross attention blocks + ALL_ATTENTION_FUNCTIONS["sdpa"] = patched_sdpa_attention_forward + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + if is_transformers_version(">=", "4.48"): + ALL_ATTENTION_FUNCTIONS["sdpa"] = sdpa_attention_forward + def __init__( self, config: "OnnxConfig", @@ -310,6 +361,51 @@ def patched_forward(*args, **kwargs): self.patched_forward = patched_forward +def patched_sdpa_attention_forward( + module: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + dropout: float = 0.0, + scaling: Optional[float] = None, + is_causal: Optional[bool] = None, + **kwargs, +) -> Tuple[torch.Tensor, None]: + if hasattr(module, "num_key_value_groups"): + key = repeat_kv(key, module.num_key_value_groups) + value = repeat_kv(value, module.num_key_value_groups) + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key.shape[-2]] + # SDPA with memory-efficient backend is bugged with non-contiguous inputs and custom attn_mask for some torch versions + # Reference: https://github.com/pytorch/pytorch/issues/112577. + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + if is_causal is None: + is_causal = causal_mask is None and query.shape[2] > 1 + + # Shapes (e.g. query.shape[2]) are tensors during jit tracing, resulting in `is_causal` being a tensor. + # We convert it to a bool for the SDPA kernel that only accepts bools. + if torch.jit.is_tracing() and isinstance(is_causal, torch.Tensor): + is_causal = is_causal.item() + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query, + key, + value, + attn_mask=causal_mask, + dropout_p=dropout, + scale=scaling, + is_causal=is_causal, + ) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, None + + class VisionEncoderDecoderPatcher(Seq2SeqModelPatcher): def __init__( self, @@ -324,14 +420,17 @@ def __init__( model.decoder.model.decoder.config.use_cache = True -def _unmask_unattended_patched_legacy( - expanded_mask: torch.Tensor, attention_mask: torch.Tensor, unmasked_value: Union[bool, float] -): - return expanded_mask +if is_transformers_version(">=", "4.39"): + + def _unmask_unattended_patched(expanded_mask: torch.Tensor, min_dtype: float): + return expanded_mask +else: -def _unmask_unattended_patched(expanded_mask: torch.Tensor, min_dtype: float): - return expanded_mask + def _unmask_unattended_patched( + expanded_mask: torch.Tensor, attention_mask: torch.Tensor, unmasked_value: Union[bool, float] + ): + return expanded_mask def _make_causal_mask_patched( @@ -366,14 +465,6 @@ def _make_causal_mask_patched( return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) -_make_causal_mask_patched_staticmethod = staticmethod(_make_causal_mask_patched) - -if _transformers_version >= version.parse("4.39.0"): - _unmask_unattended_patched_staticmethod = staticmethod(_unmask_unattended_patched) -else: - _unmask_unattended_patched_staticmethod = staticmethod(_unmask_unattended_patched_legacy) - - # Adapted from _prepare_4d_causal_attention_mask def _prepare_4d_causal_attention_mask_for_sdpa_patched( attention_mask: Optional[torch.Tensor], @@ -412,28 +503,22 @@ def _prepare_4d_causal_attention_mask_for_sdpa_patched( class DecoderModelPatcher(ModelPatcher): def __enter__(self): super().__enter__() - if AttentionMaskConverter is not None: - # TODO: Remove this _make_causal_mask patch if once transformers if much above 4.35 - AttentionMaskConverter._make_causal_mask = _make_causal_mask_patched_staticmethod + if is_transformers_version(">=", "4.35"): + AttentionMaskConverter._make_causal_mask = staticmethod(_make_causal_mask_patched) - if _transformers_version >= version.parse("4.36"): - AttentionMaskConverter._unmask_unattended = _unmask_unattended_patched_staticmethod - - if _transformers_version >= version.parse("4.36"): + if is_transformers_version(">=", "4.36"): + AttentionMaskConverter._unmask_unattended = staticmethod(_unmask_unattended_patched) patch_everywhere( "_prepare_4d_causal_attention_mask_for_sdpa", _prepare_4d_causal_attention_mask_for_sdpa_patched ) def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) - if AttentionMaskConverter is not None: - # TODO: Remove this _make_causal_mask patch if once transformers if much above 4.35 - AttentionMaskConverter._make_causal_mask = staticmethod(self.original_make_causal) - - if _transformers_version >= version.parse("4.36"): - AttentionMaskConverter._unmask_unattended = staticmethod(self.original_unmask_unattended) + if is_transformers_version(">=", "4.35"): + AttentionMaskConverter._make_causal_mask = staticmethod(self.original_make_causal_mask) - if _transformers_version >= version.parse("4.36"): + if is_transformers_version(">=", "4.36"): + AttentionMaskConverter._unmask_unattended = staticmethod(self.original_unmask_unattended) patch_everywhere( "_prepare_4d_causal_attention_mask_for_sdpa", self.original_prepare_4d_causal_attention_mask_for_sdpa ) @@ -446,13 +531,12 @@ def __init__( ): super().__init__(config, model, model_kwargs) - if _transformers_version >= version.parse("4.36"): - self.original_prepare_4d_causal_attention_mask_for_sdpa = _prepare_4d_causal_attention_mask_for_sdpa - self.original_unmask_unattended = AttentionMaskConverter._unmask_unattended + if is_transformers_version(">=", "4.35"): + self.original_make_causal_mask = AttentionMaskConverter._make_causal_mask - # TODO: Remove this if once transformers if much above 4.35 - if AttentionMaskConverter is not None: - self.original_make_causal = AttentionMaskConverter._make_causal_mask + if is_transformers_version(">=", "4.36"): + self.original_unmask_unattended = AttentionMaskConverter._unmask_unattended + self.original_prepare_4d_causal_attention_mask_for_sdpa = _prepare_4d_causal_attention_mask_for_sdpa def falcon_build_alibi_tensor_patched( @@ -833,14 +917,22 @@ def patched_forward( class SentenceTransformersTransformerPatcher(ModelPatcher): def __enter__(self): super().__enter__() - if _transformers_version >= version.parse("4.42") and self.real_config._config.model_type == "mistral": + if ( + is_transformers_version(">=", "4.42") + and is_transformers_version("<", "4.48") + and self.real_config._config.model_type == "mistral" + ): self._model[0].auto_model._update_causal_mask = types.MethodType( _update_causal_mask_patched, self._model[0].auto_model ) def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) - if _transformers_version >= version.parse("4.42") and self.real_config._config.model_type == "mistral": + if ( + is_transformers_version(">=", "4.42") + and is_transformers_version("<", "4.48") + and self.real_config._config.model_type == "mistral" + ): self._model[0].auto_model._update_causal_mask = types.MethodType( self._update_causal_mask_original, self._model[0].auto_model ) @@ -853,7 +945,11 @@ def __init__( ): super().__init__(config, model, model_kwargs) - if _transformers_version >= version.parse("4.42") and self.real_config._config.model_type == "mistral": + if ( + is_transformers_version(">=", "4.42") + and is_transformers_version("<", "4.48") + and self.real_config._config.model_type == "mistral" + ): self._update_causal_mask_original = self._model[0].auto_model._update_causal_mask def patched_forward(input_ids, attention_mask): @@ -1132,36 +1228,25 @@ def _update_causal_mask_patched( padding_mask, min_dtype ) - # if ( - # self.config._attn_implementation == "sdpa" - # and attention_mask is not None - # and attention_mask.device.type == "cuda" - # and not output_attentions - # ): - # # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # # Details: https://github.com/pytorch/pytorch/issues/110213 - # causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) return causal_mask -class MistralModelPatcher(ModelPatcher): +class MistralModelPatcher(DecoderModelPatcher): def __enter__(self): super().__enter__() - if AttentionMaskConverter is not None: - # TODO: Remove this _make_causal_mask patch if once transformers if much above 4.35 - AttentionMaskConverter._make_causal_mask = _make_causal_mask_patched_staticmethod - - if _transformers_version >= version.parse("4.36"): - AttentionMaskConverter._unmask_unattended = _unmask_unattended_patched_staticmethod - if _transformers_version >= version.parse("4.36"): - patch_everywhere( - "_prepare_4d_causal_attention_mask_for_sdpa", _prepare_4d_causal_attention_mask_for_sdpa_patched - ) - - if _transformers_version >= version.parse("4.42"): + if is_transformers_version(">=", "4.42") and is_transformers_version("<", "4.48"): if hasattr(self._model, "model"): self._model.model._update_causal_mask = types.MethodType( _update_causal_mask_patched, self._model.model @@ -1171,19 +1256,8 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) - if AttentionMaskConverter is not None: - # TODO: Remove this _make_causal_mask patch if once transformers if much above 4.35 - AttentionMaskConverter._make_causal_mask = staticmethod(self.original_make_causal) - - if _transformers_version >= version.parse("4.36"): - AttentionMaskConverter._unmask_unattended = staticmethod(self.original_unmask_unattended) - - if _transformers_version >= version.parse("4.36"): - patch_everywhere( - "_prepare_4d_causal_attention_mask_for_sdpa", self.original_prepare_4d_causal_attention_mask_for_sdpa - ) - if _transformers_version >= version.parse("4.42"): + if is_transformers_version(">=", "4.42") and is_transformers_version("<", "4.48"): if hasattr(self._model, "model"): self._model.model._update_causal_mask = types.MethodType( self._update_causal_mask_original, self._model.model @@ -1199,15 +1273,7 @@ def __init__( ): super().__init__(config, model, model_kwargs) - if _transformers_version >= version.parse("4.36"): - self.original_prepare_4d_causal_attention_mask_for_sdpa = _prepare_4d_causal_attention_mask_for_sdpa - self.original_unmask_unattended = AttentionMaskConverter._unmask_unattended - - # TODO: Remove this if once transformers if much above 4.35 - if AttentionMaskConverter is not None: - self.original_make_causal = AttentionMaskConverter._make_causal_mask - - if _transformers_version >= version.parse("4.42"): + if is_transformers_version(">=", "4.42") and is_transformers_version("<", "4.48"): if hasattr(self._model, "model"): self._update_causal_mask_original = self._model.model._update_causal_mask else: @@ -1217,15 +1283,11 @@ def __init__( class CLIPModelPatcher(ModelPatcher): def __enter__(self): super().__enter__() - - if _transformers_version >= version.parse("4.43"): - from transformers.models.clip.modeling_clip import CLIPAttention, CLIPSdpaAttention - - self.original_sdpa_forward, CLIPSdpaAttention.forward = CLIPSdpaAttention.forward, CLIPAttention.forward + if is_transformers_version(">=", "4.43"): + self.original_sdpa_forward = CLIPSdpaAttention.forward + CLIPSdpaAttention.forward = CLIPAttention.forward def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) - if _transformers_version >= version.parse("4.43"): - from transformers.models.clip.modeling_clip import CLIPSdpaAttention - + if is_transformers_version(">=", "4.43"): CLIPSdpaAttention.forward = self.original_sdpa_forward diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index 3793a56068..1a216ce6e8 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -897,9 +897,7 @@ class TasksManager: "feature-extraction", "fill-mask", "text-classification", - "multiple-choice", "token-classification", - "question-answering", onnx="ModernBertOnnxConfig", ), "mpnet": supported_tasks_mapping( diff --git a/setup.py b/setup.py index ec15277f18..4c4d9c43a0 100644 --- a/setup.py +++ b/setup.py @@ -50,7 +50,7 @@ "datasets>=1.2.1", "evaluate", "protobuf>=3.20.1", - "transformers>=4.36,<4.48.0", + "transformers>=4.36,<4.49.0", ], "onnxruntime-gpu": [ "onnx", @@ -59,19 +59,19 @@ "evaluate", "protobuf>=3.20.1", "accelerate", # ORTTrainer requires it. - "transformers>=4.36,<4.48.0", + "transformers>=4.36,<4.49.0", ], "exporters": [ "onnx", "onnxruntime", "timm", - "transformers>=4.36,<4.48.0", + "transformers>=4.36,<4.49.0", ], "exporters-gpu": [ "onnx", "onnxruntime-gpu", "timm", - "transformers>=4.36,<4.48.0", + "transformers>=4.36,<4.49.0", ], "exporters-tf": [ "tensorflow>=2.4,<=2.12.1", diff --git a/tests/exporters/exporters_utils.py b/tests/exporters/exporters_utils.py index ee31397fd8..2b9bca7a73 100644 --- a/tests/exporters/exporters_utils.py +++ b/tests/exporters/exporters_utils.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. + VALIDATE_EXPORT_ON_SHAPES_SLOW = { "batch_size": [1, 3, 5], "sequence_length": [8, 33, 96, 154], diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index d92888a8dd..c341bd88a9 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -4612,7 +4612,9 @@ def test_compare_with_and_without_past_key_values(self, model_arch: str): self.assertTrue(torch.equal(outputs_model_with_pkv, outputs_model_without_pkv)) - if model_arch == "whisper" and is_transformers_version(">=", "4.43"): + if model_arch == "whisper" and is_transformers_version(">=", "4.48"): + gen_length = self.GENERATION_LENGTH + elif model_arch == "whisper" and is_transformers_version(">=", "4.43"): gen_length = self.GENERATION_LENGTH + 2 else: gen_length = self.GENERATION_LENGTH + 1