Skip to content

Commit

Permalink
refactor: set config into weights for quantization feature support mo…
Browse files Browse the repository at this point in the history
…re easily (#400)

Co-authored-by: LS <LS>
  • Loading branch information
thincal authored Apr 10, 2024
1 parent 67d5357 commit 70db455
Show file tree
Hide file tree
Showing 18 changed files with 42 additions and 71 deletions.
3 changes: 1 addition & 2 deletions server/lorax_server/models/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,7 @@ def __init__(
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device=device, dtype=dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq", "eetq"]:
weights._set_gptq_params(model_id)
weights._set_config(model_id, config)

model = BloomForCausalLM(config, weights)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def _load_multi_mqa_gptq(config, prefix: str, weights, bias: bool, head_size, nu

g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
g_idx = g_idx.to(device=weights.device)
bits, groupsize = weights._get_gptq_params()
bits, groupsize = weights._get_bits_and_groupsize()

from lorax_server.utils.layers import HAS_EXLLAMA

Expand Down
4 changes: 1 addition & 3 deletions server/lorax_server/models/flash_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,7 @@ def __init__(
dtype,
process_group=self.process_group,
)

if config.quantize in ["gptq", "awq", "eetq"]:
weights._set_gptq_params(model_id)
weights._set_config(model_id, config)

model = GemmaForCausalLM(config, weights)

Expand Down
4 changes: 1 addition & 3 deletions server/lorax_server/models/flash_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,7 @@ def __init__(
dtype,
process_group=self.process_group,
)

if config.quantize in ["gptq", "awq", "eetq"]:
weights._set_gptq_params(model_id)
weights._set_config(model_id, config)

model = FlashGPT2ForCausalLM(config, weights)

Expand Down
4 changes: 1 addition & 3 deletions server/lorax_server/models/flash_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,7 @@ def __init__(
dtype,
process_group=self.process_group,
)

if config.quantize in ["gptq", "awq", "eetq"]:
weights._set_gptq_params(model_id)
weights._set_config(model_id, config)

model = FlashLlamaForCausalLM(config, weights)

Expand Down
4 changes: 1 addition & 3 deletions server/lorax_server/models/flash_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,7 @@ def __init__(
dtype,
process_group=self.process_group,
)

if config.quantize in ["gptq", "awq", "eetq"]:
weights._set_gptq_params(model_id)
weights._set_config(model_id, config)

model = FlashMistralForCausalLM(config, weights)

Expand Down
4 changes: 1 addition & 3 deletions server/lorax_server/models/flash_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,7 @@ def __init__(
dtype,
process_group=self.process_group,
)

if config.quantize in ["gptq", "awq", "eetq"]:
weights._set_gptq_params(model_id)
weights._set_config(model_id, config)

model = FlashMixtralForCausalLM(config, weights)

Expand Down
3 changes: 1 addition & 2 deletions server/lorax_server/models/flash_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@ def __init__(
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device=device, dtype=dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq", "eetq"]:
weights._set_gptq_params(model_id)
weights._set_config(model_id, config)

model = FlashGPTNeoXForCausalLM(config, weights)

Expand Down
4 changes: 1 addition & 3 deletions server/lorax_server/models/flash_phi.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,7 @@ def __init__(
dtype,
process_group=self.process_group,
)

if config.quantize in ["gptq", "awq", "eetq"]:
weights._set_gptq_params(model_id)
weights._set_config(model_id, config)

model = FlashPhiForCausalLM(config, weights)
self.config = config
Expand Down
4 changes: 1 addition & 3 deletions server/lorax_server/models/flash_qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,7 @@ def __init__(
dtype,
process_group=self.process_group,
)

if config.quantize in ["gptq", "awq", "eetq"]:
weights._set_gptq_params(model_id)
weights._set_config(model_id, config)

model = FlashQwenForCausalLM(config, weights)
self.config = config
Expand Down
4 changes: 1 addition & 3 deletions server/lorax_server/models/flash_qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,7 @@ def __init__(
dtype,
process_group=self.process_group,
)

if config.quantize in ["gptq", "awq", "eetq"]:
weights._set_gptq_params(model_id)
weights._set_config(model_id, config)

model = FlashQwen2ForCausalLM(config, weights)

Expand Down
6 changes: 2 additions & 4 deletions server/lorax_server/models/flash_rw.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __init__(
)

config = RWConfig.from_pretrained(model_id, revision=revision, trust_remote_code=trust_remote_code)
config.quantize = quantize

torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
Expand All @@ -55,10 +56,7 @@ def __init__(
process_group=self.process_group,
aliases={"transformer.word_embeddings.weight": ["lm_head.weight"]},
)

config.quantize = quantize
if config.quantize in ["gptq", "awq", "eetq"]:
weights._set_gptq_params(model_id)
weights._set_config(model_id, config)

model = FlashRWForCausalLM(config, weights)

Expand Down
3 changes: 1 addition & 2 deletions server/lorax_server/models/flash_santacoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,7 @@ def __init__(
process_group=self.process_group,
aliases={"transformer.wte.weight": ["lm_head.weight"]},
)
if config.quantize in ["gptq", "awq", "eetq"]:
weights._set_gptq_params(model_id)
weights._set_config(model_id, config)

model = FlashSantacoderForCausalLM(config, weights)

Expand Down
3 changes: 1 addition & 2 deletions server/lorax_server/models/galactica.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,7 @@ def __init__(
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device=device, dtype=dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq", "eetq"]:
weights._set_gptq_params(model_id)
weights._set_config(model_id, config)

model = OPTForCausalLM(config, weights)

Expand Down
3 changes: 1 addition & 2 deletions server/lorax_server/models/gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,7 @@ def __init__(
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device=device, dtype=dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq", "eetq"]:
weights._set_gptq_params(model_id)
weights._set_config(model_id, config)

model = GPTNeoxForCausalLM(config, weights)

Expand Down
3 changes: 1 addition & 2 deletions server/lorax_server/models/mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,7 @@ def __init__(

filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq", "eetq"]:
weights._set_gptq_params(model_id)
weights._set_config(model_id, config)

config.quantize = quantize
model = MPTForCausalLM(config, weights)
Expand Down
3 changes: 1 addition & 2 deletions server/lorax_server/models/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,7 @@ def __init__(
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device=device, dtype=dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq", "eetq"]:
weights._set_gptq_params(model_id)
weights._set_config(model_id, config)

model = OPTForCausalLM(config, weights)

Expand Down
52 changes: 24 additions & 28 deletions server/lorax_server/utils/weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import EntryNotFoundError, LocalEntryNotFoundError
from loguru import logger
from safetensors import SafetensorError, safe_open
from safetensors import safe_open


class AbstractWeights(ABC):
Expand Down Expand Up @@ -224,7 +224,7 @@ def get_multi_weights_col(self, prefixes: List[Union[str, Tuple]], quantize: str
else:
g_idx = None

bits, groupsize = self._get_gptq_params()
bits, groupsize = self._get_bits_and_groupsize()
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
else:
w = self.get_sharded_list("weight", prefixes, dim=0)
Expand All @@ -234,7 +234,7 @@ def get_multi_weights_col(self, prefixes: List[Union[str, Tuple]], quantize: str
def get_multi_weights_row(self, prefix: str, quantize: str):
if quantize == "gptq":
use_exllama = True
bits, groupsize = self._get_gptq_params()
bits, groupsize = self._get_bits_and_groupsize()

if bits != 4:
use_exllama = False
Expand Down Expand Up @@ -298,7 +298,7 @@ def get_multi_weights_row(self, prefix: str, quantize: str):

weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
elif quantize == "awq":
bits, groupsize = self._get_gptq_params()
bits, groupsize = self._get_bits_and_groupsize()

try:
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
Expand All @@ -314,31 +314,29 @@ def get_multi_weights_row(self, prefix: str, quantize: str):
weight = self.get_sharded(f"{prefix}.weight", dim=1)
return weight

def _get_gptq_params(self) -> Tuple[int, int]:
def _get_bits_and_groupsize(self) -> Tuple[int, int]:
try:
bits = self.get_tensor("gptq_bits").item()
groupsize = self.get_tensor("gptq_groupsize").item()
except (SafetensorError, RuntimeError) as e:
bits = self.config.quantization_config["bits"]
groupsize = self.config.quantization_config["group_size"]
except KeyError:
# be compatible with old hehavior for gptq
try:
bits = self.gptq_bits
groupsize = self.gptq_groupsize
except Exception:
raise e
bits = self.config.quantization_config["gptq_bits"]
groupsize = self.config.quantization_config["gptq_groupsize"]
except KeyError:
try:
bits = self.get_tensor("gptq_bits").item()
groupsize = self.get_tensor("gptq_groupsize").item()
except Exception as e:
raise e

return bits, groupsize

def _set_gptq_params(self, model_id):
filename = "config.json"
try:
if os.path.exists(os.path.join(model_id, filename)):
filename = os.path.join(model_id, filename)
else:
filename = hf_hub_download(model_id, filename=filename)
with open(filename, "r") as f:
data = json.load(f)
self.gptq_bits = data["quantization_config"]["bits"]
self.gptq_groupsize = data["quantization_config"]["group_size"]
except Exception:
def _set_config(self, model_id, config):
self.config = config

if not hasattr(self.config, "quantization_config"):
# fill from other config file
filename = "quantize_config.json"
try:
if os.path.exists(os.path.join(model_id, filename)):
Expand All @@ -347,8 +345,7 @@ def _set_gptq_params(self, model_id):
filename = hf_hub_download(model_id, filename=filename)
with open(filename, "r") as f:
data = json.load(f)
self.gptq_bits = data["bits"]
self.gptq_groupsize = data["group_size"]
self.config.quantization_config = data["quantization_config"]
except Exception:
filename = "quant_config.json"
try:
Expand All @@ -358,8 +355,7 @@ def _set_gptq_params(self, model_id):
filename = hf_hub_download(model_id, filename=filename)
with open(filename, "r") as f:
data = json.load(f)
self.gptq_bits = data["w_bit"]
self.gptq_groupsize = data["q_group_size"]
self.config.quantization_config = data["quantization_config"]
except Exception:
pass

Expand Down

0 comments on commit 70db455

Please sign in to comment.