Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Explore moe #26

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
82 changes: 82 additions & 0 deletions cli/chat_moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import torch
import transformers
from deltazip import AutoDeltaZipModelForCausalLM, BaseCompressionConfig
from deltazip.modeling._const import EXPERT_ID_PLACEHOLDER
from loguru import logger

def to_chatml(prompt):
return f"<human>: {prompt}<|endoftext|><assistant>:"

def to_lmsys(prompt):
return f"User: {prompt} Assistant:"

def chat(base_model: str, model_path: str):
# print("[deltazip] Loading base model...")
logger.info("Loading tokenizer")
tokenizer = transformers.AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
logger.info("Tokenizer loaded")

logger.info("Loading base_model")
base_model = transformers.AutoModelForCausalLM.from_pretrained(f"{model_path}/base/base_model.pt", trust_remote_code=True)
Copy link
Preview

Copilot AI Dec 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The path should be a directory containing model configurations, not a .pt file.

Suggested change
base_model = transformers.AutoModelForCausalLM.from_pretrained(f"{model_path}/base/base_model.pt", trust_remote_code=True)
base_model = transformers.AutoModelForCausalLM.from_pretrained(f"{model_path}/base", trust_remote_code=True)

Copilot is powered by AI, so mistakes are possible. Review output carefully before use.

Positive Feedback
Negative Feedback

Provide additional feedback

Please help us improve GitHub Copilot by sharing more details about this comment.

Please select one or more of the options
# torch.load(f"{model_path}/base_model.pt")
base_model = base_model.half()
logger.info("Loading base weights")
base_weights = torch.load(f"{model_path}/base/base_weights.pt")

delta_model = AutoDeltaZipModelForCausalLM.from_compressed(
args.model_path, strict=True, device="cpu", unpack=True, trust_remote_code=True
)
delta_model = delta_model.half()

print("base:")
print([name for name, param in base_model.named_parameters()])

print("delta:")
print([name for name, param in delta_model.named_parameters()])

print(f"base_weights: {base_weights.keys()}")

for expert_name, expert_weight in base_weights.items():
prefix, suffix = expert_name.split(EXPERT_ID_PLACEHOLDER)
Copy link
Preview

Copilot AI Dec 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The split method might fail if EXPERT_ID_PLACEHOLDER is not found in expert_name. This should be handled properly.

Suggested change
prefix, suffix = expert_name.split(EXPERT_ID_PLACEHOLDER)
if EXPERT_ID_PLACEHOLDER not in expert_name: continue

Copilot is powered by AI, so mistakes are possible. Review output carefully before use.

Positive Feedback
Negative Feedback

Provide additional feedback

Please help us improve GitHub Copilot by sharing more details about this comment.

Please select one or more of the options
for name_base, param_base in base_model.named_parameters():
if name_base.startswith(prefix) and name_base.endswith(suffix):
for name_delta, param_delta in delta_model.named_parameters():
if name_delta.endswith(name_base):
param_base.data = param_delta.data + expert_weight


delta_model = base_model
delta_model.to(torch.device("cuda"))
print("[deltazip] models loaded")
pipe = transformers.TextGenerationPipeline(
model=delta_model, tokenizer=tokenizer, device="cuda"
)
dialogs = ""
while True:
user_input = input("User: ")
if user_input == "\exit":
break
if user_input == "\reset":
dialogs = ""
continue
model_input = dialogs + to_lmsys(user_input)
outputs = pipe(
[model_input],
max_new_tokens=128,
do_sample=True,
temperature=0.6,
top_k=50,
top_p=0.9,
return_full_text=False,
)[0][0]['generated_text']
print(f"Assistant: {outputs}")
dialogs += outputs

if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--base-model", type=str, help="Type of model")
parser.add_argument("--model-path", type=str, help="Location of model")
args = parser.parse_args()
chat(args.base_model, args.model_path)
177 changes: 177 additions & 0 deletions cli/compress_moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
import accelerate
import os
import json
import torch
import argparse
from transformers import AutoTokenizer, AutoModelForCausalLM, GPTNeoXTokenizerFast
from deltazip import AutoDeltaZipModelForCausalLM, BaseCompressionConfig, base_generation_strategies, modelling_gpt_neox_moe, modeling_llama_moe
from deltazip.modeling._const import EXPERT_ID_PLACEHOLDER
from loguru import logger
from safetensors.torch import save_file
import safetensors
Copy link
Preview

Copilot AI Dec 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The safetensors module is imported but not used directly. Use it directly or remove the import.

