Skip to content

Commit

Permalink
last
Browse files Browse the repository at this point in the history
  • Loading branch information
Jinxiaolong1129 committed Mar 9, 2024
1 parent 4cb9a1c commit 17dfecf
Show file tree
Hide file tree
Showing 9 changed files with 1,060 additions and 86 deletions.
1 change: 1 addition & 0 deletions awq/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
__version__ = "0.2.3"
from awq.models.auto import AutoAWQForCausalLM
from awq.models.auto import AutoAWQForSeq2SeqLM
74 changes: 74 additions & 0 deletions awq/models/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,77 @@ def from_quantized(
offload_folder=offload_folder,
**config_kwargs,
)




class AutoAWQForSeq2SeqLM:
def __init__(self):
raise EnvironmentError(
"You must instantiate AutoAWQForSeq2SeqLM with\n"
"AutoAWQForSeq2SeqLM.from_quantized or AutoAWQForSeq2SeqLM.from_pretrained"
)

@classmethod
def from_pretrained(
self,
model_path,
trust_remote_code=True,
safetensors=True,
device_map=None,
**model_init_kwargs,
) -> BaseAWQForCausalLM:
model_type = check_and_get_model_type(
model_path, trust_remote_code, **model_init_kwargs
)

# TODO (xiaolong): build model add key value to AWQ_CAUSAL_LM_MODEL_MAP and add new py to awq/models like lamma_moe.py
return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_pretrained(
model_path,
model_type,
trust_remote_code=trust_remote_code,
safetensors=safetensors,
device_map=device_map,
**model_init_kwargs,
)

@classmethod
def from_quantized(
self,
quant_path,
quant_filename="",
max_seq_len=2048,
trust_remote_code=True,
fuse_layers=True,
use_exllama=False,
use_exllama_v2=False,
batch_size=1,
safetensors=True,
device_map="balanced",
offload_folder=None,
**config_kwargs,
) -> BaseAWQForCausalLM:
os.environ["AWQ_BATCH_SIZE"] = str(batch_size)
model_type = check_and_get_model_type(quant_path, trust_remote_code)

if config_kwargs.get("max_new_tokens") is not None:
max_seq_len = config_kwargs["max_new_tokens"]
logging.warning(
"max_new_tokens argument is deprecated... gracefully "
"setting max_seq_len=max_new_tokens."
)

return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized(
quant_path,
model_type,
quant_filename,
max_seq_len,
trust_remote_code=trust_remote_code,
fuse_layers=fuse_layers,
use_exllama=use_exllama,
use_exllama_v2=use_exllama_v2,
safetensors=safetensors,
device_map=device_map,
offload_folder=offload_folder,
**config_kwargs,
)
6 changes: 3 additions & 3 deletions awq/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@

from awq.models._config import AwqConfig
from awq.modules.act import ScaledActivation
from awq.quantize.quantizer import AwqQuantizer
from awq.quantize.quantizer import AwqQuantizer, AwqQuantizerForSeq2SeqLM
from awq.utils.module import get_named_linears, set_op_by_name

# Since we support different `AutoModelForxxx` from transformers
Expand Down Expand Up @@ -670,7 +670,7 @@ def quantize(
if hasattr(self, "modules_to_not_convert"):
self.quant_config.modules_to_not_convert = self.modules_to_not_convert

self.quantizer = AwqQuantizer(
self.quantizer = AwqQuantizerForSeq2SeqLM(
self,
self.model,
tokenizer,
Expand Down Expand Up @@ -992,7 +992,7 @@ def _load_config(
else:
ignore_patterns.append("*.safetensors*")

model_path = snapshot_download(model_path, ignore_patterns=ignore_patterns)
model_path = snapshot_download(model_path, ignore_patterns=ignore_patterns, cache_dir=config_kwargs.get('cache_dir', None))

if model_filename != "":
model_weights_path = model_path + f"/{model_filename}"
Expand Down
11 changes: 6 additions & 5 deletions awq/models/switch.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import tqdm
from typing import List, Tuple
from .base import BaseAWQForSeq2SeqLM

import torch
from transformers.models.switch_transformers.modeling_switch_transformers import (
SwitchTransformersBlock as OldSwitchTransformersBlock,
SwitchTransformersModel as OldSwitchTransformersModel,
Expand All @@ -21,20 +21,21 @@ class SwitchAWQ(BaseAWQForSeq2SeqLM):

@staticmethod
def get_model_layers(model: OldSwitchTransformersModel):
model.decoder.block
model.encoder.block
return model.model.layers
layers = model.encoder.block + model.decoder.block
return layers

@staticmethod
def get_act_for_scaling(module: OldSwitchTransformersBlock):
return dict(is_scalable=False)

@staticmethod
def move_embed(model: OldSwitchTransformersModel, device: str):
model.model.embed_tokens = model.model.embed_tokens.to(device)
# TODO just for the encoder
model.encoder.embed_tokens = model.encoder.embed_tokens.to(device)

@staticmethod
def get_layers_for_scaling(module: OldSwitchTransformersBlock, input_feat, module_kwargs):
# TODO (xiaolong): just for encoder
layers = []

# attention input
Expand Down
Loading

0 comments on commit 17dfecf

Please sign in to comment.