Skip to content

Commit

Permalink
fix exllama_set_max_input_length
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Nov 14, 2023
1 parent e04d256 commit 1043cc3
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions optimum_benchmark/backends/pytorch/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,9 @@ def load_model_from_pretrained(self) -> None:
self.quantization_config = GPTQConfig(**self.config.quantization_config)
elif self.config.quantization_scheme == "awq" and self.config.quantization_config:
LOGGER.info("\t+ Processing AWQ config")
from transformers import AWQConfig
from transformers import AwqConfig

self.quantization_config = AWQConfig(**self.config.quantization_config)
self.quantization_config = AwqConfig(**self.config.quantization_config)
elif self.config.quantization_scheme == "bnb":
LOGGER.info("\t+ Processing BitsAndBytesConfig")
from transformers import BitsAndBytesConfig
Expand Down Expand Up @@ -238,10 +238,10 @@ def load_model_from_config(self) -> None:

def prepare_for_inference(self, input_shapes: Dict[str, int], **kwargs) -> None:
super().prepare_for_inference(input_shapes=input_shapes, **kwargs)
if (self.config.quantization_scheme == "gptq" and self.config.quantization_config["desc_act"]) or (
if (self.config.quantization_scheme == "gptq" and self.config.quantization_config.get("desc_act", None)) or (
hasattr(self.pretrained_config, "quantization_config")
and self.pretrained_config.quantization_config["desc_act"]
and self.pretrained_config.quantization_config["quant_method"] == "gptq"
and self.pretrained_config.quantization_config.get("quant_method", None) == "gptq"
and self.pretrained_config.quantization_config.get("desc_act", None)
):
LOGGER.info("\t+ Setting GPTQ's max_input_length")
from auto_gptq import exllama_set_max_input_length
Expand Down

0 comments on commit 1043cc3

Please sign in to comment.