Copilot is powered by AI, so mistakes are possible. Review output carefully before use.

Positive Feedback
Negative Feedback

Provide additional feedback

Please help us improve GitHub Copilot by sharing more details about this comment.

Please select one or more of the options
from transformers import GPTNeoXConfig, LlamaConfig


def main(args):
print(args)
compress_config = BaseCompressionConfig(
bits=args.bits,
sparsity=args.sparsity,
# prunen=args.prunen,
block_size=args.block_size,
# prunem=args.prunem,
lossless=args.lossless,
damp_percent=args.perc_damp,
sym=True,
prunen=2,
prunem=4
)
print("[info] compress config:", compress_config)
if args.target_model == "gpt_neox_moe":
tokenizer = GPTNeoXTokenizerFast.from_pretrained(
args.tokenizer, use_fast=args.fast_tokenizer
)
with open(f"{args.model_path}/config.json", "r") as fp:
config = GPTNeoXConfig(**json.load(fp))
with accelerate.init_empty_weights():
model = modelling_gpt_neox_moe.GPTNeoXForCausalLM(config)
model = model.half()
model = accelerate.load_checkpoint_and_dispatch(
model, checkpoint=f"{args.model_path}/model.safetensors.index.json", device_map="auto", no_split_module_classes=['GPTNeoXLayer']
)
model.requires_grad_(False)
target_model = AutoDeltaZipModelForCausalLM.from_model(
model, compress_config=compress_config
)
elif args.target_model == 'llama_moe':
tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer, use_fast=args.fast_tokenizer
)
with open(f"{args.model_path}/config.json", "r") as fp:
config = LlamaConfig(**json.load(fp))
with accelerate.init_empty_weights():
model = modeling_llama_moe.LlamaForCausalLM(config)
model = model.half()
model = accelerate.load_checkpoint_and_dispatch(
model, checkpoint=f"{args.model_path}/model.safetensors.index.json", device_map="auto", no_split_module_classes=['LlamaDecoderLayer']
)
model.requires_grad_(False)
target_model = AutoDeltaZipModelForCausalLM.from_model(
model, compress_config=compress_config
)
else:
tokenizer = AutoTokenizer.from_pretrained(
args.target_model, use_fast=args.fast_tokenizer
)
target_model = AutoDeltaZipModelForCausalLM.from_pretrained(
args.target_model, compress_config=compress_config, torch_dtype=torch.float16
)

target_model.requires_grad_(False)
torch.cuda.empty_cache()
# now time to prepare inspect dataset
with open(args.dataset, "r") as fp:
examples = [json.loads(line)["text"] for line in fp.readlines()]
if args.n_samples <= 0:
examples = examples
else:
if args.shuffle_dataset:
import random

random.seed(42)
random.shuffle(examples)
examples = examples[: args.n_samples]
examples = [tokenizer(x, truncation=True) for x in examples]
# examples = [e for e in examples if len(e['attention_mask']) != 0]
os.makedirs(args.outdir, exist_ok=True)
os.makedirs(f"{args.outdir}/base", exist_ok=True)

logger.info("Saving base expert weights:")
base_weights = target_model.get_moe_base_weights(base_generation_strategies.take_first)
save_file(base_weights, f"{args.outdir}/base/base_weights.safetensors")
logger.info("Saving base weights finished")
del base_weights

target_model.lossy_compress(
examples,
batch_size=1,
is_moe=True
)
# write to folder
logger.info("Saving experts' delta weights:")
target_model.save_compressed(args.outdir)

if args.target_model == "gpt_neox_moe":
model = modelling_gpt_neox_moe.GPTNeoXForCausalLM(config)
model = model.half()
files = os.listdir(args.model_path)
files = [f for f in files if f.endswith("safetensors")]
for f in files:
print(f"Loading: {args.model_path}/{f}")
safetensors.torch.load_model(model, f"{args.model_path}/{f}", strict=False)
elif args.target_model == "llama_moe":
tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer, use_fast=args.fast_tokenizer
)
with open(f"{args.model_path}/config.json", "r") as fp:
config = LlamaConfig(**json.load(fp))
with accelerate.init_empty_weights():
model = modeling_llama_moe.LlamaForCausalLM(config)
model = model.half()
model = accelerate.load_checkpoint_and_dispatch(
model, checkpoint=f"{args.model_path}/model.safetensors.index.json", device_map="auto", no_split_module_classes=['LlamaDecoderLayer', 'LlamaMoE']
)
model.requires_grad_(False)
else:
model = AutoModelForCausalLM.from_pretrained(
args.target_model, torch_dtype=torch.float16, trust_remote_code=True
)

