Skip to content

Commit

Permalink
remove gptq exllama's max_input_length setting
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Nov 24, 2023
1 parent 6b0c247 commit 9a1b899
Showing 1 changed file with 1 addition and 16 deletions.
17 changes: 1 addition & 16 deletions optimum_benchmark/backends/pytorch/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def automodel_kwargs(self) -> Dict[str, Any]:
kwargs = {}

if hasattr(self.pretrained_config, "quantization_config") or self.quantization_config is not None:
kwargs["low_cpu_memory_usage"] = True
kwargs["low_cpu_mem_usage"] = True

if self.quantization_config is not None:
kwargs["quantization_config"] = self.quantization_config
Expand Down Expand Up @@ -256,21 +256,6 @@ def load_model_from_config(self) -> None:
LOGGER.info("\t+ Tying weights")
self.pretrained_model.tie_weights()

def prepare_for_inference(self, input_shapes: Dict[str, int], **kwargs) -> None:
super().prepare_for_inference(input_shapes=input_shapes, **kwargs)

if (self.config.quantization_scheme == "gptq" and self.config.quantization_config.get("desc_act", None)) or (
hasattr(self.pretrained_config, "quantization_config")
and self.pretrained_config.quantization_config.quant_method == "gptq"
and hasattr(self.pretrained_config.quantization_config, "desc_act")
and self.pretrained_config.quantization_config.desc_act
):
LOGGER.info("\t+ Setting GPTQ's max_input_length")
from auto_gptq import exllama_set_max_input_length # type: ignore

max_input_length = to_pow2(input_shapes["batch_size"] * input_shapes["sequence_length"])
self.pretrained_model = exllama_set_max_input_length(self.pretrained_model, max_input_length)

def forward(self, input: Dict[str, Any], kwargs: Dict[str, Any]) -> "ModelOutput":
if self.is_diffusion_pipeline():
return super().forward(input, kwargs)
Expand Down

0 comments on commit 9a1b899

Please sign in to comment.