Skip to content

Commit

Permalink
Multi-dimensional pruning
Browse files Browse the repository at this point in the history
  • Loading branch information
jpablomch committed Dec 11, 2024
1 parent a5d3f6c commit 0b81417
Show file tree
Hide file tree
Showing 12 changed files with 2,382 additions and 2 deletions.
153 changes: 153 additions & 0 deletions MultiPruner/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# MultiPruner

Official implementation of [Fine-Grained Training-Free Structure Removal in Foundation Models]().

This repo contains the code for **MultiPruner**, a novel pruning approach that surpasses recent training-free pruning
methods by adopting a multidimensional, iterative, fine-grained pruning strategy.
Please refer to our paper for more details.

## News
- **[2025.xx.xx]** Release the code for **MultiPruner**. :tada:

## Setup

Here is an installation script developed from scratch.

```
pip install virtualenv
virtualenv multipruner-env
source multipruner-env/bin/activate
pip install torch==2.3.1
# install dependencies
bash install.sh
```

## Run

We use [meta-llama/Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf) model as an example.

### Prune

```bash
python run_multipruner.py \
--model_path meta-llama/Llama-2-7b-hf \
--output_path <path to pruning results> \
--weight_reorder \
--do_prune \
--target_ratio 22.00 \
--pruning_distribution 44:52:4 \
--mlp_channel_group_size 1024 \
--attn_channel_group_size 128 \
--importance_metric ppl \
--calibration_dataset alpaca \
--num_calibration_samples_block 256 \
--num_calibration_samples_width 128 \
--do_eval
```

- `model_path`: Path to the pre-trained model.
- `output_path`: Directory to save the pruning and evaluation results.
- `do_prune`: Flag to indicate whether to perform pruning.
- `target_ratio`: Target pruning ratio.
- `pruning_distribution`: Pruning ratio distribution for different granularities.
- `mlp_channel_group_size`: Number of channels for each group (MLP).
- `attn_channel_group_size`: Number of channels for each group (Attn), generally a multiple of the head dimension.
- `importance_metric`: Metric for calculating block importance, currently only supports PPL.
- `calibration_dataset`: Calibration dataset name ("alpaca", "c4", "ptb" or "wikitext2").
- `num_calibration_samples_block`: Number of calibration samples to use for depth (block) pruning (stage 1).
- `num_calibration_samples_width`: Number of calibration samples to use for width pruning (stage 2 and 3).
- `do_eval`: Flag to indicate whether to perform evaluation.

### Extract the Compressed Model

The final compressed model can be extracted based on the optimal pruning configuration obtained from MultiPruner.
For more details, please refer to [this link](./extract).
Below is an example of how to extract a pruned Llama-2-7B:

```bash
python extract/extract_model.py \
--model_path meta-llama/Llama-2-7b-hf \
--weight_reorder \
--pruned_model_config_file <path to pruning results>/pruning_config.json \
--output_path <path to compressed model>
```

### Recovery Finetuning

After we have obtained the pruned model, we can use the Alpaca dataset for recovery fine-tuning.
More details can be found [here](./recovery).
The following is an example command for the compressed Llama-2-7B:

```bash
# Finetune the compressed model
python recovery/finetune.py \
--model_path <path to compressed model> \
--do_train \
--batch_size 8 \
--gradient_accumulation_steps 4 \
--num_train_epochs 2 \
--learning_rate 1e-4 \
--lora \
--lora_r 16 \
--lora_alpha 32 \
--lora_target_modules q_proj,k_proj,v_proj,o_proj,down_proj,up_proj,gate_proj \
--output_path <path to finetuned compressed model> \
--do_eval
```

## Results

We have provided some running commands (including pruning and recovery-tuning) and pruning results of MultiPruner, which can be found [here](./results).

In addition to the 22% pruning ratio shown in the paper, we also explored pruning ratios that result in 1%, 2%, and 3%
accuracy degradation (compared to Dense), under both `without finetune` and `with finetune` scenarios.
This investigation may facilitate practical applications. The results of Llama-2-7b-hf are shown in the following table:

