diff --git a/optimum_benchmark/backends/pytorch/backend.py b/optimum_benchmark/backends/pytorch/backend.py index 312b6fdb..1ac638e8 100644 --- a/optimum_benchmark/backends/pytorch/backend.py +++ b/optimum_benchmark/backends/pytorch/backend.py @@ -1,4 +1,5 @@ import gc +import logging from logging import getLogger from typing import TYPE_CHECKING, Any, Callable, Dict, List @@ -19,6 +20,9 @@ # bachend logger LOGGER = getLogger("pytorch") +# disable numexpr.utils logger +getLogger("numexpr.utils").setLevel(logging.CRITICAL) + class PyTorchBackend(Backend[PyTorchConfig]): NAME: str = "pytorch" @@ -64,14 +68,9 @@ def configure(self, config: PyTorchConfig) -> None: self.pretrained_model.eval() # BetterTransformer - if self.config.bettertransformer: - LOGGER.info("\t+ Using optimum.bettertransformer") - from optimum.bettertransformer import BetterTransformer - - self.pretrained_model = BetterTransformer.transform( - self.pretrained_model, - keep_original_model=False, - ) + if self.config.to_bettertransformer: + LOGGER.info("\t+ Enabling BetterTransformer") + self.pretrained_model.to_bettertransformer() # Compile model if self.config.torch_compile: @@ -100,19 +99,19 @@ def configure(self, config: PyTorchConfig) -> None: self.pretrained_model = get_peft_model(self.pretrained_model, peft_config=peft_config) def load_model_from_pretrained(self) -> None: - # attempting inline quantization if possible - if self.config.quantization_scheme == "gptq" and self.config.quantization_config: + # iniline quantization or quantization config modification + if self.config.quantization_scheme == "gptq": LOGGER.info("\t+ Processing GPTQ config") from transformers import GPTQConfig self.quantization_config = GPTQConfig(**self.config.quantization_config) - elif self.config.quantization_scheme == "awq" and self.config.quantization_config: + elif self.config.quantization_scheme == "awq": LOGGER.info("\t+ Processing AWQ config") from transformers import AwqConfig self.quantization_config = AwqConfig(**self.config.quantization_config) elif self.config.quantization_scheme == "bnb": - LOGGER.info("\t+ Processing BitsAndBytesConfig") + LOGGER.info("\t+ Processing BitsAndBytes config") from transformers import BitsAndBytesConfig self.quantization_config = self.config.quantization_config.copy() @@ -120,7 +119,6 @@ def load_model_from_pretrained(self) -> None: self.quantization_config["bnb_4bit_compute_dtype"] = getattr( torch, self.quantization_config["bnb_4bit_compute_dtype"] ) - LOGGER.info(f"\t+ Using bnb_4bit_compute_dtype: {self.quantization_config['bnb_4bit_compute_dtype']}") self.quantization_config = BitsAndBytesConfig(**self.quantization_config) else: self.quantization_config = None @@ -135,8 +133,8 @@ def load_model_from_pretrained(self) -> None: ) if self.config.device_map is None: LOGGER.info(f"\t+ Moving diffusion pipeline to device: {self.device}") - # Diffusers does not support loading with torch.device context manager self.pretrained_model.to(self.device) + elif self.config.device_map is not None: LOGGER.info(f"\t+ Loading model with device_map: {self.config.device_map}") self.pretrained_model = self.automodel_class.from_pretrained( @@ -146,18 +144,17 @@ def load_model_from_pretrained(self) -> None: **self.automodel_kwargs, **self.hub_kwargs, ) - elif hasattr(self.pretrained_config, "quantization_config"): - LOGGER.info("\t+ Loading quantized model") + elif hasattr(self.pretrained_config, "quantization_config") or self.quantization_config is not None: + LOGGER.info("\t+ Loading model with low cpu memory usage") self.pretrained_model = self.automodel_class.from_pretrained( self.model, - device_map=self.device, - low_cpu_mem_usage=True, + low_cpu_memory_usage=True, torch_dtype=self.torch_dtype, **self.automodel_kwargs, **self.hub_kwargs, - ) + ).to(self.device) else: - LOGGER.info(f"\t+ Loading model on device: {self.device}") + LOGGER.info(f"\t+ Loading model directly on device: {self.device}") with self.device: self.pretrained_model = self.automodel_class.from_pretrained( self.model, @@ -168,13 +165,17 @@ def load_model_from_pretrained(self) -> None: @property def automodel_kwargs(self) -> Dict[str, Any]: + kwargs = {} + if self.quantization_config is not None: - return {"quantization_config": self.quantization_config} - else: - return {} + kwargs["quantization_config"] = self.quantization_config + + if self.config.use_flash_attention_2: + kwargs["use_flash_attention_2"] = True + + return kwargs def load_model_from_config(self) -> None: - # TODO: create no_weights tests from accelerate import init_empty_weights LOGGER.info("\t+ Initializing empty weights model on device: meta") @@ -238,10 +239,11 @@ def load_model_from_config(self) -> None: def prepare_for_inference(self, input_shapes: Dict[str, int], **kwargs) -> None: super().prepare_for_inference(input_shapes=input_shapes, **kwargs) + if (self.config.quantization_scheme == "gptq" and self.config.quantization_config.get("desc_act", None)) or ( hasattr(self.pretrained_config, "quantization_config") - and self.pretrained_config.quantization_config.get("quant_method", None) == "gptq" and self.pretrained_config.quantization_config.get("desc_act", None) + and self.pretrained_config.quantization_config.get("quant_method", None) == "gptq" ): LOGGER.info("\t+ Setting GPTQ's max_input_length") from auto_gptq import exllama_set_max_input_length diff --git a/optimum_benchmark/backends/pytorch/config.py b/optimum_benchmark/backends/pytorch/config.py index ddf0d7cb..ec7e89fb 100644 --- a/optimum_benchmark/backends/pytorch/config.py +++ b/optimum_benchmark/backends/pytorch/config.py @@ -57,7 +57,8 @@ class PyTorchConfig(BackendConfig): torch_compile_config: Dict[str, Any] = field(default_factory=dict) # optimization options - bettertransformer: bool = False + to_bettertransformer: bool = False + use_flash_attention_2: bool = False # quantization options quantization_scheme: Optional[str] = None