Skip to content

Commit

Permalink
fixed segfault when loading no weights exllama v2 (#105)
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil authored Jan 5, 2024
1 parent 9b3fc4d commit e53e378
Showing 1 changed file with 32 additions and 12 deletions.
44 changes: 32 additions & 12 deletions optimum_benchmark/backends/pytorch/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import torch
from datasets import Dataset
from safetensors.torch import save_model
from safetensors.torch import save_file
from transformers import TrainerCallback, TrainerState
from transformers.utils import ModelOutput
from transformers.utils.logging import set_verbosity_error
Expand Down Expand Up @@ -78,6 +78,8 @@ def configure(self, config: PyTorchConfig) -> None:
self.quantization_config = None

# Load model
if self.config.no_weights and self.is_diffusion_pipeline():
raise ValueError("Diffusion Pipelines are not supported with no_weights=True")
if self.config.no_weights:
LOGGER.info("\t+ Loading model with no weights")
self.load_model_with_no_weights()
Expand Down Expand Up @@ -161,7 +163,7 @@ def load_model_from_pretrained(self) -> None:
**self.hub_kwargs,
)
elif self.is_gptq_quantized() or self.is_awq_quantized():
LOGGER.info("\t+ Loading GPTQ quantized model")
LOGGER.info("\t+ Loading quantized model")
self.pretrained_model = self.automodel_class.from_pretrained(
pretrained_model_name_or_path=self.model,
# for gptq, we need to specify the device_map to either auto
Expand All @@ -172,6 +174,7 @@ def load_model_from_pretrained(self) -> None:
**self.automodel_kwargs,
**self.hub_kwargs,
)
print(torch.cuda.max_memory_allocated())
elif self.config.device_map is not None:
LOGGER.info(f"\t+ Loading model with device map: {self.config.device_map}")
self.pretrained_model = self.automodel_class.from_pretrained(
Expand All @@ -193,31 +196,40 @@ def load_model_from_pretrained(self) -> None:
def load_model_with_no_weights(self) -> None:
self.tmp_dir = TemporaryDirectory()

if self.is_diffusion_pipeline():
raise ValueError("Diffusion pipelines are not supported with no_weights=True")

original_model = self.model
no_weights_model = os.path.join(self.tmp_dir.name, "no_weights")

LOGGER.info("\t+ Creating no weights model directory")
os.makedirs(no_weights_model, exist_ok=True)
if not os.path.exists(no_weights_model):
os.makedirs(no_weights_model)

if self.is_quantized():
# so that from_pretrained acts as if the model is quantized
# tricking from_pretrained to load the model as if it was quantized
self.pretrained_config.quantization_config = self.quantization_config.to_dict()

LOGGER.info(f"\t+ Saving pretrained config to {no_weights_model}")
self.pretrained_config.save_pretrained(save_directory=no_weights_model)

if self.pretrained_processor is not None:
LOGGER.info(f"\t+ Saving pretrained processor to {no_weights_model}")
self.pretrained_processor.save_pretrained(save_directory=no_weights_model)
LOGGER.info(f"\t+ Creating no weights model to {no_weights_model}")
state_dict = torch.nn.Linear(1, 1).state_dict()

if self.is_exllamav2():
# for exllamav2 we need to add g_idx to the state_dict
LOGGER.info("\t+ Loading meta model")
with torch.device("meta"):
meta_model = self.automodel_class.from_config(self.pretrained_config)

LOGGER.info("\t+ Setting g_idx for ExllamaV2")
for name, module in meta_model.named_modules():
# loading to exllama v2's QuantLinear creates g_idx with bad values
if hasattr(module, "in_features"):
state_dict[name + ".g_idx"] = torch.ones((module.in_features,), dtype=torch.int32)

LOGGER.info(f"\t+ Saving no weights model to {no_weights_model}")
save_model(
save_file(
filename=os.path.join(no_weights_model, "model.safetensors"),
model=torch.nn.Linear(1, 1),
metadata={"format": "pt"},
tensors=state_dict,
)

LOGGER.info("\t+ Loading no weights model")
Expand Down Expand Up @@ -282,6 +294,14 @@ def is_awq_quantized(self) -> bool:
and self.pretrained_config.quantization_config.get("quant_method", None) == "awq"
)

def is_exllamav2(self) -> bool:
return (
self.is_quantized()
and self.is_gptq_quantized()
and "exllama_config" in self.config.quantization_config
and self.config.quantization_config["exllama_config"]["version"] == 2
)

@property
def automodel_kwargs(self) -> Dict[str, Any]:
kwargs = {}
Expand Down

0 comments on commit e53e378

Please sign in to comment.