-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
12 changed files
with
2,382 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
# MultiPruner | ||
|
||
Official implementation of [Fine-Grained Training-Free Structure Removal in Foundation Models](). | ||
|
||
This repo contains the code for **MultiPruner**, a novel pruning approach that surpasses recent training-free pruning | ||
methods by adopting a multidimensional, iterative, fine-grained pruning strategy. | ||
Please refer to our paper for more details. | ||
|
||
## News | ||
- **[2025.xx.xx]** Release the code for **MultiPruner**. :tada: | ||
|
||
## Setup | ||
|
||
Here is an installation script developed from scratch. | ||
|
||
``` | ||
pip install virtualenv | ||
virtualenv multipruner-env | ||
source multipruner-env/bin/activate | ||
pip install torch==2.3.1 | ||
# install dependencies | ||
bash install.sh | ||
``` | ||
|
||
## Run | ||
|
||
We use [meta-llama/Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf) model as an example. | ||
|
||
### Prune | ||
|
||
```bash | ||
python run_multipruner.py \ | ||
--model_path meta-llama/Llama-2-7b-hf \ | ||
--output_path <path to pruning results> \ | ||
--weight_reorder \ | ||
--do_prune \ | ||
--target_ratio 22.00 \ | ||
--pruning_distribution 44:52:4 \ | ||
--mlp_channel_group_size 1024 \ | ||
--attn_channel_group_size 128 \ | ||
--importance_metric ppl \ | ||
--calibration_dataset alpaca \ | ||
--num_calibration_samples_block 256 \ | ||
--num_calibration_samples_width 128 \ | ||
--do_eval | ||
``` | ||
|
||
- `model_path`: Path to the pre-trained model. | ||
- `output_path`: Directory to save the pruning and evaluation results. | ||
- `do_prune`: Flag to indicate whether to perform pruning. | ||
- `target_ratio`: Target pruning ratio. | ||
- `pruning_distribution`: Pruning ratio distribution for different granularities. | ||
- `mlp_channel_group_size`: Number of channels for each group (MLP). | ||
- `attn_channel_group_size`: Number of channels for each group (Attn), generally a multiple of the head dimension. | ||
- `importance_metric`: Metric for calculating block importance, currently only supports PPL. | ||
- `calibration_dataset`: Calibration dataset name ("alpaca", "c4", "ptb" or "wikitext2"). | ||
- `num_calibration_samples_block`: Number of calibration samples to use for depth (block) pruning (stage 1). | ||
- `num_calibration_samples_width`: Number of calibration samples to use for width pruning (stage 2 and 3). | ||
- `do_eval`: Flag to indicate whether to perform evaluation. | ||
|
||
### Extract the Compressed Model | ||
|
||
The final compressed model can be extracted based on the optimal pruning configuration obtained from MultiPruner. | ||
For more details, please refer to [this link](./extract). | ||
Below is an example of how to extract a pruned Llama-2-7B: | ||
|
||
```bash | ||
python extract/extract_model.py \ | ||
--model_path meta-llama/Llama-2-7b-hf \ | ||
--weight_reorder \ | ||
--pruned_model_config_file <path to pruning results>/pruning_config.json \ | ||
--output_path <path to compressed model> | ||
``` | ||
|
||
### Recovery Finetuning | ||
|
||
After we have obtained the pruned model, we can use the Alpaca dataset for recovery fine-tuning. | ||
More details can be found [here](./recovery). | ||
The following is an example command for the compressed Llama-2-7B: | ||
|
||
```bash | ||
# Finetune the compressed model | ||
python recovery/finetune.py \ | ||
--model_path <path to compressed model> \ | ||
--do_train \ | ||
--batch_size 8 \ | ||
--gradient_accumulation_steps 4 \ | ||
--num_train_epochs 2 \ | ||
--learning_rate 1e-4 \ | ||
--lora \ | ||
--lora_r 16 \ | ||
--lora_alpha 32 \ | ||
--lora_target_modules q_proj,k_proj,v_proj,o_proj,down_proj,up_proj,gate_proj \ | ||
--output_path <path to finetuned compressed model> \ | ||
--do_eval | ||
``` | ||
|
||
## Results | ||
|
||
We have provided some running commands (including pruning and recovery-tuning) and pruning results of MultiPruner, which can be found [here](./results). | ||
|
||
In addition to the 22% pruning ratio shown in the paper, we also explored pruning ratios that result in 1%, 2%, and 3% | ||
accuracy degradation (compared to Dense), under both `without finetune` and `with finetune` scenarios. | ||
This investigation may facilitate practical applications. The results of Llama-2-7b-hf are shown in the following table: | ||
|
||
| Method | Pruning Ratio | Acc. (%) | Acc. Drop | Relative Acc. | | ||
|--------------------------|---------------|----------|-----------|---------------| | ||
| Dense | / | 68.96 | / | 100% | | ||
| MultiPruner w/o finetune | 7% | 67.94 | -1.02% | 98.52% | | ||
| MultiPruner w/o finetune | 10% | 67.02 | -1.94% | 97.19% | | ||
| MultiPruner w/o finetune | 14% | 65.93 | -3.03% | 95.61% | | ||
| MultiPruner w/ finetune | 12% | 68.28 | -0.68% | 99.01% | | ||
| MultiPruner w/ finetune | 15% | 67.41 | -1.55% | 97.75% | | ||
| MultiPruner w/ finetune | 18% | 66.16 | -2.80% | 95.94% | | ||
|
||
|
||
## Released Pruned Models 🤗 | ||
|
||
We have released several compressed models by MultiPruner: | ||
|
||
| Source Model | Pruning Ratio | Recovery Tuning | Pruned Model | | ||
|-----------------------------------------------------------------------------------------|---------------|-----------------|---------------------------------------------------------------------------------------------------------------| | ||
| [meta-llama/Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf) | 7% | ✘ | [IntelLabs/MultiPruner-Llama-2-6.3b](https://huggingface.co/IntelLabs/MultiPruner-Llama-2-6.3b) | | ||
| [meta-llama/Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf) | 10% | ✘ | [IntelLabs/MultiPruner-Llama-2-6.1b](https://huggingface.co/IntelLabs/MultiPruner-Llama-2-6.1b) | | ||
| [meta-llama/Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf) | 12% | ✘ | [IntelLabs/MultiPruner-Llama-2-5.9b](https://huggingface.co/IntelLabs/MultiPruner-Llama-2-5.9b) | | ||
| [meta-llama/Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf) | 12% | ✔ | [IntelLabs/MultiPruner-Llama-2-5.9b-alpaca](https://huggingface.co/IntelLabs/MultiPruner-Llama-2-5.9b-alpaca) | | ||
| [meta-llama/Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf) | 14% | ✘ | [IntelLabs/MultiPruner-Llama-2-5.8b](https://huggingface.co/IntelLabs/MultiPruner-Llama-2-5.8b) | | ||
| [meta-llama/Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf) | 15% | ✘ | [IntelLabs/MultiPruner-Llama-2-5.7b](https://huggingface.co/IntelLabs/MultiPruner-Llama-2-5.7b) | | ||
| [meta-llama/Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf) | 15% | ✔ | [IntelLabs/MultiPruner-Llama-2-5.7b-alpaca](https://huggingface.co/IntelLabs/MultiPruner-Llama-2-5.7b-alpaca) | | ||
| [meta-llama/Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf) | 18% | ✘ | [IntelLabs/MultiPruner-Llama-2-5.5b](https://huggingface.co/IntelLabs/MultiPruner-Llama-2-5.5b) | | ||
| [meta-llama/Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf) | 18% | ✔ | [IntelLabs/MultiPruner-Llama-2-5.5b-alpaca](https://huggingface.co/IntelLabs/MultiPruner-Llama-2-5.5b-alpaca) | | ||
| [meta-llama/Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf) | 22% | ✘ | [IntelLabs/MultiPruner-Llama-2-5.3b](https://huggingface.co/IntelLabs/MultiPruner-Llama-2-5.3b) | | ||
| [meta-llama/Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf) | 22% | ✔ | [IntelLabs/MultiPruner-Llama-2-5.3b-alpaca](https://huggingface.co/IntelLabs/MultiPruner-Llama-2-5.3b-alpaca) | | ||
| [Qwen/Qwen1.5-7B](https://huggingface.co/Qwen/Qwen1.5-7B) | 22% | ✘ | [IntelLabs/MultiPruner-Qwen1.5-6b](https://huggingface.co/IntelLabs/MultiPruner-Qwen1.5-6b) | | ||
| [baichuan-inc/Baichuan2-7B-Base](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base) | 22% | ✘ | [IntelLabs/MultiPruner-Baichuan2-5.8b](https://huggingface.co/IntelLabs/MultiPruner-Baichuan2-5.8b) | | ||
|
||
### Loading the compressed model for evaluation | ||
|
||
```bash | ||
python eval.py --model_path <path to compressed model> --output_path <path to evaluation results> | ||
``` | ||
|
||
## Citation | ||
If you find MultiPruner's code and paper helpful, please kindly cite: | ||
```bibtex | ||
@article{munoz2025multipruner, | ||
title = {Fine-Grained Training-Free Structure Removal in Foundation Models}, | ||
author = {J. Pablo Munoz and Jinjie Yuan and Nilesh Jain}, | ||
year = {2025}, | ||
url = {} | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
import os | ||
import json | ||
import logging | ||
import argparse | ||
|
||
from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
||
from lm_eval import evaluator | ||
from lm_eval.models.huggingface import HFLM | ||
|
||
import utils | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--model_path", type=str) | ||
parser.add_argument("--output_path", type=str) | ||
args = parser.parse_args() | ||
model_path = args.model_path | ||
output_path = args.output_path | ||
|
||
# Ensure the output directory exists | ||
if not os.path.exists(output_path): | ||
os.makedirs(output_path) | ||
|
||
model = AutoModelForCausalLM.from_pretrained( | ||
model_path, | ||
device_map="auto", | ||
torch_dtype="float16", | ||
trust_remote_code=True, | ||
) | ||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) | ||
|
||
# Evaluate on wikitext2 dataset | ||
dataset = utils.get_dataset("wikitext2") | ||
test_dataset = dataset["test"] | ||
test_loader = utils.prepare_test_dataloader( | ||
dataset=test_dataset, | ||
tokenizer=tokenizer, | ||
seqlen=2048, | ||
batch_size=1 | ||
) | ||
dataset_ppl = utils.evaluate_ppl( | ||
model=model, | ||
dataloader=test_loader, | ||
pad_token_id=model.config.eos_token_id, | ||
) | ||
dataset_ppl = round(dataset_ppl, 2) | ||
logging.info(f'wikitext2 PPL: {dataset_ppl}') | ||
|
||
# Evaluate on selected tasks | ||
hflm = HFLM(pretrained=model, tokenizer=tokenizer, batch_size=64) | ||
|
||
task_names = ["piqa", "winogrande", "hellaswag", "arc_easy", "arc_challenge"] | ||
logging.info(f"Selected Tasks: {task_names}") | ||
|
||
results = evaluator.simple_evaluate(hflm, tasks=task_names, num_fewshot=0, batch_size=64, log_samples=False)['results'] | ||
|
||
metric_vals = {task: round(result.get('acc_norm,none', result['acc,none']), 4) * 100 for task, result in results.items()} | ||
logging.info(json.dumps(metric_vals, indent=4)) | ||
|
||
def calculate_avg_accuracy(task_names, results): | ||
n_tasks = len(task_names) | ||
acc_cumul = sum(result.get('acc_norm,none', result['acc,none']) for task, result in results.items()) | ||
return round(acc_cumul / n_tasks, 4) * 100 | ||
|
||
acc_avg = calculate_avg_accuracy(task_names, results) | ||
logging.info(f"Average accuracy across tasks: {acc_avg}") | ||
|
||
# Save evaluation results | ||
overall_results = { | ||
"ppl_wikitext2": dataset_ppl, | ||
"5cs_acc_avg": acc_avg, | ||
**metric_vals | ||
} | ||
eval_result_path = os.path.join(output_path, f"eval.res.json") | ||
with open(eval_result_path, "w") as f: | ||
json.dump(overall_results, f, indent=4) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
## Extract the compressed Model from MultiPruner | ||
|
||
The final compressed model can be extracted based on the optimal pruning configuration obtained from **MultiPruner**. | ||
Here is an example command for the compressed Llama-2-7B: | ||
|
||
```bash | ||
python extract/extract_model.py \ | ||
--model_path meta-llama/Llama-2-7b-hf \ | ||
--weight_reorder \ | ||
--pruned_model_config_file <path to pruning result>/pruning_config.json \ | ||
--output_path <path to compressed model> | ||
``` | ||
|
||
- `model_path`: Path to the pre-trained model. | ||
- `weight_reorder`: Flag to indicate whether to perform weight reordering. | ||
- `pruned_model_config_file`: JSON file for the pruned model configuration. | ||
- `output_path`: Directory to save the compressed model. | ||
|
||
Some extracted models can be found in [this Table](../README.md#released-pruned-models-). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
import argparse | ||
import json | ||
import logging | ||
import os | ||
import sys | ||
import torch | ||
from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
||
current_dir = os.path.dirname(os.path.abspath(__file__)) | ||
project_root = os.path.abspath(os.path.join(current_dir, os.pardir)) | ||
sys.path.append(project_root) | ||
|
||
import utils | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--model_path", | ||
type=str, | ||
default="meta-llama/Llama-2-7b-hf", | ||
help="Path to the pre-trained model." | ||
) | ||
parser.add_argument( | ||
"--output_path", | ||
type=str, | ||
default="prune_result", | ||
help="Directory to save the compressed model." | ||
) | ||
parser.add_argument( | ||
"--weight_reorder", | ||
action="store_true", | ||
help="Flag to indicate whether to perform weight reorder." | ||
) | ||
parser.add_argument( | ||
"--pruned_model_config_file", | ||
type=str, | ||
default=None, | ||
help="Path to the pruned model configuration file." | ||
) | ||
|
||
args = parser.parse_args() | ||
model_path = args.model_path | ||
output_path = args.output_path | ||
weight_reorder = args.weight_reorder | ||
# Create output directory if it doesn't exist | ||
os.makedirs(output_path, exist_ok=True) | ||
|
||
pruned_model_config_file = args.pruned_model_config_file | ||
|
||
# Load model and tokenizer | ||
model = AutoModelForCausalLM.from_pretrained( | ||
model_path, | ||
device_map={"": 0}, | ||
trust_remote_code=True, | ||
torch_dtype="float16", | ||
) | ||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) | ||
|
||
if weight_reorder: | ||
for layer in utils.get_layers(model): | ||
utils.reorder_in_attn_block(getattr(layer, utils.get_attn_key(model)), model=model) | ||
utils.reorder_in_mlp_block(getattr(layer, utils.get_mlp_key(model))) | ||
|
||
# Load pruning results | ||
with open(pruned_model_config_file, "r") as f: | ||
pruned_config = json.load(f) | ||
logging.info(f"Detect a pruned model config: {pruned_config}") | ||
state_dict = model.state_dict() | ||
|
||
def get_groups(model, key): | ||
groups = [] | ||
for layer in utils.get_layers(model): | ||
modules = getattr(layer, key) | ||
groups.append({name: module for name, module in modules.named_children() if isinstance(module, torch.nn.Linear)}) | ||
return groups | ||
|
||
def get_pruned_weights(groups, pruned_channels): | ||
module_to_weight = {} | ||
for group_idx, value in pruned_channels.items(): | ||
group = groups[int(group_idx)] | ||
for name, module in group.items(): | ||
if name in utils.DEPENDENCY_GROUPS: | ||
module_to_weight[module] = module.weight[:, :value] | ||
else: | ||
module_to_weight[module] = module.weight[:value] | ||
return module_to_weight | ||
|
||
mlp_groups = get_groups(model, utils.get_mlp_key(model)) | ||
attn_groups = get_groups(model, utils.get_attn_key(model)) | ||
module_to_weight = {} | ||
if pruned_config.get("pruned_attn_width"): | ||
module_to_weight.update(get_pruned_weights(attn_groups, pruned_config["pruned_attn_width"])) | ||
if pruned_config.get("pruned_mlp_width"): | ||
module_to_weight.update(get_pruned_weights(mlp_groups, pruned_config["pruned_mlp_width"])) | ||
|
||
linear_modules = {name: module for name, module in model.named_modules() if isinstance(module, torch.nn.Linear)} | ||
for name, module in linear_modules.items(): | ||
if module in module_to_weight: | ||
sd_weight_key = name + ".weight" | ||
assert sd_weight_key in state_dict | ||
pruned_weight = module_to_weight[module] | ||
state_dict[sd_weight_key] = pruned_weight.clone() | ||
# bias | ||
sd_bias_key = name + ".bias" | ||
if sd_bias_key in state_dict: | ||
state_dict[sd_bias_key] = state_dict[sd_bias_key][:pruned_weight.size(0)].clone() | ||
|
||
def prune_modules(state_dict, idx, key): | ||
target = f".{str(idx)}.{key}" | ||
remove_key = [] | ||
for name, module in state_dict.items(): | ||
if target in name: | ||
remove_key.append(name) | ||
for key in remove_key: | ||
del state_dict[key] | ||
|
||
if pruned_config.get("pruned_attn_idx"): | ||
pruned_attn_idx = pruned_config["pruned_attn_idx"] | ||
for idx in pruned_attn_idx: | ||
prune_modules(state_dict, idx, utils.get_attn_key(model)) | ||
if pruned_config.get("pruned_mlp_idx"): | ||
pruned_mlp_idx = pruned_config["pruned_mlp_idx"] | ||
for idx in pruned_mlp_idx: | ||
prune_modules(state_dict, idx, utils.get_mlp_key(model)) | ||
|
||
model.save_pretrained(output_path, state_dict=state_dict) | ||
tokenizer.save_pretrained(output_path) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Oops, something went wrong.