diff --git a/optimum_benchmark/backends/onnxruntime/backend.py b/optimum_benchmark/backends/onnxruntime/backend.py index fcd8ee06..c902758e 100644 --- a/optimum_benchmark/backends/onnxruntime/backend.py +++ b/optimum_benchmark/backends/onnxruntime/backend.py @@ -21,14 +21,20 @@ OptimizationConfig, QuantizationConfig, ) +from safetensors.torch import save_file from transformers import TrainerCallback, TrainerState +from transformers.modeling_utils import no_init_weights +from transformers.utils.logging import set_verbosity_error from ..base import Backend -from ..optimum_utils import main_export +from ..peft_utils import get_peft_config_class from ..pytorch.utils import randomize_weights from .config import ORTConfig from .utils import TASKS_TO_ORTMODELS, TASKS_TO_ORTSD, format_quantization_config +# disable transformers logging +set_verbosity_error() + LOGGER = getLogger("onnxruntime") @@ -59,14 +65,10 @@ def configure(self, config: ORTConfig) -> None: ortmodel_name = self.ortmodel_class.__name__ LOGGER.info(f"Inferred ORTModel class {ortmodel_name} for task {self.task} and model_type {self.model_type}") - # Process torch dtype - self.torch_dtype = getattr(torch, self.config.torch_dtype) if self.config.torch_dtype is not None else None - - ###### Training with ORTModule ###### - # ort-training is basically a different package so we might need to separate these two backends in the future + ######## Training with ORTModule ######## if not self.config.use_inference_session: if self.config.no_weights: - self.load_automodel_from_config() + self.load_automodel_with_no_weights() else: self.load_automodel_from_pretrained() @@ -74,161 +76,182 @@ def configure(self, config: ORTConfig) -> None: LOGGER.info("\t+ Applying PEFT") from peft import get_peft_model - from ..peft_utils import get_peft_config_class - peft_config_class = get_peft_config_class(self.config.peft_strategy) peft_config = peft_config_class(**self.config.peft_config) self.pretrained_model = get_peft_model(self.pretrained_model, peft_config=peft_config) - # early exit because nothing of the following can be applied to training - return - - ###### Inference with ORTModelForxxx ###### - # Inference session options - self.session_options = SessionOptions() - for key, value in self.config.session_options.items(): - setattr(self.session_options, key, value) - # Exporting, optimizing, post-processing and quantizing with ORTModelForxxx - self.tmpdir = TemporaryDirectory() + return # early exit because nothing of the following can be applied to training - # Some statefullness to handle the different combinations of options + ######## Inference with ORTModelForxxx ######## self.export = self.config.export - self.use_merged = self.config.use_merged - self.provider_options = self.config.provider_options.copy() + self.tmpdir = TemporaryDirectory() + self.session_options = SessionOptions() + self.provider_options = self.config.provider_options - if self.is_diffusion_pipeline(): - self.load_ortmodel() - # early exit because nothing of the following can be applied to diffusion pipelines - return + for key, value in self.config.session_options.items(): + setattr(self.session_options, key, value) if self.config.no_weights: - self.load_automodel_from_config() # creates dummy automodel - self.export_automodel() # exports automodel - self.export = False + self.load_ortmodel_with_no_weights() else: - if self.config.export: - self.use_merged = False # merging is handled separately - self.load_automodel_from_pretrained() # creates automodel from pretrained - self.export_automodel() # exports automodel - self.export = False + self.load_ortmodel_from_pretrained() - self.delete_pretrained_model() # deletes automodel + if self.config.provider == "TensorrtExecutionProvider" and self.is_text_generation_model(): + return # deferred loading for trt text generation models - if self.config.auto_optimization or self.config.optimization: - self.optimize_onnx_files() + if self.is_optimized or self.is_quantized: + original_model = self.model + original_export = self.export - if self.config.use_merged: - self.merge_onnx_files() - self.use_merged = True + self.model = self.pretrained_model.model_save_dir # self.model will point to a directory from here on + self.export = False # we disable export because we want to load the optimized/quantized onnx files + + if self.is_optimized: + self.optimize_onnx_files() - if self.config.auto_quantization or self.config.quantization: + if self.is_quantized: self.quantize_onnx_files() - if self.config.provider == "TensorrtExecutionProvider" and self.is_text_generation_model(): - # deferred loading for trt text generation models - return + if self.is_optimized or self.is_quantized: + self.load_ortmodel_from_pretrained() # load optimized/quantized model + self.export = original_export + self.model = original_model - self.load_ortmodel() self.validate_provider() def validate_provider(self) -> None: - if self.config.provider == "TensorrtExecutionProvider": - assert self.pretrained_model.providers == [ - "TensorrtExecutionProvider", - "CUDAExecutionProvider", - "CPUExecutionProvider", - ], f"TensorrtExecutionProvider is not first in providers list: {self.pretrained_model.providers}" - - if self.config.provider == "ROCMExecutionProvider": - assert self.pretrained_model.providers == [ - "ROCMExecutionProvider", - "CPUExecutionProvider", - ], f"ROCMExecutionProvider is not first in providers list: {self.pretrained_model.providers}" - - def load_automodel_from_config(self) -> None: - from accelerate import init_empty_weights - - LOGGER.info("\t+ Loading AutoModel from config") - with init_empty_weights(): - self.pretrained_model = self.automodel_class.from_config( - self.pretrained_config, - torch_dtype=self.torch_dtype, - trust_remote_code=self.hub_kwargs.get("trust_remote_code", False), - ) - self.pretrained_model.to_empty(device=self.device) + assert ( + self.pretrained_model.providers[0] == self.config.provider + ), f"{self.config.provider} is not first in providers list: {self.pretrained_model.providers}" + + def load_automodel_with_no_weights(self) -> None: + original_model = self.model + no_weights_model = os.path.join(self.tmpdir.name, "no_weights") + + if not os.path.exists(no_weights_model): + LOGGER.info("\t+ Creating no weights model directory") + os.makedirs(no_weights_model) + + LOGGER.info("\t+ Saving pretrained config") + self.pretrained_config.save_pretrained(save_directory=no_weights_model) + + LOGGER.info("\t+ Creating no weights model") + state_dict = torch.nn.Linear(1, 1).state_dict() + + LOGGER.info("\t+ Saving no weights model") + save_file( + filename=os.path.join(no_weights_model, "model.safetensors"), + metadata={"format": "pt"}, + tensors=state_dict, + ) + + LOGGER.info("\t+ Loading no weights model") + with no_init_weights(): + self.model = no_weights_model + self.load_automodel_from_pretrained() + self.model = original_model + + LOGGER.info("\t+ Randomizing weights") randomize_weights(self.pretrained_model) + LOGGER.info("\t+ Tying model weights after randomization") + self.pretrained_model.tie_weights() def load_automodel_from_pretrained(self) -> None: LOGGER.info("\t+ Loading AutoModel from pretrained") self.pretrained_model = self.automodel_class.from_pretrained( self.model, - torch_dtype=self.torch_dtype, + **self.automodel_kwargs, **self.hub_kwargs, ).to(self.device) - def load_ortmodel(self) -> None: - LOGGER.info("\t+ Loading ORTModel") + def load_ortmodel_with_no_weights(self) -> None: + no_weights_model = os.path.join(self.tmpdir.name, "no_weights") + + LOGGER.info("\t+ Loading AutoModel with no weights") + self.load_automodel_with_no_weights() + self.delete_pretrained_model() + + LOGGER.info("\t+ Loading ORTModel with no weights") + with no_init_weights(): + original_model = self.model + self.model = no_weights_model + self.load_ortmodel_from_pretrained() + self.model = original_model + + def load_ortmodel_from_pretrained(self) -> None: self.pretrained_model = self.ortmodel_class.from_pretrained( self.model, export=self.export, - provider=self.config.provider, session_options=self.session_options, provider_options=self.provider_options, use_io_binding=self.config.use_io_binding, + provider=self.config.provider, **self.ortmodel_kwargs, **self.hub_kwargs, ) - # exported or not, the onnx model is/was here - self.model = self.pretrained_model.model_save_dir @property - def ortmodel_kwargs(self) -> Dict[str, Any]: - if self.is_text_generation_model(): - return {"use_cache": self.config.use_cache, "use_merged": self.use_merged} + def is_optimized(self) -> bool: + return self.config.auto_optimization or self.config.optimization + + @property + def is_quantized(self) -> bool: + return self.config.auto_quantization or self.config.quantization + + @property + def automodel_kwargs(self) -> Dict[str, Any]: + kwargs = {} + + if self.config.torch_dtype is not None and hasattr(torch, self.config.torch_dtype): + kwargs["torch_dtype"] = getattr(torch, self.config.torch_dtype) else: - return {} + kwargs["torch_dtype"] = self.config.torch_dtype + + return kwargs @property - def export_task(self) -> str: - return self.task + "-with-past" if self.config.use_cache and self.is_text_generation_model() else self.task + def ortmodel_kwargs(self) -> Dict[str, Any]: + kwargs = {} - def export_automodel(self) -> None: - LOGGER.info("\t+ Exporting AutoModel to ONNX") - exported_model_dir = f"{self.tmpdir.name}/exported_model" - self.merging_config, self.models_and_onnx_configs = main_export( - self.model, - output=exported_model_dir, - task=self.export_task, - device=self.device, - fp16=self.torch_dtype == torch.float16, - **self.hub_kwargs, - # we hijack the model instantiation and use our random weights model - model=self.pretrained_model, - ) - self.model = exported_model_dir + if self.is_text_generation_model(): + kwargs["use_cache"] = self.config.use_cache + kwargs["use_merged"] = self.config.use_merged - def merge_onnx_files(self) -> None: - LOGGER.info("\t+ Post-processing the exported model") - self.merging_config.post_process_exported_models(self.model, self.models_and_onnx_configs, None) + return kwargs @property def onnx_files_names(self): assert os.path.isdir(self.model), f"{self.model} is not a directory" return [file for file in os.listdir(self.model) if file.endswith(".onnx")] + @property + def onnx_files_names_to_quantize(self): + assert os.path.isdir(self.model), f"{self.model} is not a directory" + if self.config.use_merged: + # we filter merging components since they're not used for inference + # this also allows for calibration of one merged component models (like gpt2) + return [ + model + for model in self.onnx_files_names + if model not in [ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME] + ] + else: + return self.onnx_files_names + def optimize_onnx_files(self) -> None: LOGGER.info("\t+ Attempting optimization") - optimized_model_path = f"{self.tmpdir.name}/optimized" + optimized_model_path = os.path.join(self.tmpdir.name, "optimized") LOGGER.info("\t+ Processing optimization config") if self.config.auto_optimization is not None: optimization_config = AutoOptimizationConfig.with_optimization_level( optimization_level=self.config.auto_optimization, - for_gpu=self.device == "cuda", + for_gpu=(self.device == "cuda"), **self.config.auto_optimization_config, ) elif self.config.optimization: optimization_config = OptimizationConfig( - optimize_for_gpu=self.device == "cuda", **self.config.optimization_config + optimize_for_gpu=(self.device == "cuda"), + **self.config.optimization_config, ) LOGGER.info("\t+ Creating optimizer") optimizer = ORTOptimizer.from_pretrained(self.model, file_names=self.onnx_files_names) @@ -241,28 +264,26 @@ def optimize_onnx_files(self) -> None: use_external_data_format=None, one_external_file=True, ) - self.model = optimized_model_path - @property - def onnx_files_names_to_quantize(self): - assert os.path.isdir(self.model), f"{self.model} is not a directory" - if self.config.use_merged: - # we filter merging components since they're not used for inference - # this also allows for calibration of one merged component models (like gpt2) - return [ - model - for model in self.onnx_files_names - if model not in [ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME] - ] - else: - return self.onnx_files_names + if self.pretrained_processor is not None: + self.pretrained_processor.save_pretrained(optimized_model_path) + + if self.pretrained_config is not None: + self.pretrained_config.save_pretrained(optimized_model_path) + + self.model = optimized_model_path def quantize_onnx_files(self) -> None: LOGGER.info("\t+ Attempting quantization") quantized_model_path = f"{self.tmpdir.name}/quantized" - LOGGER.info("\t+ Processing quantization config") + if self.config.calibration and len(self.onnx_files_names_to_quantize) > 1: - raise NotImplementedError("Calibration is not supported for models with multiple components") + raise NotImplementedError( + "Calibration is not supported for models with multiple components. " + f"Found {len(self.onnx_files_names_to_quantize)} components." + ) + + LOGGER.info("\t+ Processing quantization config") if self.config.auto_quantization is not None: self.config.auto_quantization_config = format_quantization_config(self.config.auto_quantization_config) auto_quantization_config_class = getattr(AutoQuantizationConfig, self.config.auto_quantization) @@ -270,10 +291,10 @@ def quantize_onnx_files(self) -> None: elif self.config.quantization: self.config.quantization_config = format_quantization_config(self.config.quantization_config) quantization_config = QuantizationConfig(**self.config.quantization_config) - LOGGER.info(f"\t+ Model has {len(self.onnx_files_names_to_quantize)} components to quantize") - if len(self.onnx_files_names_to_quantize) == 1: - LOGGER.info("\t+ Creating quantizer") - quantizer = ORTQuantizer.from_pretrained(self.model, file_name=self.onnx_files_names_to_quantize[0]) + + for onnx_file_name_to_quantize in self.onnx_files_names_to_quantize: + LOGGER.info(f"\t+ Creating quantizer for {onnx_file_name_to_quantize}") + quantizer = ORTQuantizer.from_pretrained(self.model, file_name=onnx_file_name_to_quantize) if self.config.calibration: LOGGER.info("\t+ Processing calibration config") preprocess_class = get_class(self.config.calibration_config.pop("preprocess_class")) @@ -295,6 +316,7 @@ def quantize_onnx_files(self) -> None: ) else: calibration_tensors_range = None + LOGGER.info("\t+ Quantizing model") quantizer.quantize( save_dir=quantized_model_path, @@ -304,20 +326,13 @@ def quantize_onnx_files(self) -> None: use_external_data_format=False, preprocessor=None, ) - else: - for onnx_file_name_to_quantize in self.onnx_files_names_to_quantize: - LOGGER.info(f"\t+ Creating quantizer for {onnx_file_name_to_quantize}") - quantizer = ORTQuantizer.from_pretrained(self.model, file_name=onnx_file_name_to_quantize) - LOGGER.info(f"\t+ Quantizing {onnx_file_name_to_quantize}") - quantizer.quantize( - save_dir=quantized_model_path, - quantization_config=quantization_config, - calibration_tensors_range=None, - file_suffix="", - # TODO: add support for these - use_external_data_format=False, - preprocessor=None, - ) + + if self.pretrained_processor is not None: + self.pretrained_processor.save_pretrained(quantized_model_path) + + if self.pretrained_config is not None: + self.pretrained_config.save_pretrained(quantized_model_path) + self.model = quantized_model_path @property @@ -327,7 +342,7 @@ def inputs_names(self) -> List[str]: elif hasattr(self.pretrained_model, "input_names"): return self.pretrained_model.input_names else: - return {} + return [] def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]: inputs = super().prepare_inputs(inputs) @@ -368,7 +383,7 @@ def prepare_for_inference(self, **kwargs) -> None: f"position_ids:{batch_size}x{sequence_length + max_new_tokens}" ), } - self.load_ortmodel() + self.load_ortmodel_from_pretrained() self.validate_provider() def train( diff --git a/optimum_benchmark/backends/optimum_utils.py b/optimum_benchmark/backends/optimum_utils.py deleted file mode 100644 index 0ca4515b..00000000 --- a/optimum_benchmark/backends/optimum_utils.py +++ /dev/null @@ -1,379 +0,0 @@ -import os -from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union - -import torch -from optimum.exporters.onnx.__main__ import ( - DEFAULT_DUMMY_SHAPES, - ONNX_WEIGHTS_NAME, - # UNPICKABLE_ARCHS, - # AtolError, - AutoTokenizer, - OnnxConfigWithPast, - # OutputMatchError, - RequestsConnectionError, - # ShapeError, - TasksManager, - _get_submodels_and_onnx_configs, - export_models, - is_torch_available, - logger, - maybe_load_preprocessors, - maybe_save_preprocessors, -) - -if TYPE_CHECKING: - from optimum.exporters.onnx import OnnxConfig - from transformers import PreTrainedModel - - -# rewrite of the main_export function from optimum.exporters.onnx.__main__ -# to use the model passed in as an argument instead of loading it from the model_name_or_path -def main_export( - model_name_or_path: str, - output: Union[str, Path], - task: str = "auto", - opset: Optional[int] = None, - device: str = "cpu", - fp16: Optional[bool] = False, - optimize: Optional[str] = None, - monolith: bool = False, - no_post_process: bool = False, - framework: Optional[str] = None, - atol: Optional[float] = None, - cache_dir: Optional[str] = None, - trust_remote_code: bool = False, - pad_token_id: Optional[int] = None, - subfolder: str = "", - revision: str = "main", - force_download: bool = False, - local_files_only: bool = False, - use_auth_token: Optional[Union[bool, str]] = None, - for_ort: bool = False, - do_validation: bool = True, - model_kwargs: Optional[Dict[str, Any]] = None, - custom_onnx_configs: Optional[Dict[str, "OnnxConfig"]] = None, - fn_get_submodels: Optional[Callable] = None, - use_subprocess: bool = False, - _variant: str = "default", - ######################################## - model: Optional["PreTrainedModel"] = None, - ######################################## - **kwargs_shapes, -): - """ - Full-suite ONNX export. - - Args: - > Required parameters - - model_name_or_path (`str`): - Model ID on huggingface.co or path on disk to the model repository to export. - output (`Union[str, Path]`): - Path indicating the directory where to store the generated ONNX model. - - > Optional parameters - - task (`Optional[str]`, defaults to `None`): - The task to export the model for. If not specified, the task will be auto-inferred based on the model. For decoder models, - use `xxx-with-past` to export the model using past key values in the decoder. - opset (`Optional[int]`, defaults to `None`): - If specified, ONNX opset version to export the model with. Otherwise, the default opset for the given model architecture - will be used. - device (`str`, defaults to `"cpu"`): - The device to use to do the export. Defaults to "cpu". - fp16 (`Optional[bool]`, defaults to `"False"`): - Use half precision during the export. PyTorch-only, requires `device="cuda"`. - optimize (`Optional[str]`, defaults to `None`): - Allows to run ONNX Runtime optimizations directly during the export. Some of these optimizations are specific to - ONNX Runtime, and the resulting ONNX will not be usable with other runtime as OpenVINO or TensorRT. - Available options: `"O1", "O2", "O3", "O4"`. Reference: [`~optimum.onnxruntime.AutoOptimizationConfig`] - monolith (`bool`, defaults to `False`): - Forces to export the model as a single ONNX file. - no_post_process (`bool`, defaults to `False`): - Allows to disable any post-processing done by default on the exported ONNX models. - framework (`Optional[str]`, defaults to `None`): - The framework to use for the ONNX export (`"pt"` or `"tf"`). If not provided, will attempt to automatically detect - the framework for the checkpoint. - atol (`Optional[float]`, defaults to `None`): - If specified, the absolute difference tolerance when validating the model. Otherwise, the default atol for the model will be used. - cache_dir (`Optional[str]`, defaults to `None`): - Path indicating where to store cache. The default Hugging Face cache path will be used by default. - trust_remote_code (`bool`, defaults to `False`): - Allows to use custom code for the modeling hosted in the model repository. This option should only be set for repositories - you trust and in which you have read the code, as it will execute on your local machine arbitrary code present in the - model repository. - pad_token_id (`Optional[int]`, defaults to `None`): - This is needed by some models, for some tasks. If not provided, will attempt to use the tokenizer to guess it. - subfolder (`str`, defaults to `""`): - In case the relevant files are located inside a subfolder of the model repo either locally or on huggingface.co, you can - specify the folder name here. - revision (`str`, defaults to `"main"`): - Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id. - force_download (`bool`, defaults to `False`): - Whether or not to force the (re-)download of the model weights and configuration files, overriding the - cached versions if they exist. - local_files_only (`Optional[bool]`, defaults to `False`): - Whether or not to only look at local files (i.e., do not try to download the model). - use_auth_token (`Optional[str]`, defaults to `None`): - The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated - when running `transformers-cli login` (stored in `~/.huggingface`). - model_kwargs (`Optional[Dict[str, Any]]`, defaults to `None`): - Experimental usage: keyword arguments to pass to the model during - the export. This argument should be used along the `custom_onnx_configs` argument - in case, for example, the model inputs/outputs are changed (for example, if - `model_kwargs={"output_attentions": True}` is passed). - custom_onnx_configs (`Optional[Dict[str, OnnxConfig]]`, defaults to `None`): - Experimental usage: override the default ONNX config used for the given model. This argument may be useful for advanced users that desire a finer-grained control on the export. An example is available [here](https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/export_a_model). - fn_get_submodels (`Optional[Callable]`, defaults to `None`): - Experimental usage: Override the default submodels that are used at the export. This is - especially useful when exporting a custom architecture that needs to split the ONNX (e.g. encoder-decoder). If unspecified with custom models, optimum will try to use the default submodels used for the given task, with no guarantee of success. - use_subprocess (`bool`): - Do the ONNX exported model validation in subprocesses. This is especially useful when - exporting on CUDA device, where ORT does not release memory at inference session - destruction. When set to `True`, the `main_export` call should be guarded in - `if __name__ == "__main__":` block. - _variant (`str`, defaults to `default`): - Specify the variant of the ONNX export to use. - **kwargs_shapes (`Dict`): - Shapes to use during inference. This argument allows to override the default shapes used during the ONNX export. - - Example usage: - ```python - >>> from optimum.exporters.onnx import main_export - - >>> main_export("gpt2", output="gpt2_onnx/") - ``` - """ - if optimize == "O4" and device != "cuda": - raise ValueError( - "Requested O4 optimization, but this optimization requires to do the export on GPU." - " Please pass the argument `--device cuda`." - ) - - if (framework == "tf" and fp16 is True) or not is_torch_available(): - raise ValueError("The --fp16 option is supported only for PyTorch.") - - if fp16 is True and device == "cpu": - raise ValueError( - "FP16 export is supported only when exporting on GPU. Please pass the option `--device cuda`." - ) - float_dtype = "fp16" - else: - float_dtype = "fp32" - - output = Path(output) - if not output.exists(): - output.mkdir(parents=True) - - if for_ort: - logger.warning( - "The option --for-ort was passed, but its behavior is now the default in the ONNX exporter" - " and passing it is not required anymore." - ) - - original_task = task - task = TasksManager.map_from_synonym(task) - - framework = TasksManager.determine_framework(model_name_or_path, subfolder=subfolder, framework=framework) - - # get the shapes to be used to generate dummy inputs - input_shapes = {} - for input_name in DEFAULT_DUMMY_SHAPES.keys(): - input_shapes[input_name] = ( - kwargs_shapes[input_name] if input_name in kwargs_shapes else DEFAULT_DUMMY_SHAPES[input_name] - ) - - torch_dtype = None if fp16 is False else torch.float16 - - if task == "auto": - try: - task = TasksManager.infer_task_from_model(model_name_or_path) - except KeyError as e: - raise KeyError( - f"The task could not be automatically inferred. Please provide the argument --task with the relevant task from {', '.join(TasksManager.get_all_tasks())}. Detailed error: {e}" - ) - except RequestsConnectionError as e: - raise RequestsConnectionError( - f"The task could not be automatically inferred as this is available only for models hosted on the Hugging Face Hub. Please provide the argument --task with the relevant task from {', '.join(TasksManager.get_all_tasks())}. Detailed error: {e}" - ) - - if model is None: - model = TasksManager.get_model_from_task( - task, - model_name_or_path, - subfolder=subfolder, - revision=revision, - cache_dir=cache_dir, - use_auth_token=use_auth_token, - local_files_only=local_files_only, - force_download=force_download, - trust_remote_code=trust_remote_code, - framework=framework, - torch_dtype=torch_dtype, - device=device, - ) - - custom_architecture = False - is_stable_diffusion = "stable-diffusion" in task - model_type = "stable-diffusion" if is_stable_diffusion else model.config.model_type.replace("_", "-") - - if not is_stable_diffusion: - if model_type in TasksManager._UNSUPPORTED_CLI_MODEL_TYPE: - raise ValueError( - f"{model_type} is not supported yet. Only {TasksManager._SUPPORTED_CLI_MODEL_TYPE} are supported. " - f"If you want to support {model_type} please propose a PR or open up an issue." - ) - if model.config.model_type.replace("-", "_") not in TasksManager.get_supported_model_type_for_task( - task, exporter="onnx" - ): - custom_architecture = True - - # TODO: support onnx_config.py in the model repo - if custom_architecture and custom_onnx_configs is None: - raise ValueError( - f"Trying to export a {model.config.model_type.replace('-', '_')} model, that is a custom or unsupported architecture for the task {task}, but no custom onnx configuration was passed as `custom_onnx_configs`. Please refer to https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/export_a_model#custom-export-of-transformers-models for an example on how to export custom models. For the task {task}, the Optimum ONNX exporter supports natively the architectures: {TasksManager.get_supported_model_type_for_task(task, exporter='onnx')}." - ) - - if custom_architecture and original_task == "auto": - raise ValueError( - f'Automatic task detection is not supported with custom architectures. Please specify the `task` argument. Suggestion: task="{task}" (or task="{task}-with-past" if the model is decoder-based and supports KV cache)' - ) - - if ( - not custom_architecture - and not is_stable_diffusion - and task + "-with-past" in TasksManager.get_supported_tasks_for_model_type(model_type, "onnx") - ): - if original_task == "auto": # Make -with-past the default if --task was not explicitly specified - task = task + "-with-past" - else: - logger.info( - f"The task `{task}` was manually specified, and past key values will not be reused in the decoding." - f" if needed, please pass `--task {task}-with-past` to export using the past key values." - ) - - if task.endswith("-with-past") and monolith is True: - task_non_past = task.replace("-with-past", "") - raise ValueError( - f"The task {task} is not compatible with the --monolith argument. Please either use" - f" `--task {task_non_past} --monolith`, or `--task {task}` without the monolith argument." - ) - - if original_task == "auto": - synonyms_for_task = sorted(TasksManager.synonyms_for_task(task)) - if synonyms_for_task: - synonyms_for_task = ", ".join(synonyms_for_task) - possible_synonyms = f" (possible synonyms are: {synonyms_for_task})" - else: - possible_synonyms = "" - logger.info(f"Automatic task detection to {task}{possible_synonyms}.") - - # The preprocessors are loaded as they may be useful to export the model. Notably, some of the static input shapes may be stored in the - # preprocessors config. - preprocessors = maybe_load_preprocessors( - model_name_or_path, subfolder=subfolder, trust_remote_code=trust_remote_code - ) - onnx_config, models_and_onnx_configs = _get_submodels_and_onnx_configs( - model=model, - task=task, - monolith=monolith, - custom_onnx_configs=custom_onnx_configs if custom_onnx_configs is not None else {}, - custom_architecture=custom_architecture, - float_dtype=float_dtype, - fn_get_submodels=fn_get_submodels, - preprocessors=preprocessors, - _variant=_variant, - ) - - if not is_stable_diffusion: - needs_pad_token_id = ( - isinstance(onnx_config, OnnxConfigWithPast) - and getattr(model.config, "pad_token_id", None) is None - and task in ["text-classification"] - ) - if needs_pad_token_id: - if pad_token_id is not None: - model.config.pad_token_id = pad_token_id - else: - try: - tok = AutoTokenizer.from_pretrained(model_name_or_path) - model.config.pad_token_id = tok.pad_token_id - except Exception: - raise ValueError( - "Could not infer the pad token id, which is needed in this case, please provide it with the --pad_token_id argument" - ) - - # Ensure the requested opset is sufficient - if opset is None: - opset = onnx_config.DEFAULT_ONNX_OPSET - - if opset < onnx_config.DEFAULT_ONNX_OPSET: - raise ValueError( - f"Opset {opset} is not sufficient to export {model_type}. " - f"At least {onnx_config.DEFAULT_ONNX_OPSET} is required." - ) - if atol is None: - atol = onnx_config.ATOL_FOR_VALIDATION - if isinstance(atol, dict): - atol = atol[task.replace("-with-past", "")] - - # Saving the model config and preprocessor as this is needed sometimes. - model.config.save_pretrained(output) - generation_config = getattr(model, "generation_config", None) - if generation_config is not None: - generation_config.save_pretrained(output) - maybe_save_preprocessors(model_name_or_path, output) - - if model.config.is_encoder_decoder and task.startswith("text-generation"): - raise ValueError( - f"model.config.is_encoder_decoder is True and task is `{task}`, which are incompatible. If the task was auto-inferred, please fill a bug report" - f"at https://github.com/huggingface/optimum, if --task was explicitly passed, make sure you selected the right task for the model," - f" referring to `optimum.exporters.tasks.TaskManager`'s `_TASKS_TO_AUTOMODELS`." - ) - - onnx_files_subpaths = [key + ".onnx" for key in models_and_onnx_configs.keys()] - else: - # save the subcomponent configuration - for model_name in models_and_onnx_configs: - subcomponent = models_and_onnx_configs[model_name][0] - if hasattr(subcomponent, "save_config"): - subcomponent.save_config(output / model_name) - elif hasattr(subcomponent, "config") and hasattr(subcomponent.config, "save_pretrained"): - subcomponent.config.save_pretrained(output / model_name) - - onnx_files_subpaths = [os.path.join(name_dir, ONNX_WEIGHTS_NAME) for name_dir in models_and_onnx_configs] - - # Saving the additional components needed to perform inference. - model.scheduler.save_pretrained(output.joinpath("scheduler")) - - feature_extractor = getattr(model, "feature_extractor", None) - if feature_extractor is not None: - feature_extractor.save_pretrained(output.joinpath("feature_extractor")) - - tokenizer = getattr(model, "tokenizer", None) - if tokenizer is not None: - tokenizer.save_pretrained(output.joinpath("tokenizer")) - - tokenizer_2 = getattr(model, "tokenizer_2", None) - if tokenizer_2 is not None: - tokenizer_2.save_pretrained(output.joinpath("tokenizer_2")) - - model.save_config(output) - - _, onnx_outputs = export_models( - models_and_onnx_configs=models_and_onnx_configs, - opset=opset, - output_dir=output, - output_names=onnx_files_subpaths, - input_shapes=input_shapes, - device=device, - dtype="fp16" if fp16 is True else None, - model_kwargs=model_kwargs, - ) - - # for the post processing later we don't wanna keep models - for key in models_and_onnx_configs.keys(): - models_and_onnx_configs[key] = ("dummy_model", models_and_onnx_configs[key][1]) - - return onnx_config, models_and_onnx_configs diff --git a/optimum_benchmark/backends/pytorch/backend.py b/optimum_benchmark/backends/pytorch/backend.py index 0cf2e9c9..1e948c77 100644 --- a/optimum_benchmark/backends/pytorch/backend.py +++ b/optimum_benchmark/backends/pytorch/backend.py @@ -8,19 +8,21 @@ from datasets import Dataset from safetensors.torch import save_file from transformers import TrainerCallback, TrainerState +from transformers.modeling_utils import no_init_weights from transformers.utils import ModelOutput from transformers.utils.logging import set_verbosity_error from ..base import Backend +from ..peft_utils import get_peft_config_class from .config import PyTorchConfig from .utils import TransformersDataParallel, randomize_weights -# bachend logger -LOGGER = getLogger("pytorch") - # disable transformers logging set_verbosity_error() +# bachend logger +LOGGER = getLogger("pytorch") + class PyTorchBackend(Backend[PyTorchConfig]): NAME: str = "pytorch" @@ -88,12 +90,9 @@ def configure(self, config: PyTorchConfig) -> None: self.load_model_from_pretrained() # Eval mode - if self.config.eval_mode: - if self.is_diffusion_pipeline(): - LOGGER.info("\t+ Diffusion pipeline is in eval mode") - else: - LOGGER.info("\t+ Turning on model's eval mode") - self.pretrained_model.eval() + if self.config.eval_mode and not self.is_diffusion_pipeline(): + LOGGER.info("\t+ Turning on model's eval mode") + self.pretrained_model.eval() # BetterTransformer if self.config.to_bettertransformer: @@ -117,11 +116,9 @@ def configure(self, config: PyTorchConfig) -> None: ) if self.config.peft_strategy is not None: - LOGGER.info("\t+ Applying PEFT") from peft import get_peft_model - from ..peft_utils import get_peft_config_class - + LOGGER.info("\t+ Using PEFT") peft_config_class = get_peft_config_class(self.config.peft_strategy) peft_config = peft_config_class(**self.config.peft_config) self.pretrained_model = get_peft_model(self.pretrained_model, peft_config=peft_config) @@ -153,7 +150,7 @@ def load_model_from_pretrained(self) -> None: LOGGER.info(f"\t+ Moving pipeline to device: {self.device}") self.pretrained_model.to(self.device) elif self.is_bnb_quantized(): - LOGGER.info("\t+ Loading BnB quantized model") + LOGGER.info("\t+ Loading quantized model") self.pretrained_model = self.automodel_class.from_pretrained( pretrained_model_name_or_path=self.model, device_map=self.config.device_map, @@ -193,38 +190,33 @@ def load_model_from_pretrained(self) -> None: ) def load_model_with_no_weights(self) -> None: - self.tmp_dir = TemporaryDirectory() - - original_model = self.model - no_weights_model = os.path.join(self.tmp_dir.name, "no_weights") + self.tmpdir = TemporaryDirectory() + no_weights_model = os.path.join(self.tmpdir.name, "no_weights") - LOGGER.info("\t+ Creating no weights model directory") if not os.path.exists(no_weights_model): + LOGGER.info("\t+ Creating no weights model directory") os.makedirs(no_weights_model) if self.is_quantized(): # tricking from_pretrained to load the model as if it was quantized self.pretrained_config.quantization_config = self.quantization_config.to_dict() - LOGGER.info(f"\t+ Saving pretrained config to {no_weights_model}") + LOGGER.info("\t+ Saving pretrained config") self.pretrained_config.save_pretrained(save_directory=no_weights_model) - LOGGER.info(f"\t+ Creating no weights model to {no_weights_model}") + LOGGER.info("\t+ Creating no weights model") state_dict = torch.nn.Linear(1, 1).state_dict() if self.is_exllamav2(): # for exllamav2 we need to add g_idx to the state_dict - LOGGER.info("\t+ Loading meta model") with torch.device("meta"): meta_model = self.automodel_class.from_config(self.pretrained_config) - LOGGER.info("\t+ Setting g_idx for ExllamaV2") for name, module in meta_model.named_modules(): - # loading to exllama v2's QuantLinear creates g_idx with bad values if hasattr(module, "in_features"): state_dict[name + ".g_idx"] = torch.ones((module.in_features,), dtype=torch.int32) - LOGGER.info(f"\t+ Saving no weights model to {no_weights_model}") + LOGGER.info("\t+ Saving no weights model") save_file( filename=os.path.join(no_weights_model, "model.safetensors"), metadata={"format": "pt"}, @@ -232,13 +224,14 @@ def load_model_with_no_weights(self) -> None: ) LOGGER.info("\t+ Loading no weights model") - self.model = no_weights_model - self.load_model_from_pretrained() - self.model = original_model + with no_init_weights(): + original_model = self.model + self.model = no_weights_model + self.load_model_from_pretrained() + self.model = original_model if not self.is_quantized(): # TODO: verify if this can be extended to quantized models - # (not sure how torch.Tensor.normal_ works on quantized tensors) LOGGER.info("\t+ Randomizing model weights") randomize_weights(self.pretrained_model) LOGGER.info("\t+ Tying model weights after randomization") @@ -316,11 +309,6 @@ def automodel_kwargs(self) -> Dict[str, Any]: # config by passing quantization_config to from_pretrained kwargs["quantization_config"] = self.quantization_config - if self.config.no_weights: - # when no_weights=True, the state_dict is empty so from_pretrained will try to randomly - # initialize every missing weights, we don't want that, so we set fast_init to False - kwargs["_fast_init"] = False - return kwargs def forward(self, input: Dict[str, Any], kwargs: Dict[str, Any]) -> "ModelOutput": @@ -382,8 +370,8 @@ def clean(self) -> None: LOGGER.info("\t+ Emptying CUDA cache") torch.cuda.empty_cache() - if hasattr(self, "tmp_dir"): + if hasattr(self, "tmpdir"): LOGGER.info("\t+ Cleaning temporary directory") - self.tmp_dir.cleanup() + self.tmpdir.cleanup() gc.collect() diff --git a/setup.py b/setup.py index 7ac8f60a..63329634 100644 --- a/setup.py +++ b/setup.py @@ -52,7 +52,12 @@ "neural-compressor": [f"optimum[neural-compressor]>={OPTIMUM_VERSION}"], # gpu backends "onnxruntime-gpu": [f"optimum[onnxruntime-gpu]>={OPTIMUM_VERSION}"], - "onnxruntime-training": ["torch-ort", "onnxruntime-training"], + "onnxruntime-training": [ + "torch-ort", + "onnxruntime-training", + # # we use optimum from source, until the next release + "optimum@git+https://github.com/huggingface/optimum.git", + ], # docker-based backends "text-generation-inference": ["docker"], # specific settings