diff --git a/examples/neural_compressor_ptq_bert.yaml b/examples/neural_compressor_ptq_bert.yaml new file mode 100644 index 00000000..6268fbb0 --- /dev/null +++ b/examples/neural_compressor_ptq_bert.yaml @@ -0,0 +1,31 @@ +defaults: + - launcher: process + - benchmark: inference + - backend: neural-compressor + - experiment # inheriting experiment schema + - _self_ # for hydra 1.1 compatibility + - override hydra/job_logging: colorlog # colorful logging + - override hydra/hydra_logging: colorlog # colorful logging + +experiment_name: openvino_static_quant_bert +model: bert-base-uncased +device: cpu + +backend: + no_weights: true + ptq_quantization: true + calibration: true + +benchmark: + input_shapes: + batch_size: 1 + +hydra: + run: + dir: runs/${experiment_name} + sweep: + dir: sweeps/${experiment_name} + job: + chdir: true + env_set: + OVERRIDE_BENCHMARKS: 1 diff --git a/examples/onnxruntime_static_quant_vit.yaml b/examples/onnxruntime_static_quant_vit.yaml new file mode 100644 index 00000000..478cb1f7 --- /dev/null +++ b/examples/onnxruntime_static_quant_vit.yaml @@ -0,0 +1,32 @@ +defaults: + - launcher: process + - benchmark: inference + - backend: onnxruntime + - experiment # inheriting experiment schema + - _self_ # for hydra 1.1 compatibility + - override hydra/job_logging: colorlog # colorful logging + - override hydra/hydra_logging: colorlog # colorful logging + +experiment_name: onnxruntime_static_quant_vit +model: google/vit-base-patch16-224 +device: cpu + +backend: + quantization: true + quantization_config: + is_static: true + per_channel: false + + calibration: true + +hydra: + run: + dir: runs/${experiment_name} + sweep: + dir: sweeps/${experiment_name} + job: + chdir: true + env_set: + OVERRIDE_BENCHMARKS: 1 + CUDA_VISIBLE_DEVICES: 0 + CUDA_DEVICE_ORDER: PCI_BUS_ID diff --git a/examples/openvino_diffusion.yaml b/examples/openvino_diffusion.yaml index 5cd3a758..0c42eb0c 100644 --- a/examples/openvino_diffusion.yaml +++ b/examples/openvino_diffusion.yaml @@ -1,22 +1,19 @@ defaults: - - backend: openvino # default backend - - launcher: inline # default launcher - - benchmark: inference # default benchmark + - backend: openvino + - launcher: process + - benchmark: inference - experiment # inheriting experiment schema - _self_ # for hydra 1.1 compatibility - override hydra/job_logging: colorlog # colorful logging - override hydra/hydra_logging: colorlog # colorful logging -experiment_name: openvino_diffusion model: stabilityai/stable-diffusion-2-1 +experiment_name: openvino_diffusion device: cpu -launcher: - device_isolation: true - backend: - export: true reshape: true + export: true half: true benchmark: diff --git a/examples/openvino_static_quant_bert.yaml b/examples/openvino_static_quant_bert.yaml new file mode 100644 index 00000000..ba5f2051 --- /dev/null +++ b/examples/openvino_static_quant_bert.yaml @@ -0,0 +1,33 @@ +defaults: + - backend: openvino + - launcher: process + - benchmark: inference + - experiment # inheriting experiment schema + - _self_ # for hydra 1.1 compatibility + - override hydra/job_logging: colorlog # colorful logging + - override hydra/hydra_logging: colorlog # colorful logging + +experiment_name: openvino_static_quant_bert +model: bert-base-uncased +device: cpu + +backend: + export: true + no_weights: true + quantization: true + calibration: true + reshape: true + +benchmark: + input_shapes: + batch_size: 1 + +hydra: + run: + dir: runs/${experiment_name} + sweep: + dir: sweeps/${experiment_name} + job: + chdir: true + env_set: + OVERRIDE_BENCHMARKS: 1 diff --git a/examples/pytorch_bert.yaml b/examples/pytorch_bert.yaml index c26af141..41dc9d8c 100644 --- a/examples/pytorch_bert.yaml +++ b/examples/pytorch_bert.yaml @@ -1,7 +1,7 @@ defaults: - - backend: pytorch # default backend - - launcher: torchrun # default launcher - - benchmark: inference # default benchmark + - backend: pytorch + - launcher: process + - benchmark: inference - experiment # inheriting experiment schema - _self_ # for hydra 1.1 compatibility - override hydra/job_logging: colorlog # colorful logging diff --git a/optimum_benchmark/backends/neural_compressor/backend.py b/optimum_benchmark/backends/neural_compressor/backend.py index cbace02b..67e794c4 100644 --- a/optimum_benchmark/backends/neural_compressor/backend.py +++ b/optimum_benchmark/backends/neural_compressor/backend.py @@ -1,8 +1,10 @@ import gc +import os from logging import getLogger from tempfile import TemporaryDirectory from typing import Any, Dict +import torch from hydra.utils import get_class from neural_compressor.config import ( AccuracyCriterion, @@ -10,13 +12,19 @@ TuningCriterion, ) from optimum.intel.neural_compressor.quantization import INCQuantizer +from transformers.modeling_utils import no_init_weights +from transformers.utils.logging import set_verbosity_error +from ...generators.dataset_generator import DatasetGenerator from ..base import Backend from .config import INCConfig from .utils import TASKS_TO_INCMODELS LOGGER = getLogger("neural-compressor") +# disable transformers logging +set_verbosity_error() + class INCBackend(Backend[INCConfig]): NAME: str = "neural-compressor" @@ -45,20 +53,65 @@ def configure(self, config: INCConfig) -> None: self.tmpdir = TemporaryDirectory() if self.config.ptq_quantization: - self.load_automodel_from_pretrained() + if self.config.no_weights: + self.load_automodel_with_no_weights() + else: + self.load_automodel_from_pretrained() self.quantize_automodel() self.delete_pretrained_model() + self.load_incmodel_from_pretrained() + elif self.config.no_weights: + self.load_incmodel_with_no_weights() + else: + self.load_incmodel_from_pretrained() - self.load_incmodel_from_pretrained() + self.tmpdir.cleanup() def load_automodel_from_pretrained(self) -> None: - LOGGER.info("\t+ Loading AutoModel") + LOGGER.info("\t+ Loading AutoModel from pretrained") self.pretrained_model = self.automodel_class.from_pretrained(self.model, **self.hub_kwargs) + def load_automodel_with_no_weights(self) -> None: + no_weights_model = os.path.join(self.tmpdir.name, "no_weights") + + if not os.path.exists(no_weights_model): + LOGGER.info("\t+ Creating no weights model directory") + os.makedirs(no_weights_model) + + LOGGER.info("\t+ Saving pretrained config") + self.pretrained_config.save_pretrained(save_directory=no_weights_model) + + LOGGER.info("\t+ Creating no weights model") + state_dict = torch.nn.Linear(1, 1).state_dict() + + LOGGER.info("\t+ Saving no weights model") + torch.save(state_dict, os.path.join(no_weights_model, "pytorch_model.bin")) + + LOGGER.info("\t+ Loading no weights model") + with no_init_weights(): + original_model = self.model + self.model = no_weights_model + self.load_automodel_from_pretrained() + self.model = original_model + def load_incmodel_from_pretrained(self) -> None: - LOGGER.info("\t+ Loading INCModel") + LOGGER.info("\t+ Loading INCModel from pretrained") self.pretrained_model = self.incmodel_class.from_pretrained(self.model, **self.hub_kwargs) + def load_incmodel_with_no_weights(self) -> None: + no_weights_model = os.path.join(self.tmpdir.name, "no_weights") + + LOGGER.info("\t+ Loading AutoModel with no weights") + self.load_automodel_with_no_weights() + self.delete_pretrained_model() + + LOGGER.info("\t+ Loading INCModel with no weights") + with no_init_weights(): + original_model = self.model + self.model = no_weights_model + self.load_incmodel_from_pretrained() + self.model = original_model + def quantize_automodel(self) -> None: LOGGER.info("\t+ Attempting to quantize model") quantized_model_path = f"{self.tmpdir.name}/quantized" @@ -71,34 +124,33 @@ def quantize_automodel(self) -> None: ptq_quantization_config = PostTrainingQuantConfig(**ptq_quantization_config) LOGGER.info("\t+ Creating quantizer") quantizer = INCQuantizer.from_pretrained( - self.pretrained_model, task=self.task, seed=self.config.seed, + model=self.pretrained_model, # TODO: add support for these - eval_fn=None, calibration_fn=None, + eval_fn=None, ) if self.config.calibration: - LOGGER.info("\t+ Processing calibration config") - calibration_config = self.config.calibration_config.copy() - preprocess_class = get_class(calibration_config.pop("preprocess_class")) - calibration_config["preprocess_function"] = preprocess_class(model_name_or_path=self.model) - LOGGER.info("\t+ Loading calibration dataset") - calibration_dataset = quantizer.get_calibration_dataset(**calibration_config) + LOGGER.info("\t+ Generating calibration dataset") + dataset_shapes = {"dataset_size": 1, "sequence_length": 1, **self.model_shapes} + calibration_dataset = DatasetGenerator(task=self.task, dataset_shapes=dataset_shapes).generate() + columns_to_be_removed = list(set(calibration_dataset.column_names) - set(quantizer._signature_columns)) + calibration_dataset = calibration_dataset.remove_columns(columns_to_be_removed) else: calibration_dataset = None LOGGER.info("\t+ Quantizing model") quantizer.quantize( - quantization_config=ptq_quantization_config, save_directory=quantized_model_path, calibration_dataset=calibration_dataset, + quantization_config=ptq_quantization_config, # TODO: add support for these remove_unused_columns=True, data_collator=None, file_name=None, - batch_size=8, + batch_size=1, ) self.model = quantized_model_path diff --git a/optimum_benchmark/backends/neural_compressor/config.py b/optimum_benchmark/backends/neural_compressor/config.py index 4ae8e953..ff0dc456 100644 --- a/optimum_benchmark/backends/neural_compressor/config.py +++ b/optimum_benchmark/backends/neural_compressor/config.py @@ -47,22 +47,15 @@ } -CALIBRATION_CONFIG = { - "dataset_name": "glue", - "num_samples": 300, - "dataset_config_name": "sst2", - "dataset_split": "train", - "preprocess_batch": True, - "preprocess_class": "optimum_benchmark.preprocessors.glue.GluePreprocessor", -} - - @dataclass class INCConfig(BackendConfig): name: str = "neural_compressor" version: str = "${neural_compressor_version:}" _target_: str = "optimum_benchmark.backends.neural_compressor.backend.INCBackend" + # load options + no_weights: bool = False + # post-training quantization options ptq_quantization: bool = False ptq_quantization_config: Dict[str, Any] = field(default_factory=dict) @@ -80,6 +73,3 @@ def __post_init__(self): ) if self.ptq_quantization_config["approach"] == "static" and not self.calibration: raise ValueError("Calibration must be enabled when using static quantization.") - - if self.calibration: - self.calibration_config = OmegaConf.to_object(OmegaConf.merge(CALIBRATION_CONFIG, self.calibration_config)) diff --git a/optimum_benchmark/backends/onnxruntime/backend.py b/optimum_benchmark/backends/onnxruntime/backend.py index 096d0971..187a9c12 100644 --- a/optimum_benchmark/backends/onnxruntime/backend.py +++ b/optimum_benchmark/backends/onnxruntime/backend.py @@ -18,6 +18,7 @@ AutoCalibrationConfig, AutoOptimizationConfig, AutoQuantizationConfig, + CalibrationConfig, OptimizationConfig, QuantizationConfig, ) @@ -26,11 +27,17 @@ from transformers.modeling_utils import no_init_weights from transformers.utils.logging import set_verbosity_error +from ...generators.dataset_generator import DatasetGenerator from ..base import Backend from ..peft_utils import get_peft_config_class from ..pytorch.utils import randomize_weights from .config import ORTConfig -from .utils import TASKS_TO_ORTMODELS, TASKS_TO_ORTSD, format_quantization_config +from .utils import ( + TASKS_TO_ORTMODELS, + TASKS_TO_ORTSD, + format_calibration_config, + format_quantization_config, +) # disable transformers logging set_verbosity_error() @@ -104,7 +111,7 @@ def configure(self, config: ORTConfig) -> None: original_export = self.export self.model = self.pretrained_model.model_save_dir # self.model will point to a directory from here on - self.export = False # we disable export because we want to load the optimized/quantized onnx files + self.export = False # we disable export because we'll load the optimized/quantized model now if self.is_optimized: self.optimize_onnx_files() @@ -192,11 +199,15 @@ def load_ortmodel_from_pretrained(self) -> None: @property def is_optimized(self) -> bool: - return self.config.auto_optimization or self.config.optimization + return (self.config.auto_optimization is not None) or self.config.optimization @property def is_quantized(self) -> bool: - return self.config.auto_quantization or self.config.quantization + return (self.config.auto_quantization is not None) or self.config.quantization + + @property + def is_calibrated(self) -> bool: + return (self.config.auto_calibration is not None) or self.config.calibration @property def automodel_kwargs(self) -> Dict[str, Any]: @@ -244,8 +255,8 @@ def optimize_onnx_files(self) -> None: LOGGER.info("\t+ Processing optimization config") if self.config.auto_optimization is not None: optimization_config = AutoOptimizationConfig.with_optimization_level( - optimization_level=self.config.auto_optimization, for_gpu=(self.device == "cuda"), + optimization_level=self.config.auto_optimization, **self.config.auto_optimization_config, ) elif self.config.optimization: @@ -259,10 +270,10 @@ def optimize_onnx_files(self) -> None: optimizer.optimize( optimization_config, save_dir=optimized_model_path, - file_suffix="", # TODO: add support for these use_external_data_format=None, one_external_file=True, + file_suffix="", ) if self.pretrained_processor is not None: @@ -277,42 +288,62 @@ def quantize_onnx_files(self) -> None: LOGGER.info("\t+ Attempting quantization") quantized_model_path = f"{self.tmpdir.name}/quantized" - if self.config.calibration and len(self.onnx_files_names_to_quantize) > 1: + if self.is_calibrated and len(self.onnx_files_names_to_quantize) > 1: raise NotImplementedError( - "Calibration is not supported for models with multiple components. " + "Calibrated/Static Quantization is not supported for models with multiple components. " f"Found {len(self.onnx_files_names_to_quantize)} components." ) LOGGER.info("\t+ Processing quantization config") if self.config.auto_quantization is not None: - self.config.auto_quantization_config = format_quantization_config(self.config.auto_quantization_config) - auto_quantization_config_class = getattr(AutoQuantizationConfig, self.config.auto_quantization) - quantization_config = auto_quantization_config_class(**self.config.auto_quantization_config) + auto_quantization_config = format_quantization_config(self.config.auto_quantization_config) + auto_quantization_class = getattr(AutoQuantizationConfig, self.config.auto_quantization) + quantization_config = auto_quantization_class(**auto_quantization_config) elif self.config.quantization: - self.config.quantization_config = format_quantization_config(self.config.quantization_config) - quantization_config = QuantizationConfig(**self.config.quantization_config) + quantization_config = format_quantization_config(self.config.quantization_config) + quantization_config = QuantizationConfig(**quantization_config) + + if self.is_calibrated: + LOGGER.info("\t+ Generating calibration dataset") + dataset_shapes = {"dataset_size": 1, "sequence_length": 1, **self.model_shapes} + calibration_dataset = DatasetGenerator(task=self.task, dataset_shapes=dataset_shapes).generate() + columns_to_be_removed = list(set(calibration_dataset.column_names) - set(self.inputs_names)) + calibration_dataset = calibration_dataset.remove_columns(columns_to_be_removed) + + LOGGER.info("\t+ Processing calibration config") + if self.config.auto_calibration is not None: + LOGGER.info("\t+ Processing calibration config") + auto_calibration_method = getattr(AutoCalibrationConfig, self.config.auto_calibration) + calibration_config = auto_calibration_method( + calibration_dataset, + **self.config.auto_calibration_config, + ) + elif self.config.calibration: + LOGGER.info("\t+ Processing calibration config") + calibration_config = format_calibration_config(self.config.calibration_config) + calibration_config = CalibrationConfig( + dataset_name="calibration_dataset", + dataset_split=calibration_dataset.split, + dataset_num_samples=calibration_dataset.num_rows, + dataset_config_name=calibration_dataset.config_name, + **self.config.calibration_config, + ) for onnx_file_name_to_quantize in self.onnx_files_names_to_quantize: LOGGER.info(f"\t+ Creating quantizer for {onnx_file_name_to_quantize}") quantizer = ORTQuantizer.from_pretrained(self.model, file_name=onnx_file_name_to_quantize) - if self.config.calibration: - LOGGER.info("\t+ Processing calibration config") - preprocess_class = get_class(self.config.calibration_config.pop("preprocess_class")) - self.config.calibration_config["preprocess_function"] = preprocess_class(model_name_or_path=self.model) - LOGGER.info("\t+ Loading calibration dataset") - calibration_dataset = quantizer.get_calibration_dataset(**self.config.calibration_config) - LOGGER.info("\t+ Creating calibration config") - calibration_config = AutoCalibrationConfig.minmax(calibration_dataset) + + if self.is_calibrated: LOGGER.info("\t+ Fitting calibration tensors range") calibration_tensors_range = quantizer.fit( dataset=calibration_dataset, + use_gpu=(self.device == "cuda"), calibration_config=calibration_config, operators_to_quantize=quantization_config.operators_to_quantize, - use_gpu=self.device == "cuda", - # TODO: add support for these - batch_size=1, + # TODO: add support for these (maybe) use_external_data_format=False, force_symmetric_range=False, + batch_size=1, ) else: calibration_tensors_range = None @@ -322,9 +353,10 @@ def quantize_onnx_files(self) -> None: save_dir=quantized_model_path, quantization_config=quantization_config, calibration_tensors_range=calibration_tensors_range, - # TODO: add support for these + # TODO: add support for these (maybe) use_external_data_format=False, preprocessor=None, + file_suffix="", ) if self.pretrained_processor is not None: @@ -348,15 +380,12 @@ def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]: if self.library == "diffusers": return {"prompt": inputs["prompt"]} - for key in list(inputs.keys()): - # sometimes optimum onnx exported models don't have inputs - # that their pytorch counterparts have, for instance token_type_ids - if key not in self.inputs_names: - inputs.pop(key) - LOGGER.info(f"\t+ Moving inputs tensors to device {self.device}") - for key, value in inputs.items(): - inputs[key] = value.to(self.device) + for key, value in list(inputs.items()): + if key in self.inputs_names: + inputs[key] = value.to(self.device) + else: + inputs.pop(key) return inputs diff --git a/optimum_benchmark/backends/onnxruntime/config.py b/optimum_benchmark/backends/onnxruntime/config.py index 4f4b5dd4..7bca0ee0 100644 --- a/optimum_benchmark/backends/onnxruntime/config.py +++ b/optimum_benchmark/backends/onnxruntime/config.py @@ -8,65 +8,25 @@ from ..config import BackendConfig from ..peft_utils import PEFT_CONFIGS, PEFT_TASKS_TYPES -OPTIMIZATION_CONFIG = { - "optimization_level": 1, - "fp16": False, - "enable_transformers_specific_optimizations": True, - "enable_gelu_approximation": False, - "disable_gelu_fusion": False, - "disable_layer_norm_fusion": False, - "disable_attention_fusion": False, - "disable_skip_layer_norm_fusion": True, - "disable_bias_skip_layer_norm_fusion": False, - "disable_bias_gelu_fusion": False, - "use_mask_index": False, - "no_attention_mask": False, - "disable_embed_layer_norm_fusion": True, - "disable_shape_inference": False, - "use_multi_head_attention": False, - "enable_gemm_fast_gelu_fusion": False, - "use_raw_attention_mask": False, - "disable_group_norm_fusion": True, - "disable_packed_kv": True, -} - -AUTO_OPTIMIZATION_CONFIG = { - # auto optimization config depends on the level so we keep it minimal -} - QUANTIZATION_CONFIG = { "is_static": False, - "format": "QOperator", # QOperator, QDQ - "mode": "IntegerOps", # QLinearOps, IntegerOps - "activations_dtype": "QUInt8", # QInt8, QUInt8 - "activations_symmetric": False, - "weights_dtype": "QInt8", # QInt8, QUInt8 - "weights_symmetric": True, - "per_channel": False, - "reduce_range": False, - "operators_to_quantize": [ - "MatMul", - "Add", - ], + "format": "QOperator", + # is_static and format are mandatory } -AUTO_QUANTIZATION_CONFIG = { - "is_static": False, - # full auto quantization config depends on the strategy so we keep it minimal +CALIBRATION_CONFIG = { + "method": "MinMax" + # method is mandatory } -CALIBRATION_CONFIG = { - "dataset_name": "glue", - "num_samples": 300, - "dataset_config_name": "sst2", - "dataset_split": "train", - "preprocess_batch": True, - "preprocess_class": "optimum_benchmark.preprocessors.glue.GluePreprocessor", +AUTO_QUANTIZATION_CONFIG = { + "is_static": False, + # is_static is mandatory } TRT_PROVIDER_OPTIONS = { "trt_engine_cache_enable": True, - "trt_engine_cache_path": "tmp/trt_cache", + "trt_engine_cache_path": "/tmp/trt_cache", } DEVICE_PROVIDER_MAP = {"cpu": "CPUExecutionProvider", "cuda": "CUDAExecutionProvider"} @@ -102,26 +62,30 @@ class ORTConfig(BackendConfig): default_factory=lambda: {"enable_profiling": "${is_profiling:${benchmark.name}}"} ) - # optimization options + # null, O1, O2, O3, O4 + auto_optimization: Optional[str] = None + auto_optimization_config: Dict[str, Any] = field(default_factory=dict) + + # null, arm64, avx2, avx512, avx512_vnni, tensorrt + auto_quantization: Optional[str] = None + auto_quantization_config: Dict[str, Any] = field(default_factory=dict) + + # minmax, entropy, l2norm, percentiles + auto_calibration: Optional[str] = None + auto_calibration_config: Dict[str, Any] = field(default_factory=dict) + + # manual optimization options optimization: bool = False optimization_config: Dict[str, Any] = field(default_factory=dict) - # quantization options + # manual quantization options quantization: bool = False quantization_config: Dict[str, Any] = field(default_factory=dict) - # calibration options + # manual calibration options calibration: bool = False calibration_config: Dict[str, Any] = field(default_factory=dict) - # null, O1, O2, O3, O4 - auto_optimization: Optional[str] = None - auto_optimization_config: Dict[str, Any] = field(default_factory=dict) - - # null, arm64, avx2, avx512, avx512_vnni, tensorrt - auto_quantization: Optional[str] = None - auto_quantization_config: Dict[str, Any] = field(default_factory=dict) - # ort-training is basically a different package so we might need to separate these two backends in the future use_inference_session: bool = "${is_inference:${benchmark.name}}" @@ -139,31 +103,25 @@ def __post_init__(self): self.provider_options = OmegaConf.to_object(OmegaConf.merge(TRT_PROVIDER_OPTIONS, self.provider_options)) os.makedirs(self.provider_options["trt_engine_cache_path"], exist_ok=True) - if self.optimization: - self.optimization_config = OmegaConf.to_object( - OmegaConf.merge(OPTIMIZATION_CONFIG, self.optimization_config) - ) if self.quantization: self.quantization_config = OmegaConf.to_object( OmegaConf.merge(QUANTIZATION_CONFIG, self.quantization_config) ) # raise ValueError if the quantization is static but calibration is not enabled - if self.quantization_config["is_static"] and not self.calibration: + if self.quantization_config["is_static"] and self.auto_calibration is None and not self.calibration: raise ValueError( - "Quantization is static but calibration is not enabled. Please enable calibration or disable static quantization." + "Quantization is static but calibration is not enabled. " + "Please enable calibration or disable static quantization." ) - if self.auto_optimization is not None: - self.auto_optimization_config = OmegaConf.to_object( - OmegaConf.merge(AUTO_OPTIMIZATION_CONFIG, self.auto_optimization_config) - ) if self.auto_quantization is not None: self.auto_quantization_config = OmegaConf.to_object( OmegaConf.merge(AUTO_QUANTIZATION_CONFIG, self.auto_quantization_config) ) - if self.auto_quantization_config["is_static"] and not self.calibration: + if self.auto_quantization_config["is_static"] and self.auto_calibration is None and not self.calibration: raise ValueError( - "Quantization is static but calibration is not enabled. Please enable calibration or disable static quantization." + "Quantization is static but calibration is not enabled. " + "Please enable calibration or disable static quantization." ) if self.calibration: diff --git a/optimum_benchmark/backends/onnxruntime/utils.py b/optimum_benchmark/backends/onnxruntime/utils.py index ea2fac2d..5fe1f1aa 100644 --- a/optimum_benchmark/backends/onnxruntime/utils.py +++ b/optimum_benchmark/backends/onnxruntime/utils.py @@ -1,6 +1,11 @@ from typing import Any, Dict -from onnxruntime.quantization import QuantFormat, QuantizationMode, QuantType +from onnxruntime.quantization import ( + CalibrationMethod, + QuantFormat, + QuantizationMode, + QuantType, +) from optimum.pipelines import ORT_SUPPORTED_TASKS TASKS_TO_ORTSD = { @@ -11,6 +16,13 @@ TASKS_TO_ORTMODELS = {task: task_dict["class"][0] for task, task_dict in ORT_SUPPORTED_TASKS.items()} +def format_calibration_config(calibration_config: Dict[str, Any]) -> None: + if calibration_config.get("method", None) is not None: + calibration_config["method"] = CalibrationMethod[calibration_config["method"]] + + return calibration_config + + def format_quantization_config(quantization_config: Dict[str, Any]) -> None: """Format the quantization dictionary for onnxruntime.""" # the conditionals are here because some quantization strategies don't have all the options diff --git a/optimum_benchmark/backends/openvino/backend.py b/optimum_benchmark/backends/openvino/backend.py index e10ae0b8..e5eca700 100644 --- a/optimum_benchmark/backends/openvino/backend.py +++ b/optimum_benchmark/backends/openvino/backend.py @@ -1,18 +1,28 @@ import gc import inspect +import os from logging import getLogger from tempfile import TemporaryDirectory from typing import Any, Dict +import torch from hydra.utils import get_class from openvino.runtime import properties from optimum.intel.openvino import OVConfig as OVQuantizationConfig # naming conflict from optimum.intel.openvino import OVQuantizer +from safetensors.torch import save_file +from transformers.modeling_utils import no_init_weights +from transformers.utils.logging import set_verbosity_error +from ...generators.dataset_generator import DatasetGenerator from ..base import Backend +from ..pytorch.utils import randomize_weights from .config import OVConfig from .utils import TASKS_TO_OVMODEL +# disable transformers logging +set_verbosity_error() + LOGGER = getLogger("openvino") @@ -36,10 +46,8 @@ def configure(self, config: OVConfig) -> None: super().configure(config) self.ovmodel_class = get_class(TASKS_TO_OVMODEL[self.task]) - ortmodel_name = self.ovmodel_class.__name__ - LOGGER.info( - f"\t+ Inferred OVModel class {ortmodel_name} for task {self.task} and model_type {self.model_type}" - ) + ovmodel_name = self.ovmodel_class.__name__ + LOGGER.info(f"\t+ Inferred class {ovmodel_name} for task {self.task} and model_type {self.model_type}") self.openvino_config = self.config.openvino_config.copy() if self.config.inter_op_num_threads is not None: @@ -49,38 +57,96 @@ def configure(self, config: OVConfig) -> None: if self.config.intra_op_num_threads is not None: raise NotImplementedError("OVBackend does not support intra_op_num_threads") + if self.library == "diffusers" and self.config.no_weights: + raise NotImplementedError("Diffusers models can't be loaded with no weights") + self.tmpdir = TemporaryDirectory() if self.config.quantization: - self.load_automodel() + if self.config.no_weights: + self.load_automodel_with_no_weights() + else: + self.load_automodel_from_pretrained() self.quantize_automodel() - self.delete_pretrained_model() # deletes automodel - self.export = False # quantized model is already exported + self.delete_pretrained_model() + self.load_ovmodel_from_pretrained() + elif self.config.no_weights: + self.load_ovmodel_with_no_weights() else: - self.export = self.config.export # to not change the config's values + self.load_ovmodel_from_pretrained() - self.load_ovmodel() self.tmpdir.cleanup() - def load_automodel(self) -> None: - self.pretrained_model = self.automodel_class.from_pretrained(self.model, **self.hub_kwargs) - - @property - def ovmodel_kwargs(self) -> Dict[str, Any]: - if self.is_text_generation_model(): - return {"use_cache": self.config.use_cache, "use_merged": self.config.use_merged} - else: - return {} - - def load_ovmodel(self) -> None: + def load_ovmodel_from_pretrained(self) -> None: self.pretrained_model = self.ovmodel_class.from_pretrained( self.model, - export=self.export, ov_config=self.openvino_config, + export=self.config.export and not self.config.quantization, + # in case of quantization, the model will be exported by the quantizer **self.ovmodel_kwargs, **self.hub_kwargs, ) + def load_ovmodel_with_no_weights(self) -> None: + no_weights_model = os.path.join(self.tmpdir.name, "no_weights") + + LOGGER.info("\t+ Loading AutoModel with no weights") + self.load_automodel_with_no_weights() + self.delete_pretrained_model() + + LOGGER.info("\t+ Loading OVModel with no weights") + with no_init_weights(): + original_model = self.model + self.model = no_weights_model + self.load_ovmodel_from_pretrained() + self.model = original_model + + def load_automodel_from_pretrained(self) -> None: + LOGGER.info("\t+ Loading AutoModel from pretrained") + self.pretrained_model = self.automodel_class.from_pretrained(self.model, **self.hub_kwargs) + + def load_automodel_with_no_weights(self) -> None: + original_model = self.model + no_weights_model = os.path.join(self.tmpdir.name, "no_weights") + + if not os.path.exists(no_weights_model): + LOGGER.info("\t+ Creating no weights model directory") + os.makedirs(no_weights_model) + + LOGGER.info("\t+ Saving pretrained config") + self.pretrained_config.save_pretrained(save_directory=no_weights_model) + + LOGGER.info("\t+ Creating no weights model") + state_dict = torch.nn.Linear(1, 1).state_dict() + + LOGGER.info("\t+ Saving no weights model") + save_file( + filename=os.path.join(no_weights_model, "model.safetensors"), + metadata={"format": "pt"}, + tensors=state_dict, + ) + + LOGGER.info("\t+ Loading no weights model") + with no_init_weights(): + self.model = no_weights_model + self.load_automodel_from_pretrained() + self.model = original_model + + LOGGER.info("\t+ Randomizing weights") + randomize_weights(self.pretrained_model) + LOGGER.info("\t+ Tying model weights after randomization") + self.pretrained_model.tie_weights() + + @property + def ovmodel_kwargs(self) -> Dict[str, Any]: + kwargs = {} + + if self.is_text_generation_model(): + kwargs["use_cache"] = self.config.use_cache + kwargs["use_merged"] = self.config.use_merged + + return kwargs + def quantize_automodel(self) -> None: LOGGER.info("\t+ Attempting quantization") quantized_model_path = f"{self.tmpdir.name}/quantized" @@ -88,18 +154,22 @@ def quantize_automodel(self) -> None: quantization_config = OVQuantizationConfig(**self.config.quantization_config) LOGGER.info("\t+ Creating quantizer") quantizer = OVQuantizer.from_pretrained(self.pretrained_model, task=self.task, seed=self.config.seed) - LOGGER.info("\t+ Processing calibration config") - calibration_config = self.config.calibration_config.copy() - preprocess_class = get_class(calibration_config.pop("preprocess_class")) - calibration_config["preprocess_function"] = preprocess_class(model_name_or_path=self.model) - LOGGER.info("\t+ Loading calibration dataset") - calibration_dataset = quantizer.get_calibration_dataset(**calibration_config) + + if self.config.calibration: + LOGGER.info("\t+ Generating calibration dataset") + dataset_shapes = {"dataset_size": 1, "sequence_length": 1, **self.model_shapes} + calibration_dataset = DatasetGenerator(task=self.task, dataset_shapes=dataset_shapes).generate() + columns_to_be_removed = list(set(calibration_dataset.column_names) - set(quantizer._export_input_names)) + calibration_dataset = calibration_dataset.remove_columns(columns_to_be_removed) + else: + calibration_dataset = None + LOGGER.info("\t+ Quantizing model") quantizer.quantize( - quantization_config=quantization_config, save_directory=quantized_model_path, + quantization_config=quantization_config, calibration_dataset=calibration_dataset, - # TODO: add support for these + # TODO: add support for these (maybe) remove_unused_columns=True, data_collator=None, weights_only=False, @@ -121,6 +191,9 @@ def prepare_for_inference(self, **kwargs) -> None: for key, value in kwargs.items() if key in inspect.getfullargspec(self.pretrained_model.reshape).args } + if (static_shapes.get("height", None) is not None) and ("sequence_length" in static_shapes): + static_shapes["sequence_length"] = kwargs.get("num_channels", 3) + LOGGER.info(f"\t+ Reshaping model with static shapes: {static_shapes}") self.pretrained_model.reshape(**static_shapes) diff --git a/optimum_benchmark/backends/openvino/config.py b/optimum_benchmark/backends/openvino/config.py index 9077d623..378999bb 100644 --- a/optimum_benchmark/backends/openvino/config.py +++ b/optimum_benchmark/backends/openvino/config.py @@ -8,22 +8,6 @@ OmegaConf.register_new_resolver("openvino_version", openvino_version) -# https://github.com/huggingface/optimum-intel/blob/main/optimum/intel/openvino/configuration.py#L81 -QUANTIZATION_CONFIG = { - "compression": None, - "input_info": None, - "save_onnx_model": False, -} - -CALIBRATION_CONFIG = { - "dataset_name": "glue", - "num_samples": 300, - "dataset_config_name": "sst2", - "dataset_split": "train", - "preprocess_batch": True, - "preprocess_class": "optimum_benchmark.preprocessors.glue.GluePreprocessor", -} - @dataclass class OVConfig(BackendConfig): @@ -31,6 +15,9 @@ class OVConfig(BackendConfig): version: str = "${openvino_version:}" _target_: str = "optimum_benchmark.backends.openvino.backend.OVBackend" + # load options + no_weights: bool = False + # export options export: bool = True use_cache: bool = True @@ -40,8 +27,8 @@ class OVConfig(BackendConfig): openvino_config: Dict[str, Any] = field(default_factory=dict) # compilation options - reshape: bool = False half: bool = False + reshape: bool = False # quantization options quantization: bool = False @@ -54,13 +41,5 @@ class OVConfig(BackendConfig): def __post_init__(self): super().__post_init__() - if self.quantization: - self.quantization_config = OmegaConf.to_object( - OmegaConf.merge(QUANTIZATION_CONFIG, self.quantization_config) - ) - if not self.calibration: - raise ValueError("OpenVINO quantization requires enabling calibration.") - else: - self.calibration_config = OmegaConf.to_object( - OmegaConf.merge(CALIBRATION_CONFIG, self.calibration_config) - ) + if self.quantization and not self.calibration: + raise ValueError("OpenVINO quantization requires enabling calibration.") diff --git a/optimum_benchmark/preprocessors/__init__.py b/optimum_benchmark/preprocessors/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/optimum_benchmark/preprocessors/glue.py b/optimum_benchmark/preprocessors/glue.py deleted file mode 100644 index 8e741359..00000000 --- a/optimum_benchmark/preprocessors/glue.py +++ /dev/null @@ -1,13 +0,0 @@ -from transformers import AutoTokenizer - - -class GluePreprocessor: - def __init__(self, model_name_or_path): - self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) - - def __call__(self, examples): - return self.tokenizer( - examples["sentence"], - padding="max_length", - truncation=True, - )