Skip to content

Commit

Permalink
adding initial code drop for llm finetune
Browse files Browse the repository at this point in the history
  • Loading branch information
itayhubara committed Jan 18, 2024
1 parent 00f04c5 commit ab55445
Show file tree
Hide file tree
Showing 9 changed files with 1,347 additions and 0 deletions.
90 changes: 90 additions & 0 deletions llm_finetune/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# LoRA benchmark

LoRA benchmark on GPU (Nvidia A100 80GB). Inspired by [this blog post](https://medium.com/@sourabmangrulkar/falcon-180b-finetuning-using-peft-and-deepspeed-b92643091d99) and [this script](https://github.com/pacman100/DHS-LLM-Workshop/blob/main/chat_assistant/training/train.py).


## Setup

Run the following:
```bash
sudo ./run_docker.sh
cd lora
pip install -r requirements.txt
```

> The Docker run command contains `-v /home/regis_huggingface_co/workspace:/root/workspace --workdir /root/workspace`. Feel free to change these flags at your own convenience.
You will also need to run the following to install flash attention:
```
pip install flash-attn --no-build-isolation
```

> For flash attention, make sure that the following command returns 0:
> ```
> ninja --version >/dev/null && echo $?
> ```
> If not, run
> ```
> pip uninstall -y ninja && pip install ninja
> ```
> and install `flash-attn` again.
> More information [here](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#installation-and-features).
Make sure to have requested permission for donwloading Llama2 weights on the Hugging Face Hub: https://huggingface.co/meta-llama/Llama-2-7b-hf
Then, you will need to be connected to your Hugging Face account with a read token running:
```
huggingface-cli login
```


## Llama2-70B on 8 devices

Run:
```bash
accelerate launch --config_file configs/default_config.yaml scripts/train.py \
--model_name meta-llama/Llama-2-70b-hf \
--dataset_name "tau/scrolls" --dataset_config_name "gov_report" \
--max_seq_len 8192 \
--bf16 True \
--logging_steps 1 \
--eval_steps 22 \
--output_dir "/tmp/llama-70b" \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 1 \
--dataset_text_field "input" \
--lr_scheduler_type "cosine" \
--learning_rate 1e-3 \
--warmup_ratio 0.03 \
--use_gradient_checkpointing True \
--use_peft_lora True \
--lora_r 16 \
--lora_alpha 32 \
--lora_dropout 0.1 \
--max_steps 440 \
--use_flash_attn \
--lora_target_modules "q_proj,v_proj,k_proj,o_proj"
```
where the Accelerate config file is [this one](https://github.com/regisss/lora/blob/main/configs/default_config.yaml).

> Using flash attention with `--use_flash_attn` is necessary for training on 8k-token sequences.
Learning curves of such a run can be found here: https://huggingface.co/regisss/test_5/tensorboard


## Evaluation

To run evaluation for summarizing texts, you can run:
- Without LoRA adapter weights:
```
python scripts/eval.py --model_name meta-llama/Llama-2-70b-hf --max_new_tokens 900 --seq_length 8192 --do_sample --dataset_name "tau/scrolls" --dataset_config_name "gov_report"
```
- With LoRA adapter weights:
```
python scripts/eval.py --peft_model_name path_to_my_lora_model --max_new_tokens 900 --seq_length 8192 --do_sample --dataset_name "tau/scrolls" --dataset_config_name "gov_report"
```
## expected outcome

A clean output (train and eval loss) of a singel run with 440 steps can be found under
```
convergence_example.txt
```
22 changes: 22 additions & 0 deletions llm_finetune/configs/default_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
gradient_accumulation_steps: 1
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: true
zero3_save_16bit_model: true
zero_stage: 3
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
508 changes: 508 additions & 0 deletions llm_finetune/convergence_example.txt

Large diffs are not rendered by default.

9 changes: 9 additions & 0 deletions llm_finetune/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
transformers
accelerate
peft
datasets
deepspeed
bitsandbytes
evaluate
nltk
rouge-score
2 changes: 2 additions & 0 deletions llm_finetune/run_docker.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
docker pull nvcr.io/nvidia/pytorch:23.09-py3
docker run -v path_to_my_folder:/root/workspace --workdir /root/workspace --gpus all -it --rm --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 nvcr.io/nvidia/pytorch:23.09-py3
24 changes: 24 additions & 0 deletions llm_finetune/run_llama_70B_scrolls_r16.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
accelerate launch --config_file configs/default_config.yaml scripts/train.py \
--model_name meta-llama/Llama-2-70b-hf \
--dataset_name "tau/scrolls" --dataset_config_name "gov_report" \
--max_seq_len 8192 \
--bf16 True \
--logging_steps 1 \
--eval_steps 22 \
--save_steps 22 \
--output_dir "./results/llama-70b_scrolls_gov_report_r16_$1" \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 1 \
--dataset_text_field "input" \
--lr_scheduler_type "cosine" \
--learning_rate 1e-3 \
--warmup_ratio 0.03 \
--use_gradient_checkpointing True \
--use_peft_lora True \
--lora_r 16 \
--lora_alpha 32 \
--lora_dropout 0.1 \
--max_steps 440 \
--use_flash_attn \
--seed "$1" \
--lora_target_modules "q_proj,v_proj,k_proj,o_proj"
190 changes: 190 additions & 0 deletions llm_finetune/scripts/eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
import argparse
import torch
from dataclasses import dataclass
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import AutoPeftModelForCausalLM
from peft.config import PeftConfigMixin
from datasets import load_dataset
import evaluate
import nltk
import numpy as np
from tqdm import tqdm
from typing import Any, Dict, List, Union, Optional
from transformers.tokenization_utils_base import PreTrainedTokenizerBase

nltk.download("punkt")

# Arguments management
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name",
default=None,
type=str,
help="Path to pre-trained model (on the HF Hub or locally).",
)
parser.add_argument(
"--peft_model_name",
default=None,
type=str,
help="Path to PEFT model (on the HF Hub or locally).",
)
parser.add_argument(
"--max_new_tokens", type=int, default=300, help="Number of tokens to generate."
)
parser.add_argument("--seq_length", type=int, default=8192, help="Sequence length.")
parser.add_argument("--do_sample", action="store_true", help="Wheter to generate doing multinomial sampling.")
parser.add_argument("--dataset_name", type=str, default="tau/scrolls", help= "The preference dataset to use.")
parser.add_argument("--dataset_config_name", type=str, default="gov_report", help= "The preference dataset config to use.")
args = parser.parse_args()

# Instantiate model
if args.peft_model_name is not None:
model = (
AutoPeftModelForCausalLM.from_pretrained(
args.peft_model_name,
device_map="auto",
torch_dtype=torch.bfloat16,
trust_remote_code=True,
)
.merge_and_unload()
.eval()
)
base_model_name = PeftConfigMixin.from_pretrained(
args.peft_model_name
).base_model_name_or_path
tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)
else:
model = AutoModelForCausalLM.from_pretrained(
args.model_name,
device_map="auto",
torch_dtype=torch.bfloat16,
trust_remote_code=True,
).eval()
tokenizer = AutoTokenizer.from_pretrained(args.model_name, trust_remote_code=True)

model.generation_config.pad_token_id = model.generation_config.eos_token_id
tokenizer.pad_token = tokenizer.eos_token

# Load dataset
dataset = load_dataset(
args.dataset_name,
args.dataset_config_name,
use_auth_token=True,
num_proc=4,
split="validation"
)
column_names = dataset.features

def tokenize_function(examples):
output_texts = []
for i in range(len(examples["input"])):
output_texts.append(
f"### Summarize the following text:\n {examples['input'][i]}\n ### Summary:\n "
)
input_ids = tokenizer(output_texts).input_ids

return {"input_ids": input_ids, "ground_truth": examples["output"]}


test_dataset = dataset.map(
tokenize_function,
batched=True,
num_proc=2,
remove_columns=column_names,
)


def filter_function(examples):
to_keep = []
for i in range(len(examples["input_ids"])):
if len(examples["input_ids"][i]) > args.seq_length - args.max_new_tokens:
to_keep.append(False)
else:
to_keep.append(True)
return to_keep


test_dataset = test_dataset.filter(
filter_function,
batched=True,
num_proc=2,
)
print(f"Size of the test set: {len(test_dataset)}.")


@dataclass
class CustomDataCollator:
tokenizer: PreTrainedTokenizerBase
padding: Union[bool, str] = True
max_length: Optional[int] = None
pad_to_multiple_of: Optional[int] = None
return_tensors: str = "pt"

def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
input_ids = [{"input_ids": sample["input_ids"]} for sample in features]
batch = self.tokenizer.pad(
input_ids,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=self.return_tensors,
)
batch["ground_truth"] = [sample["ground_truth"] for sample in features]
return batch


dataloader = DataLoader(
test_dataset,
batch_size=1,
collate_fn=CustomDataCollator(tokenizer),
)


def postprocess_text(preds, labels):
preds = [pred.strip() for pred in preds]
labels = [label.strip() for label in labels]

# rougeLSum expects newline after each sentence
preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]

