diff --git a/optimum_benchmark/backends/pytorch/backend.py b/optimum_benchmark/backends/pytorch/backend.py index c48e2d1b..d1b474ba 100644 --- a/optimum_benchmark/backends/pytorch/backend.py +++ b/optimum_benchmark/backends/pytorch/backend.py @@ -6,7 +6,7 @@ import torch from datasets import Dataset -from safetensors.torch import save_model +from safetensors.torch import save_file from transformers import TrainerCallback, TrainerState from transformers.utils import ModelOutput from transformers.utils.logging import set_verbosity_error @@ -78,6 +78,8 @@ def configure(self, config: PyTorchConfig) -> None: self.quantization_config = None # Load model + if self.config.no_weights and self.is_diffusion_pipeline(): + raise ValueError("Diffusion Pipelines are not supported with no_weights=True") if self.config.no_weights: LOGGER.info("\t+ Loading model with no weights") self.load_model_with_no_weights() @@ -161,7 +163,7 @@ def load_model_from_pretrained(self) -> None: **self.hub_kwargs, ) elif self.is_gptq_quantized() or self.is_awq_quantized(): - LOGGER.info("\t+ Loading GPTQ quantized model") + LOGGER.info("\t+ Loading quantized model") self.pretrained_model = self.automodel_class.from_pretrained( pretrained_model_name_or_path=self.model, # for gptq, we need to specify the device_map to either auto @@ -172,6 +174,7 @@ def load_model_from_pretrained(self) -> None: **self.automodel_kwargs, **self.hub_kwargs, ) + print(torch.cuda.max_memory_allocated()) elif self.config.device_map is not None: LOGGER.info(f"\t+ Loading model with device map: {self.config.device_map}") self.pretrained_model = self.automodel_class.from_pretrained( @@ -193,31 +196,40 @@ def load_model_from_pretrained(self) -> None: def load_model_with_no_weights(self) -> None: self.tmp_dir = TemporaryDirectory() - if self.is_diffusion_pipeline(): - raise ValueError("Diffusion pipelines are not supported with no_weights=True") - original_model = self.model no_weights_model = os.path.join(self.tmp_dir.name, "no_weights") LOGGER.info("\t+ Creating no weights model directory") - os.makedirs(no_weights_model, exist_ok=True) + if not os.path.exists(no_weights_model): + os.makedirs(no_weights_model) if self.is_quantized(): - # so that from_pretrained acts as if the model is quantized + # tricking from_pretrained to load the model as if it was quantized self.pretrained_config.quantization_config = self.quantization_config.to_dict() LOGGER.info(f"\t+ Saving pretrained config to {no_weights_model}") self.pretrained_config.save_pretrained(save_directory=no_weights_model) - if self.pretrained_processor is not None: - LOGGER.info(f"\t+ Saving pretrained processor to {no_weights_model}") - self.pretrained_processor.save_pretrained(save_directory=no_weights_model) + LOGGER.info(f"\t+ Creating no weights model to {no_weights_model}") + state_dict = torch.nn.Linear(1, 1).state_dict() + + if self.is_exllamav2(): + # for exllamav2 we need to add g_idx to the state_dict + LOGGER.info("\t+ Loading meta model") + with torch.device("meta"): + meta_model = self.automodel_class.from_config(self.pretrained_config) + + LOGGER.info("\t+ Setting g_idx for ExllamaV2") + for name, module in meta_model.named_modules(): + # loading to exllama v2's QuantLinear creates g_idx with bad values + if hasattr(module, "in_features"): + state_dict[name + ".g_idx"] = torch.ones((module.in_features,), dtype=torch.int32) LOGGER.info(f"\t+ Saving no weights model to {no_weights_model}") - save_model( + save_file( filename=os.path.join(no_weights_model, "model.safetensors"), - model=torch.nn.Linear(1, 1), metadata={"format": "pt"}, + tensors=state_dict, ) LOGGER.info("\t+ Loading no weights model") @@ -282,6 +294,14 @@ def is_awq_quantized(self) -> bool: and self.pretrained_config.quantization_config.get("quant_method", None) == "awq" ) + def is_exllamav2(self) -> bool: + return ( + self.is_quantized() + and self.is_gptq_quantized() + and "exllama_config" in self.config.quantization_config + and self.config.quantization_config["exllama_config"]["version"] == 2 + ) + @property def automodel_kwargs(self) -> Dict[str, Any]: kwargs = {}