Skip to content

Commit

Permalink
Merge pull request #27 from eth-easl/feature/larger_models
Browse files Browse the repository at this point in the history
Feature/larger models
  • Loading branch information
xzyaoi authored Jun 3, 2024
2 parents ffb5f53 + 9fd712a commit c2c2481
Show file tree
Hide file tree
Showing 15 changed files with 2,561 additions and 144 deletions.
2 changes: 1 addition & 1 deletion artifact/benchmarks/utils/compression/compress_llamas.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
python cli/compress.py --target-model FlagAlpha/Llama2-Chinese-7b-Chat --outdir .cache/compressed_models/7b-parameters/llama2-chinese-7b-chat-2bits --dataset .cache/datasets/meta.jsonl --n-samples 256 --bits 2 --sparsity 0.5 --lossless gdeflate --delta subtract --base-model meta-llama/Llama-2-7b-hf --shuffle-dataset --fast-tokenizer --perc-damp 0.01 --block-size 128

python cli/compress.py --target-model lmsys/vicuna-7b-v1.5 --outdir .cache/compressed_models/7b-parameters/vicuna-7b-v1.5-2bits --dataset .cache/datasets/lmsys.jsonl --n-samples 256 --bits 2 --sparsity 0.5 --lossless gdeflate --delta subtract --base-model meta-llama/Llama-2-7b-hf --shuffle-dataset --fast-tokenizer --perc-damp 0.01 --block-size 128
python cli/compress.py --target-model lmsys/vicuna-7b-v1.5 --outdir .cache/compressed_models/vicuna-7b.2b50s --dataset .local/datasets/lmsys.jsonl --n-samples 256 --bits 2 --sparsity 0.5 --lossless gdeflate --delta subtract --base-model meta-llama/Llama-2-7b-hf --shuffle-dataset --fast-tokenizer --perc-damp 0.01 --block-size 128

python cli/compress.py --target-model meta-llama/Llama-2-7b-chat-hf --outdir .cache/compressed_models/7b-parameters/llama-2-7b-chat --dataset .cache/datasets/meta.jsonl --n-samples 256 --bits 2 --sparsity 0.5 --lossless gdeflate --delta subtract --base-model meta-llama/Llama-2-7b-hf --shuffle-dataset --fast-tokenizer --perc-damp 0.01 --block-size 128

Expand Down
24 changes: 15 additions & 9 deletions cli/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@
import transformers
from deltazip import AutoDeltaZipModelForCausalLM, BaseCompressionConfig

ignore_keywords = [
'norm',
'embed',
'lm_head'
]

