Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi-dimensional pruning #23

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
222 changes: 222 additions & 0 deletions MultiPruner/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
# MultiPruner

Official implementation of [Fine-Grained Training-Free Structure Removal in Foundation Models]().

This repo contains the code for **MultiPruner**, a novel pruning approach that surpasses recent training-free pruning
methods, e.g., BlockPruner (Zhong el al., 2024) and ShortGPT (Men et al., 2024), by adopting a multidimensional, iterative, fine-grained pruning strategy.
Please refer to our paper for more details.

## News
- **[2024.12.14]** Release the code for **MultiPruner**. :tada:

## Supported Models 🤗

- **Llama**
- [x] [meta-llama/Llama-3.2-3B](https://huggingface.co/meta-llama/Llama-3.2-3B)
- [x] [meta-llama/Llama-3.1-8B](https://huggingface.co/meta-llama/Llama-3.1-8B)
- [x] [meta-llama/Meta-Llama-3-8B](https://huggingface.co/meta-llama/Meta-Llama-3-8B)
- [x] [meta-llama/Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf)
- [x] [meta-llama/Llama-2-13b-hf](https://huggingface.co/meta-llama/Llama-2-13b-hf)
- **Qwen**
- [x] [Qwen/Qwen2.5-7B](https://huggingface.co/Qwen/Qwen2.5-7B)
- [x] [Qwen/Qwen1.5-7B](https://huggingface.co/Qwen/Qwen1.5-7B)
- [x] [Qwen/Qwen1.5-14B](https://huggingface.co/Qwen/Qwen1.5-14B)
- **Baichuan**
- [x] [baichuan-inc/Baichuan2-7B-Base](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base)
- [x] [baichuan-inc/Baichuan2-13B-Base](https://huggingface.co/baichuan-inc/Baichuan2-13B-Base)

**All pruning result configurations and pruning commands are available [here](./results).**

## Setup

Use the following instructions to create a virtual environment with the required dependencies.

```
# install dependencies
bash install.sh
```

## Run

We use [meta-llama/Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf) model as an example.

### Prune

```bash
python run_multipruner.py \
--model_path meta-llama/Llama-2-7b-hf \
--output_path <path to pruning results> \
--weight_reorder \
--do_prune \
--target_ratio 22.00 \
--pruning_distribution 44:52:4 \
--mlp_channel_group_size 1024 \
--attn_channel_group_size 128 \
--importance_metric ppl \
--calibration_dataset alpaca \
--num_calibration_samples_block 256 \
--num_calibration_samples_width 128 \
--do_eval
```

- `model_path`: Path to the pre-trained model.
- `output_path`: Directory to save the pruning and evaluation results.
- `weight_reorder`: Indicates that weight reordering should be performed in Attn and MLP.
- `do_prune`: Flag to indicate whether to perform pruning.
- `target_ratio`: Target pruning ratio.
- `pruning_distribution`: Pruning ratio distribution for different granularities.
- `mlp_channel_group_size`: Number of channels for each group (MLP).
- `attn_channel_group_size`: Number of channels for each group (Attn), generally a multiple of the head dimension.
- `importance_metric`: Metric for calculating block importance, currently only supports PPL.
- `calibration_dataset`: Calibration dataset name ("alpaca", "c4", "ptb" or "wikitext2").
- `num_calibration_samples_block`: Number of calibration samples to use for depth (block) pruning (stage 1).
- `num_calibration_samples_width`: Number of calibration samples to use for width pruning (stage 2 and 3).
- `do_eval`: Flag to indicate whether to perform evaluation.

### Extract the Compressed Model

The final compressed model can be extracted based on the optimal pruning configuration obtained from MultiPruner.
For more details, please refer to [this link](./extract).
Below is an example of how to extract a pruned Llama-2-7B:

```bash
python extract/extract_model.py \
--model_path meta-llama/Llama-2-7b-hf \
--weight_reorder \
--pruned_model_config_file <path to pruning results>/pruning_config.json \
--output_path <path to compressed model>
```

### Recovery Finetuning

After we have obtained the pruned model, we can use the Alpaca dataset for recovery fine-tuning.
More details can be found [here](./recovery).
The following is an example command for the compressed Llama-2-7B:

```bash
# Finetune the compressed model
python recovery/finetune.py \
--model_path <path to compressed model> \
--do_train \
--batch_size 8 \
--gradient_accumulation_steps 4 \
--num_train_epochs 2 \
--learning_rate 1e-4 \
--lora \
--lora_r 16 \
--lora_alpha 32 \
--lora_target_modules q_proj,k_proj,v_proj,o_proj,down_proj,up_proj,gate_proj \
--output_path <path to finetuned compressed model> \
--do_eval
```

## Results

We have provided some running commands (including pruning and recovery-tuning) and pruning configurations of MultiPruner, which can be found [here](./results).

#### Llama-3.2-3B

| Method | Pruning Ratio | Acc. (%) | WikiText2 PPL |
|------------------|---------------|----------|---------------|
| Dense | / | 67.67 | 7.81 |
| BlockPruner | 9% | 62.31 | 13.07 |
| **MultiPruner** | 9% | 64.04 | 10.46 |

#### Llama-3.1-8B

| Method | Pruning Ratio | Acc. (%) | WikiText2 PPL |
|------------------|---------------|----------|---------------|
| Dense | / | 73.75 | 6.24 |
| BlockPruner | 10% | 66.75 | 10.58 |
| **MultiPruner** | 10% | 69.27 | 8.93 |
| BlockPruner | 20% | 59.08 | 15.37 |
| **MultiPruner** | 20% | 63.07 | 13.86 |

Compared to [LLM-Pruner](https://arxiv.org/abs/2305.11627) (Pruning ratio: ~17%):

| Method | WikiText2 PPL (Seq Len: 2048) |
|---------------------------|-------------------------------|
| Dense | 6.24 |
| BlockPruner | 13.78 |
| LLM-Pruner (L2) | 49.09 |
| LLM-Pruner (Taylor) | 12.71 |
| **MultiPruner** (10:90:0) | **11.64** |

#### Meta-Llama-3-8B

| Method | Pruning Ratio | Acc. (%) | WikiText2 PPL |
|------------------|---------------|----------|---------------|
| Dense | / | 72.73 | 6.14 |
| BlockPruner | 10% | 66.46 | 10.88 |
| **MultiPruner** | 10% | 69.03 | 8.19 |
| BlockPruner | 20% | 57.59 | 22.36 |
| **MultiPruner** | 20% | 63.02 | 16.01 |

Compared to [LLM-Pruner](https://arxiv.org/abs/2305.11627) (Pruning ratio: ~17%):

| Method | WikiText2 PPL (Seq Len: 2048) |
|----------------------------|-------------------------------|
| Dense | 6.14 |
| BlockPruner | 16.15 |
| LLM-Pruner (L2) | 34.13 |
| LLM-Pruner (Taylor) | 12.86 |
| **MultiPruner** (10:90:0) | **11.11** |

#### Qwen2.5-7B

| Method | Pruning Ratio | Acc. (%) | WikiText2 PPL |
|------------------|---------------|----------|---------------|
| Dense | / | 72.04 | 6.85 |
| BlockPruner | 10% | 67.44 | 9.88 |
| **MultiPruner** | 10% | 69.71 | 9.15 |
| BlockPruner | 20% | 57.44 | 17.17 |
| **MultiPruner** | 20% | 62.82 | 13.37 |


For additional results and discussions on other models, please refer to the paper.

In addition, we also explored pruning ratios that result in 1%, 2%, and 3%
accuracy degradation (compared to Dense), under both `without finetune` and `with finetune` scenarios.
This investigation may facilitate practical applications. The results of Llama-2-7B are shown in the following table:

| Method | Pruning Ratio | Acc. (%) | Acc. Drop | Relative Acc. |
|--------------------------|---------------|----------|-----------|---------------|
| Dense | / | 68.96 | / | 100% |
| MultiPruner w/o finetune | 7% | 67.94 | -1.02% | 98.52% |
| MultiPruner w/o finetune | 10% | 67.02 | -1.94% | 97.19% |
| MultiPruner w/o finetune | 14% | 65.93 | -3.03% | 95.61% |
| MultiPruner w/ finetune | 12% | 68.28 | -0.68% | 99.01% |
| MultiPruner w/ finetune | 15% | 67.41 | -1.55% | 97.75% |
| MultiPruner w/ finetune | 18% | 66.16 | -2.80% | 95.94% |

*In all tables, `Acc.(%)` represents the average accuracy score across the five tasks: `piqa`, `winogrande`, `hellaswag`, `arc_easy`, and `arc_challenge`.*

### Loading the compressed model for evaluation

```bash
python eval.py --model_path <path to compressed model> --output_path <path to evaluation results>
```

## Acknowledgement

MultiPruner benefits from the following work:

```bibtex
@article{zhong2024blockpruner,
title={BlockPruner: Fine-grained Pruning for Large Language Models},
author={Zhong, Longguang and Wan, Fanqi and Chen, Ruijun and Quan, Xiaojun and Li, Liangzhi},
journal={arXiv preprint arXiv:2406.10594},
year={2024}
}
```

## Citation
If you find MultiPruner's code and paper helpful, please kindly cite:
```bibtex
@article{munoz2025multipruner,
title = {Fine-Grained Training-Free Structure Removal in Foundation Models},
author = {J. Pablo Munoz and Jinjie Yuan and Nilesh Jain},
year = {2025},
url = {}
}
```
82 changes: 82 additions & 0 deletions MultiPruner/eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import os
import json
import logging
import argparse

from transformers import AutoModelForCausalLM, AutoTokenizer

from lm_eval import evaluator
from lm_eval.models.huggingface import HFLM

import utils


def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str)
parser.add_argument("--output_path", type=str)
args = parser.parse_args()
model_path = args.model_path
output_path = args.output_path

# Ensure the output directory exists
if not os.path.exists(output_path):
os.makedirs(output_path)

model = AutoModelForCausalLM.from_pretrained(
model_path,
device_map="auto",
torch_dtype="float16",
trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

# Evaluate on wikitext2 dataset
dataset = utils.get_dataset("wikitext2")
test_dataset = dataset["test"]
test_loader = utils.prepare_test_dataloader(
dataset=test_dataset,
tokenizer=tokenizer,
seqlen=2048,
batch_size=1
)
dataset_ppl = utils.evaluate_ppl(
model=model,
dataloader=test_loader,
pad_token_id=model.config.eos_token_id,
)
dataset_ppl = round(dataset_ppl, 2)
logging.info(f'wikitext2 PPL: {dataset_ppl}')

# Evaluate on selected tasks
hflm = HFLM(pretrained=model, tokenizer=tokenizer, batch_size=64)

task_names = ["piqa", "winogrande", "hellaswag", "arc_easy", "arc_challenge"]
logging.info(f"Selected Tasks: {task_names}")

results = evaluator.simple_evaluate(hflm, tasks=task_names, num_fewshot=0, batch_size=64, log_samples=False)['results']

metric_vals = {task: round(result.get('acc_norm,none', result['acc,none']), 4) * 100 for task, result in results.items()}
logging.info(json.dumps(metric_vals, indent=4))

def calculate_avg_accuracy(task_names, results):
n_tasks = len(task_names)
acc_cumul = sum(result.get('acc_norm,none', result['acc,none']) for task, result in results.items())
return round(acc_cumul / n_tasks, 4) * 100

acc_avg = calculate_avg_accuracy(task_names, results)
logging.info(f"Average accuracy across tasks: {acc_avg}")

# Save evaluation results
overall_results = {
"ppl_wikitext2": dataset_ppl,
"5cs_acc_avg": acc_avg,
**metric_vals
}
eval_result_path = os.path.join(output_path, f"eval.res.json")
with open(eval_result_path, "w") as f:
json.dump(overall_results, f, indent=4)


if __name__ == "__main__":
main()
17 changes: 17 additions & 0 deletions MultiPruner/extract/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
## Extract the Compressed Model from MultiPruner

The final compressed model can be extracted based on the optimal pruning configuration obtained from **MultiPruner**.
Here is an example command for the compressed Llama-2-7B:

```bash
python extract/extract_model.py \
--model_path meta-llama/Llama-2-7b-hf \
--weight_reorder \
--pruned_model_config_file <path to pruning result>/pruning_config.json \
--output_path <path to compressed model>
```

- `model_path`: Path to the pre-trained model.
- `weight_reorder`: Flag to indicate whether to perform weight reordering.
- `pruned_model_config_file`: JSON file for the pruned model configuration.
- `output_path`: Directory to save the compressed model.
Loading