logger.info("Saving non-fc layers:")
sd = model.state_dict()
to_remove = []
for name in sd.keys():
if name.startswith(target_model.layers_block_name):
for inside_layer_module in sum(target_model.inside_layer_modules, []):
prefix, suffix = inside_layer_module.split(EXPERT_ID_PLACEHOLDER)
if prefix in name and suffix in name and name.endswith(".weight"):
to_remove.append(name)

# Make sure we only save the non-fc layers (i.e the layers where MoE isn't applied)
for name in to_remove:
del sd[name]
model.save_pretrained(f"{args.outdir}/base/base_model", state_dict=sd)
Copy link
Preview

Copilot AI Dec 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The save_pretrained method does not accept a state_dict parameter. Set the state_dict directly on the model before calling save_pretrained.

Suggested change
model.save_pretrained(f"{args.outdir}/base/base_model", state_dict=sd)
model.load_state_dict(sd)

Copilot is powered by AI, so mistakes are possible. Review output carefully before use.

Positive Feedback
Negative Feedback

Provide additional feedback

Please help us improve GitHub Copilot by sharing more details about this comment.

Please select one or more of the options
logger.info("Saving base model finished")

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--dataset",
type=str,
default="answer_verification",
help="The dataset to use for training, must be a path to a jsonl file.",
)
parser.add_argument(
"--n-samples",
type=int,
default=-1,
help="How many data samples used for calibration, -1 means all.",
)
parser.add_argument("--target-model", type=str)
parser.add_argument("--model-path", type=str)
parser.add_argument("--tokenizer", type=str)
parser.add_argument("--sparsity", type=float, default=0.5)
parser.add_argument("--bits", type=int, default=4)
parser.add_argument("--block-size", type=int, default=128)
parser.add_argument("--prunen", type=int, default=0)
parser.add_argument("--prunem", type=int, default=0)
parser.add_argument(
"--lossless", type=str, default="gdeflate", choices=["gdeflate"]
)
parser.add_argument("--delta", type=str, choices=["subtract", "xor"], default="")
parser.add_argument("--perc-damp", type=float, default=0.01)
parser.add_argument("--outdir", type=str, default=".cache/compressed_models")
parser.add_argument("--fast-tokenizer", action="store_true")
parser.add_argument("--shuffle-dataset", action="store_false")
args = parser.parse_args()
main(args)
68 changes: 68 additions & 0 deletions cli/save_moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import json
import torch
import transformers
from deltazip import AutoDeltaZipModelForCausalLM, BaseCompressionConfig, modelling_gpt_neox_moe
from deltazip.modeling._const import EXPERT_ID_PLACEHOLDER
from loguru import logger
from safetensors.torch import load_file, load_model

def save(model_type, model_path):
logger.info("Loading tokenizer")
if model_type == "gpt-neox-moe":
pass
else:
tokenizer = transformers.AutoTokenizer.from_pretrained(model_type, trust_remote_code=True)
logger.info("Tokenizer loaded")
logger.info("Loading base_model")

delta_model = None
config=None
if model_type == "gpt-neox-moe":
with open(f"{args.model_path}/base/base_model/config.json", "r") as fp:
Copy link
Preview

Copilot AI Dec 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variable 'args' is used instead of 'model_path'. It should be 'model_path' instead of 'args.model_path'.

Suggested change
with open(f"{args.model_path}/base/base_model/config.json", "r") as fp:
with open(f"{model_path}/base/base_model/config.json", "r") as fp:

Copilot is powered by AI, so mistakes are possible. Review output carefully before use.

Positive Feedback
Negative Feedback

Provide additional feedback

Please help us improve GitHub Copilot by sharing more details about this comment.

Please select one or more of the options
config = transformers.GPTNeoXConfig(**json.load(fp))
base_model = modelling_gpt_neox_moe.GPTNeoXForCausalLM(config)
base_model = base_model.half()
delta_model = modelling_gpt_neox_moe.GPTNeoXForCausalLM(config)
delta_model = delta_model.half()
load_model(base_model, f"{args.model_path}/base/base_model/model.safetensors", strict=False)
else:
base_model = transformers.AutoModelForCausalLM.from_pretrained(f"{model_path}/base/base_model", trust_remote_code=True)