return preds, labels


metric = evaluate.load("rouge")


def compute_metrics(generated, ground_truth):
# Some simple post-processing
decoded_preds, decoded_labels = postprocess_text(generated, ground_truth)
result = metric.compute(
predictions=decoded_preds, references=decoded_labels, use_stemmer=True
)
result = {k: round(v * 100, 4) for k, v in result.items()}
prediction_lens = [
np.count_nonzero(gen != tokenizer.pad_token_id) for gen in generated
]
result["gen_len"] = np.mean(prediction_lens)
return result


generated_sequences = []
ground_truths = []
for batch in tqdm(dataloader):
outputs = model.generate(
inputs=batch["input_ids"].to("cuda"),do_sample=args.do_sample , max_new_tokens=args.max_new_tokens
)
outputs = [
output.split("### Summary:\n ")[-1]
for output in tokenizer.batch_decode(outputs, skip_special_tokens=True)
]

print("Batch outputs:", outputs)
print("Batch ground truths:", batch["ground_truth"])
generated_sequences += outputs
ground_truths += batch["ground_truth"]
print("Current results:", compute_metrics(generated_sequences, ground_truths))

res = compute_metrics(generated_sequences, ground_truths)
print("Final results:", res)
Loading

0 comments on commit ab55445

Please sign in to comment.