From 9a1b899bab28e66d996789b32fbd6d2d190e75d2 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Fri, 24 Nov 2023 06:45:16 +0000 Subject: [PATCH] remove gptq exllama's max_input_length setting --- optimum_benchmark/backends/pytorch/backend.py | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/optimum_benchmark/backends/pytorch/backend.py b/optimum_benchmark/backends/pytorch/backend.py index 340a7a87..9df59781 100644 --- a/optimum_benchmark/backends/pytorch/backend.py +++ b/optimum_benchmark/backends/pytorch/backend.py @@ -184,7 +184,7 @@ def automodel_kwargs(self) -> Dict[str, Any]: kwargs = {} if hasattr(self.pretrained_config, "quantization_config") or self.quantization_config is not None: - kwargs["low_cpu_memory_usage"] = True + kwargs["low_cpu_mem_usage"] = True if self.quantization_config is not None: kwargs["quantization_config"] = self.quantization_config @@ -256,21 +256,6 @@ def load_model_from_config(self) -> None: LOGGER.info("\t+ Tying weights") self.pretrained_model.tie_weights() - def prepare_for_inference(self, input_shapes: Dict[str, int], **kwargs) -> None: - super().prepare_for_inference(input_shapes=input_shapes, **kwargs) - - if (self.config.quantization_scheme == "gptq" and self.config.quantization_config.get("desc_act", None)) or ( - hasattr(self.pretrained_config, "quantization_config") - and self.pretrained_config.quantization_config.quant_method == "gptq" - and hasattr(self.pretrained_config.quantization_config, "desc_act") - and self.pretrained_config.quantization_config.desc_act - ): - LOGGER.info("\t+ Setting GPTQ's max_input_length") - from auto_gptq import exllama_set_max_input_length # type: ignore - - max_input_length = to_pow2(input_shapes["batch_size"] * input_shapes["sequence_length"]) - self.pretrained_model = exllama_set_max_input_length(self.pretrained_model, max_input_length) - def forward(self, input: Dict[str, Any], kwargs: Dict[str, Any]) -> "ModelOutput": if self.is_diffusion_pipeline(): return super().forward(input, kwargs)