base_model = base_model.half()
logger.info("Loading base weights")
base_weights = load_file(f"{model_path}/base/base_weights.safetensors")

delta_model = AutoDeltaZipModelForCausalLM.from_compressed(
args.model_path, strict=True, device="cpu", unpack=True, trust_remote_code=True, model_config=config, custom_model = delta_model
Copy link
Preview

Copilot AI Dec 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variable 'args' is used instead of 'model_path'. It should be 'model_path' instead of 'args.model_path'.

Suggested change
args.model_path, strict=True, device="cpu", unpack=True, trust_remote_code=True, model_config=config, custom_model = delta_model
model_path, strict=True, device="cpu", unpack=True, trust_remote_code=True, model_config=config, custom_model = delta_model

Copilot is powered by AI, so mistakes are possible. Review output carefully before use.

Positive Feedback
Negative Feedback

Provide additional feedback

Please help us improve GitHub Copilot by sharing more details about this comment.

Please select one or more of the options
)
delta_model = delta_model.half()
logger.info("Loading delta weights")
# print([n for n, _ in delta_model.named_parameters()])
for expert_name, expert_weight in base_weights.items():
prefix, suffix = expert_name.split(EXPERT_ID_PLACEHOLDER)
for name_base, param_base in base_model.named_parameters():
if name_base.startswith(prefix) and name_base.endswith(suffix):
# print(expert_name, name_base)
for name_delta, param_delta in delta_model.named_parameters():
# print(expert_name, name_base, name_delta)
if name_delta.endswith(name_base):
print("Merging weights: ", name_base, name_delta)
param_base.data = param_delta.data + expert_weight
param_base.data = param_base.data.contiguous()

delta_model = base_model
if model_type == "gpt-neox-moe":
pass
else:
tokenizer.save_pretrained(f"{model_path}/complete_model")
logger.info("Saving complete model")
delta_model.save_pretrained(f"{model_path}/complete_model")

if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--model-type", type=str, help="Type of model")
parser.add_argument("--model-path", type=str, help="Directory of compressed model")
args = parser.parse_args()
save(args.model_type, args.model_path)
Empty file added cli/temp.txt
Empty file.
3 changes: 3 additions & 0 deletions deltazip/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,6 @@
from .modeling import BaseCompressionConfig
from .modeling import AutoDeltaZipModelForCausalLM
from .modeling import AutoCompressionConfig
from .modeling import base_generation_strategies
from .modeling import modelling_gpt_neox_moe
from .modeling import modeling_llama_moe
2 changes: 1 addition & 1 deletion deltazip/core/sparsegpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def add_batch(self, inp, out):
inp = math.sqrt(2 / self.nsamples) * inp.float()
self.H += inp.matmul(inp.t())
sparsity_H = calculate_sparsity(self.H)
if sparsity_H == 1:
if torch.numel(inp) != 0 and sparsity_H == 1:
raise ValueError("sparsity of H == 1, something is off, aborting")

def fasterprune(
Expand Down
6 changes: 6 additions & 0 deletions deltazip/lossless/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
"fp16": torch.float16,
"fp32": torch.float32,
"int32": torch.int32,
"bool": torch.bool,
}

cp_dtype_maps = {
"int8": cp.int8,
"fp16": cp.float16,
"fp32": cp.float32,
"int32": cp.int32,
"bool": cp.bool_
}


Expand Down Expand Up @@ -60,6 +62,9 @@ def compress_tensor(self, tensor: torch.Tensor):
elif tensor.dtype == torch.float32:
dtype = "fp32"
self.comp_manager.input_type = cp.float32
elif tensor.dtype == torch.bool:
dtype = "bool"
self.comp_manager.input_type = cp.bool_
else:
raise ValueError(f"Unsupported dtype: {tensor.dtype}")
compressed_tensor = self.comp_manager.compress(to_compress_tensor)
Expand All @@ -84,6 +89,7 @@ def compress_state_dict(self, state_dict: Dict[str, torch.Tensor]):
tensors_shape = {}
tensors_dtype = {}
for key in state_dict:
print(key)
tensors[key], tensors_shape[key], tensors_dtype[key] = self.compress_tensor(
state_dict[key]
)
Expand Down
2 changes: 2 additions & 0 deletions deltazip/modeling/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
BaseCompressionConfig,
BaseDeltaZipModelForCausalLM
)

from .moe import base_generation_strategies, modelling_gpt_neox_moe, modeling_llama_moe
from .auto import *
from .bloom import *
from .gpt2 import *
Expand Down
Loading