forked from vllm-project/llm-compressor
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdeepseek_moe_w8a8_int8.py
94 lines (75 loc) · 2.62 KB
/
deepseek_moe_w8a8_int8.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor.transformers import oneshot
from llmcompressor.transformers.compression.helpers import calculate_offload_device_map
# select a Mixture of Experts model for quantization
MODEL_ID = "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"
# adjust based off number of desired GPUs
# if not enough memory is available, some layers will automatically be offlaoded to cpu
device_map = calculate_offload_device_map(
MODEL_ID,
reserve_for_hessians=True,
num_gpus=2,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID, device_map=device_map, torch_dtype=torch.bfloat16, trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
# Select calibration dataset.
# its recommended to use more calibration samples for MoE models so each expert is hit
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
DATASET_SPLIT = "train_sft"
NUM_CALIBRATION_SAMPLES = 2048
MAX_SEQUENCE_LENGTH = 2048
# Load dataset and preprocess.
ds = load_dataset(DATASET_ID, split=DATASET_SPLIT)
ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES))
def preprocess(example):
return {
"text": tokenizer.apply_chat_template(
example["messages"],
tokenize=False,
)
}
ds = ds.map(preprocess)
# Tokenize inputs.
def tokenize(sample):
return tokenizer(
sample["text"],
padding=False,
max_length=MAX_SEQUENCE_LENGTH,
truncation=True,
add_special_tokens=False,
)
ds = ds.map(tokenize, remove_columns=ds.column_names)
# define a llmcompressor recipe for INT8 W8A8 quantization
# since the MoE gate layers are sensitive to quantization, we add them to the ignore
# list so they remain at full precision
recipe = [
GPTQModifier(
targets="Linear",
scheme="W8A8",
ignore=["lm_head", "re:.*mlp.gate$"],
),
]
SAVE_DIR = MODEL_ID.split("/")[1] + "-W8A8"
oneshot(
model=model,
dataset=ds,
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
save_compressed=True,
output_dir=SAVE_DIR,
)
print("========== SAMPLE GENERATION ==============")
SAMPLE_INPUT = ["I love quantization because"]
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
inputs = tokenizer(SAMPLE_INPUT, return_tensors="pt", padding=True).to(model.device)
output = model.generate(**inputs, max_length=50)
text_output = tokenizer.batch_decode(output)
print(text_output)