| Method | Pruning Ratio | Acc. (%) | Acc. Drop | Relative Acc. |
|--------------------------|---------------|----------|-----------|---------------|
| Dense | / | 68.96 | / | 100% |
| MultiPruner w/o finetune | 7% | 67.94 | -1.02% | 98.52% |
| MultiPruner w/o finetune | 10% | 67.02 | -1.94% | 97.19% |
| MultiPruner w/o finetune | 14% | 65.93 | -3.03% | 95.61% |
| MultiPruner w/ finetune | 12% | 68.28 | -0.68% | 99.01% |
| MultiPruner w/ finetune | 15% | 67.41 | -1.55% | 97.75% |
| MultiPruner w/ finetune | 18% | 66.16 | -2.80% | 95.94% |


## Released Pruned Models 🤗

We have released several compressed models by MultiPruner:

| Source Model | Pruning Ratio | Recovery Tuning | Pruned Model |
|-----------------------------------------------------------------------------------------|---------------|-----------------|---------------------------------------------------------------------------------------------------------------|
| [meta-llama/Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf) | 7% || [IntelLabs/MultiPruner-Llama-2-6.3b](https://huggingface.co/IntelLabs/MultiPruner-Llama-2-6.3b) |
| [meta-llama/Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf) | 10% || [IntelLabs/MultiPruner-Llama-2-6.1b](https://huggingface.co/IntelLabs/MultiPruner-Llama-2-6.1b) |
| [meta-llama/Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf) | 12% || [IntelLabs/MultiPruner-Llama-2-5.9b](https://huggingface.co/IntelLabs/MultiPruner-Llama-2-5.9b) |
| [meta-llama/Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf) | 12% || [IntelLabs/MultiPruner-Llama-2-5.9b-alpaca](https://huggingface.co/IntelLabs/MultiPruner-Llama-2-5.9b-alpaca) |
| [meta-llama/Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf) | 14% || [IntelLabs/MultiPruner-Llama-2-5.8b](https://huggingface.co/IntelLabs/MultiPruner-Llama-2-5.8b) |
| [meta-llama/Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf) | 15% || [IntelLabs/MultiPruner-Llama-2-5.7b](https://huggingface.co/IntelLabs/MultiPruner-Llama-2-5.7b) |
| [meta-llama/Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf) | 15% || [IntelLabs/MultiPruner-Llama-2-5.7b-alpaca](https://huggingface.co/IntelLabs/MultiPruner-Llama-2-5.7b-alpaca) |
| [meta-llama/Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf) | 18% || [IntelLabs/MultiPruner-Llama-2-5.5b](https://huggingface.co/IntelLabs/MultiPruner-Llama-2-5.5b) |
| [meta-llama/Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf) | 18% || [IntelLabs/MultiPruner-Llama-2-5.5b-alpaca](https://huggingface.co/IntelLabs/MultiPruner-Llama-2-5.5b-alpaca) |
| [meta-llama/Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf) | 22% || [IntelLabs/MultiPruner-Llama-2-5.3b](https://huggingface.co/IntelLabs/MultiPruner-Llama-2-5.3b) |
| [meta-llama/Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf) | 22% || [IntelLabs/MultiPruner-Llama-2-5.3b-alpaca](https://huggingface.co/IntelLabs/MultiPruner-Llama-2-5.3b-alpaca) |
| [Qwen/Qwen1.5-7B](https://huggingface.co/Qwen/Qwen1.5-7B) | 22% || [IntelLabs/MultiPruner-Qwen1.5-6b](https://huggingface.co/IntelLabs/MultiPruner-Qwen1.5-6b) |
| [baichuan-inc/Baichuan2-7B-Base](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base) | 22% || [IntelLabs/MultiPruner-Baichuan2-5.8b](https://huggingface.co/IntelLabs/MultiPruner-Baichuan2-5.8b) |

### Loading the compressed model for evaluation

```bash
python eval.py --model_path <path to compressed model> --output_path <path to evaluation results>
```

## Citation
If you find MultiPruner's code and paper helpful, please kindly cite:
```bibtex
@article{munoz2025multipruner,
title = {Fine-Grained Training-Free Structure Removal in Foundation Models},
author = {J. Pablo Munoz and Jinjie Yuan and Nilesh Jain},
year = {2025},
url = {}
}
```
82 changes: 82 additions & 0 deletions MultiPruner/eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import os
import json
import logging
import argparse

from transformers import AutoModelForCausalLM, AutoTokenizer

from lm_eval import evaluator
from lm_eval.models.huggingface import HFLM

import utils


def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str)
parser.add_argument("--output_path", type=str)
args = parser.parse_args()
model_path = args.model_path
output_path = args.output_path

# Ensure the output directory exists
if not os.path.exists(output_path):
os.makedirs(output_path)

model = AutoModelForCausalLM.from_pretrained(
model_path,
device_map="auto",
torch_dtype="float16",
trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

# Evaluate on wikitext2 dataset
dataset = utils.get_dataset("wikitext2")
test_dataset = dataset["test"]
test_loader = utils.prepare_test_dataloader(
dataset=test_dataset,
tokenizer=tokenizer,
seqlen=2048,
batch_size=1
)
dataset_ppl = utils.evaluate_ppl(
model=model,
dataloader=test_loader,
pad_token_id=model.config.eos_token_id,
)
dataset_ppl = round(dataset_ppl, 2)
logging.info(f'wikitext2 PPL: {dataset_ppl}')

# Evaluate on selected tasks
hflm = HFLM(pretrained=model, tokenizer=tokenizer, batch_size=64)

task_names = ["piqa", "winogrande", "hellaswag", "arc_easy", "arc_challenge"]
logging.info(f"Selected Tasks: {task_names}")

results = evaluator.simple_evaluate(hflm, tasks=task_names, num_fewshot=0, batch_size=64, log_samples=False)['results']

metric_vals = {task: round(result.get('acc_norm,none', result['acc,none']), 4) * 100 for task, result in results.items()}
logging.info(json.dumps(metric_vals, indent=4))

def calculate_avg_accuracy(task_names, results):
n_tasks = len(task_names)
acc_cumul = sum(result.get('acc_norm,none', result['acc,none']) for task, result in results.items())
return round(acc_cumul / n_tasks, 4) * 100

acc_avg = calculate_avg_accuracy(task_names, results)
logging.info(f"Average accuracy across tasks: {acc_avg}")

# Save evaluation results
overall_results = {
"ppl_wikitext2": dataset_ppl,
"5cs_acc_avg": acc_avg,
**metric_vals
}
eval_result_path = os.path.join(output_path, f"eval.res.json")
with open(eval_result_path, "w") as f:
json.dump(overall_results, f, indent=4)


if __name__ == "__main__":
main()
19 changes: 19 additions & 0 deletions MultiPruner/extract/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
## Extract the compressed Model from MultiPruner

The final compressed model can be extracted based on the optimal pruning configuration obtained from **MultiPruner**.
Here is an example command for the compressed Llama-2-7B:

```bash
python extract/extract_model.py \
--model_path meta-llama/Llama-2-7b-hf \
--weight_reorder \
--pruned_model_config_file <path to pruning result>/pruning_config.json \
--output_path <path to compressed model>
```

- `model_path`: Path to the pre-trained model.
- `weight_reorder`: Flag to indicate whether to perform weight reordering.
- `pruned_model_config_file`: JSON file for the pruned model configuration.
- `output_path`: Directory to save the compressed model.

Some extracted models can be found in [this Table](../README.md#released-pruned-models-).
132 changes: 132 additions & 0 deletions MultiPruner/extract/extract_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import argparse
import json
import logging
import os
import sys
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.abspath(os.path.join(current_dir, os.pardir))
sys.path.append(project_root)

import utils


def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_path",
type=str,
default="meta-llama/Llama-2-7b-hf",
help="Path to the pre-trained model."
)
parser.add_argument(
"--output_path",
type=str,
default="prune_result",
help="Directory to save the compressed model."
)
parser.add_argument(
"--weight_reorder",
action="store_true",
help="Flag to indicate whether to perform weight reorder."
)
parser.add_argument(
"--pruned_model_config_file",
type=str,
default=None,
help="Path to the pruned model configuration file."
)

args = parser.parse_args()
model_path = args.model_path
output_path = args.output_path
weight_reorder = args.weight_reorder
# Create output directory if it doesn't exist
os.makedirs(output_path, exist_ok=True)

pruned_model_config_file = args.pruned_model_config_file

# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(
model_path,
device_map={"": 0},
trust_remote_code=True,
torch_dtype="float16",
)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

if weight_reorder:
for layer in utils.get_layers(model):
utils.reorder_in_attn_block(getattr(layer, utils.get_attn_key(model)), model=model)
utils.reorder_in_mlp_block(getattr(layer, utils.get_mlp_key(model)))

# Load pruning results
with open(pruned_model_config_file, "r") as f:
pruned_config = json.load(f)
logging.info(f"Detect a pruned model config: {pruned_config}")
state_dict = model.state_dict()

def get_groups(model, key):
groups = []
for layer in utils.get_layers(model):
modules = getattr(layer, key)
groups.append({name: module for name, module in modules.named_children() if isinstance(module, torch.nn.Linear)})
return groups

def get_pruned_weights(groups, pruned_channels):
module_to_weight = {}
for group_idx, value in pruned_channels.items():
group = groups[int(group_idx)]
for name, module in group.items():
if name in utils.DEPENDENCY_GROUPS:
module_to_weight[module] = module.weight[:, :value]
else:
module_to_weight[module] = module.weight[:value]
return module_to_weight

mlp_groups = get_groups(model, utils.get_mlp_key(model))
attn_groups = get_groups(model, utils.get_attn_key(model))
module_to_weight = {}
if pruned_config.get("pruned_attn_width"):
module_to_weight.update(get_pruned_weights(attn_groups, pruned_config["pruned_attn_width"]))
if pruned_config.get("pruned_mlp_width"):
module_to_weight.update(get_pruned_weights(mlp_groups, pruned_config["pruned_mlp_width"]))

linear_modules = {name: module for name, module in model.named_modules() if isinstance(module, torch.nn.Linear)}
for name, module in linear_modules.items():
if module in module_to_weight:
sd_weight_key = name + ".weight"
assert sd_weight_key in state_dict
pruned_weight = module_to_weight[module]
state_dict[sd_weight_key] = pruned_weight.clone()
# bias
sd_bias_key = name + ".bias"
if sd_bias_key in state_dict:
state_dict[sd_bias_key] = state_dict[sd_bias_key][:pruned_weight.size(0)].clone()

def prune_modules(state_dict, idx, key):
target = f".{str(idx)}.{key}"
remove_key = []
for name, module in state_dict.items():
if target in name:
remove_key.append(name)
for key in remove_key:
del state_dict[key]

if pruned_config.get("pruned_attn_idx"):
pruned_attn_idx = pruned_config["pruned_attn_idx"]
for idx in pruned_attn_idx:
prune_modules(state_dict, idx, utils.get_attn_key(model))
if pruned_config.get("pruned_mlp_idx"):
pruned_mlp_idx = pruned_config["pruned_mlp_idx"]
for idx in pruned_mlp_idx:
prune_modules(state_dict, idx, utils.get_mlp_key(model))

model.save_pretrained(output_path, state_dict=state_dict)
tokenizer.save_pretrained(output_path)


if __name__ == "__main__":
main()
Loading

0 comments on commit 0b81417

Please sign in to comment.