compress_config = BaseCompressionConfig(
bits=4,
group_size=128,
Expand All @@ -16,7 +22,7 @@ def to_chatml(prompt):
return f"<human>: {prompt}<|endoftext|><assistant>:"

def to_lmsys(prompt):
return f"User: {prompt} Assistant:"
return f"USER: {prompt}\nASSISTANT:"

def chat(base_model:str, model_path: str):
print("[deltazip] Loading base model...")
Expand All @@ -27,17 +33,17 @@ def chat(base_model:str, model_path: str):
base_model = base_model.half()
print("[deltazip] Loading target model...")
delta_model = AutoDeltaZipModelForCausalLM.from_compressed(
args.model_path, strict=True, device="cpu", unpack=True
model_path, strict=True, device="cpu", unpack=True
)
delta_model = delta_model.half()

compressed_modules = []
for x in base_model.inside_layer_modules:
compressed_modules.extend(x)
for name, param in base_model.model.named_parameters():
delta_model.model.state_dict()[name].copy_(
param + delta_model.model.state_dict()[name]
)
if any([kw in name for kw in ignore_keywords]):
#delta_model.model.state_dict()[name].copy_(param)
pass
else:
delta_model.model.state_dict()[name].copy_(
param + delta_model.model.state_dict()[name]
)
delta_model = delta_model.to(torch.device("cuda"))
print("[deltazip] models loaded")
pipe = transformers.TextGenerationPipeline(
Expand Down
95 changes: 85 additions & 10 deletions cli/compress.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,19 @@
import json
import torch
import argparse
import safetensors as st
from transformers import AutoTokenizer
from deltazip import AutoDeltaZipModelForCausalLM, BaseCompressionConfig
import os
import math

max_threads = str(min(8, os.cpu_count()))
os.environ['OMP_NUM_THREADS'] = max_threads
os.environ['OPENBLAS_NUM_THREADS'] = max_threads
os.environ['MKL_NUM_THREADS'] = max_threads
os.environ['VECLIB_MAXIMUM_THREADS'] = max_threads
os.environ['NUMEXPR_NUM_THREADS'] = max_threads
os.environ['NUMEXPR_MAX_THREADS'] = max_threads

def main(args):
print(args)
Expand All @@ -14,20 +24,39 @@ def main(args):
compress_config = BaseCompressionConfig(
bits=args.bits,
sparsity=args.sparsity,
prunen=args.prunen,
block_size=args.block_size,
prunen=args.prunen,
prunem=args.prunem,
lossless=args.lossless,
damp_percent=args.perc_damp,
sym=False,
desc_act=args.desc_act,
sym=args.sym,
)
print("[info] compress config:", compress_config)
target_model = AutoDeltaZipModelForCausalLM.from_pretrained(
args.target_model,
compress_config=compress_config,
torch_dtype=torch.float16,
max_memory = {0: "2GIB", 1: "48GIB", 2: "48GIB", 3:"48GIB"}
# max_memory = {
# 0: "60GIB",
# 1: "60GIB",
# 2: "60GIB",
# 3: "60GIB",
# 4: "60GIB",
# 5: "60GIB",
# 6: "60GIB",
# 7: "60GIB",
# "cpu": "140GIB"
# }
)
ignore_keywords = [
'norm',
'embed',
'lm_head'
]
not_save_keywords = [
'norm',
]
target_model.requires_grad_(False)
if args.base_model != "" and args.delta != "":
print("[info] base model is defined, delta mode enabled")
Expand Down Expand Up @@ -64,19 +93,62 @@ def main(args):
# write to folder
os.makedirs(args.outdir, exist_ok=True)
# for weights that are not compressed, we calculate delta afterward compression
if args.large_model:
# for large models - save a temporary results to avoid re-run
tensors = {}
for name, param in target_model.named_parameters():
if not param.is_meta:
tensors[name] = param.data.cpu().clone().detach()
st.torch.save_file(
tensors, os.path.join(args.outdir, "temp.safetensors")
)
target_model_ref = AutoDeltaZipModelForCausalLM.from_pretrained(
args.target_model,
compress_config=compress_config,
torch_dtype=torch.float16,
)
missing_state_dict = target_model_ref.state_dict()
missing_state_dict = {
k: v for k, v in missing_state_dict.items() if k not in tensors
}
print(f"[info] loaded keys: {missing_state_dict.keys()}")
missing_key, unexpected_key = target_model.load_state_dict(missing_state_dict, strict = False, assign=True)
print(f"[info] missing keys: {missing_key}")
print(f"[info] unexpected keys: {unexpected_key}")
for name, param in target_model.named_parameters():
if param.is_meta:
print(f"[info] {name} is on meta")
del target_model_ref

if args.base_model != "" and args.delta != "":
compressed_modules = []
for x in base_model.inside_layer_modules:
compressed_modules.extend(x)
for name, param in target_model.named_parameters():
if "bias" in name or all(
[modules not in name for modules in compressed_modules]
):
target_model.state_dict()[name].copy_(
param - base_model.state_dict()[name]
)
# if all([module not in name for module in compressed_modules]):
# print(f"[info] {name} is compressed, saving in full...")

# target_model.state_dict()[name] = param
# else:
# print(f"[info] {name} is not compressed, saving in full...")
# target_model.state_dict()[name] = param
if any([keyword in name for keyword in not_save_keywords]):
print(f"[info] {name} is not saved")
del target_model.state_dict()[name]
# if "bias" in name or all(
# [modules not in name for modules in compressed_modules]
# ):
# base_weight = base_model.state_dict()[name]
# if base_weight.device != param.device:
# base_weight = base_weight.to(param.device)
# target_model.state_dict()[name] = param - base_weight

if args.base_model != "" and args.delta != "":
del base_model
# run a forward pass to make sure the model is working
target_model.save_compressed(args.outdir)

with open(os.path.join(args.outdir, "compressed_modules.json"), "w") as fp:
json.dump(compressed_modules, fp)

if __name__ == "__main__":
parser = argparse.ArgumentParser()
Expand All @@ -103,6 +175,9 @@ def main(args):
"--lossless", type=str, default="gdeflate", choices=["gdeflate"]
)
parser.add_argument("--delta", type=str, choices=["subtract", "xor"], default="")
parser.add_argument("--sym", action="store_true")
parser.add_argument("--desc-act", action="store_true")
parser.add_argument("--large-model", action="store_true")
parser.add_argument("--perc-damp", type=float, default=0.01)
parser.add_argument("--outdir", type=str, default=".cache/compressed_models")
parser.add_argument("--fast-tokenizer", action="store_true")
Expand Down
25 changes: 20 additions & 5 deletions cli/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from loguru import logger
from deltazip import AutoDeltaZipModelForCausalLM, BaseCompressionConfig
from transformers import AutoTokenizer

compress_config = BaseCompressionConfig(
bits=4,
group_size=128,
Expand All @@ -13,7 +14,11 @@
lossless="gdeflate",
damp_percent=0.02,
)

ignore_keywords = [
'norm',
'embed',
'lm_head'
]
def merge(args):
print(args)
with torch.inference_mode():
Expand All @@ -29,14 +34,23 @@ def merge(args):
compressed_modules = []
for x in base_model.inside_layer_modules:
compressed_modules.extend(x)

for name, param in base_model.model.named_parameters():
delta_model.model.state_dict()[name].copy_(
param + delta_model.model.state_dict()[name]
)
# save model to output directory
if args.delta == "subtract":
if any([kw in name for kw in ignore_keywords]):
#delta_model.model.state_dict()[name].copy_(param)
print(f"Ignoring {name}")
else:
delta_model.model.state_dict()[name].copy_(
param + delta_model.model.state_dict()[name]
)
else:
logger.warning("Skipping due to unknown delta mode")

for name, param in delta_model.model.state_dict().items():
param = param.contiguous()
delta_model.model.save_pretrained(args.output_dir, safe_serialization=False, max_shard_size="10GB")

os.makedirs(args.output_dir, exist_ok=True)
tokenizer = AutoTokenizer.from_pretrained(args.base_model, use_fast=True)
tokenizer.save_pretrained(args.output_dir)
Expand All @@ -46,5 +60,6 @@ def merge(args):
parser.add_argument("--base-model", type=str, default="gpt2")
parser.add_argument("--target-model", type=str, default="gpt2")
parser.add_argument("--output-dir", type=str, default="output")
parser.add_argument("--delta", type=str, default="")
args = parser.parse_args()
merge(args)
7 changes: 6 additions & 1 deletion deltazip/core/sparsegpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def fasterprune(
perm = torch.argsort(torch.diag(H), descending=True)
W = W[:, perm]
H = H[perm][:, perm]
invperm = torch.argsort(perm)
Losses = torch.zeros(self.rows, device=self.dev)
damp = percdamp * torch.mean(torch.diag(H))
diag = torch.arange(self.columns, device=self.dev)
Expand Down Expand Up @@ -120,6 +121,11 @@ def fasterprune(
for i in range(count):
w = W1[:, i]
d = Hinv1[i, i]

if prunen != 0 and i % prunem == 0:
tmp = W1[:, i:(i + prunem)] ** 2 / (torch.diag(Hinv1)[i:(i + prunem)].reshape((1, -1))) ** 2
mask1.scatter_(1, i + torch.topk(tmp, prunen, dim=1, largest=False)[1], True)

q = w.clone()
q[mask1[:, i]] = 0
if hasattr(self, "quantizer"):
Expand Down Expand Up @@ -157,7 +163,6 @@ def fasterprune(
g_idx = [i // self.columns for i in range(self.columns)]
g_idx = torch.tensor(g_idx, dtype=torch.int32, device=W.device)
if actorder:
invperm = torch.argsort(perm)
Q = Q[:, invperm]
g_idx = g_idx[invperm]
W = W.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype)
Expand Down
Loading

0 comments on commit c2c2481

Please sign in to comment.