From 11f3aa47575a8c935316fec9e0f29e76e427b07e Mon Sep 17 00:00:00 2001 From: girinman Date: Sun, 11 Feb 2024 22:34:45 +0900 Subject: [PATCH 1/2] feat: Add callback saving trainable params --- fine-tune.py | 2 ++ save_callback.py | 44 +++++++++++++++++++++++++++++++++++ supervised-fine-tune-qlora.py | 2 ++ supervised-fine-tune.py | 2 ++ 4 files changed, 50 insertions(+) create mode 100644 save_callback.py diff --git a/fine-tune.py b/fine-tune.py index 5e738a75..b8580563 100644 --- a/fine-tune.py +++ b/fine-tune.py @@ -27,6 +27,7 @@ from gptneox_attn_replace import replace_gpt_neox_attn from peft import LoraConfig, get_peft_model from torch.distributed import barrier +from save_callback import SavePeftModelCallback from datasets import load_dataset @@ -202,6 +203,7 @@ def train(): train_dataset=dataset["train"], eval_dataset=None, data_collator=data_collator) + trainer.add_callback(SavePeftModelCallback) trainer.train() trainer.save_state() trainer.save_model(output_dir=training_args.output_dir) diff --git a/save_callback.py b/save_callback.py new file mode 100644 index 00000000..a0f38a89 --- /dev/null +++ b/save_callback.py @@ -0,0 +1,44 @@ +import os +import logging +import torch + +from transformers import ( + TrainerCallback, + TrainingArguments, + TrainerState, + TrainerControl, +) + +PREFIX_CHECKPOINT_DIR = "step" + +class SavePeftModelCallback(TrainerCallback): + def on_save( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}") + os.makedirs(checkpoint_folder, exist_ok=True) + + modules_to_save = [] + for module_name in args.trainable_params.split(","): + if len(module_name.strip()) > 0: + modules_to_save.append(module_name) + + # Save trainable parameters if exist + if modules_to_save: + state_dict = kwargs["model"].state_dict() + to_save = {} + for key, value in state_dict.items(): + if any(module_name in key for module_name in modules_to_save): + to_save[key.replace("base_model.model.", "")] = value + torch.save(to_save, os.path.join(checkpoint_folder, "trainable_params.bin")) + logging.info(f"Trainable parameters saved at: {checkpoint_folder}") + + # Save LoRA adapter weight + kwargs["model"].save_pretrained(checkpoint_folder) + logging.info(f"LoRA adapter weights saved at: {checkpoint_folder}") + + return control diff --git a/supervised-fine-tune-qlora.py b/supervised-fine-tune-qlora.py index 0a648d18..5a8c9f22 100644 --- a/supervised-fine-tune-qlora.py +++ b/supervised-fine-tune-qlora.py @@ -31,6 +31,7 @@ from gptneox_attn_replace import replace_gpt_neox_attn from peft import LoraConfig, get_peft_model from torch.distributed import barrier +from save_callback import SavePeftModelCallback IGNORE_INDEX = -100 DEFAULT_PAD_TOKEN = "[PAD]" @@ -350,6 +351,7 @@ def forward(self, x): model.gradient_checkpointing_enable() # enable gradient checkpointing trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module) + trainer.add_callback(SavePeftModelCallback) trainer.train() trainer.save_state() trainer.save_model(output_dir=training_args.output_dir) diff --git a/supervised-fine-tune.py b/supervised-fine-tune.py index ab314139..08af3fbb 100644 --- a/supervised-fine-tune.py +++ b/supervised-fine-tune.py @@ -30,6 +30,7 @@ from gptneox_attn_replace import replace_gpt_neox_attn from peft import LoraConfig, get_peft_model from torch.distributed import barrier +from save_callback import SavePeftModelCallback IGNORE_INDEX = -100 DEFAULT_PAD_TOKEN = "[PAD]" @@ -316,6 +317,7 @@ def train(): model.gradient_checkpointing_enable() # enable gradient checkpointing trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module) + trainer.add_callback(SavePeftModelCallback) trainer.train() trainer.save_state() trainer.save_model(output_dir=training_args.output_dir) From db876685e47d6c97582c91b0d981e2746eb78ce5 Mon Sep 17 00:00:00 2001 From: girinman Date: Sun, 11 Feb 2024 23:12:59 +0900 Subject: [PATCH 2/2] feat: save & load model configs --- merge_lora_weights_and_save_hf_model.py | 9 +++++++++ save_callback.py | 6 +++++- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/merge_lora_weights_and_save_hf_model.py b/merge_lora_weights_and_save_hf_model.py index 7f20da48..de574467 100644 --- a/merge_lora_weights_and_save_hf_model.py +++ b/merge_lora_weights_and_save_hf_model.py @@ -64,9 +64,18 @@ def main(args): print("base model", args.base_model) print("peft model", args.peft_model) + # Load config from peft model dir if exists + # In order to reuse the rope scaling configurations + config_path = os.path.join(args.peft_model, "config.json") + if os.path.isfile(config_path): + config = transformers.AutoConfig.from_pretrained(config_path) + else: + config = transformers.AutoConfig.from_pretrained(args.base_model) + # Load model and tokenizer model = transformers.AutoModelForCausalLM.from_pretrained( args.base_model, + config=config, cache_dir=args.cache_dir, torch_dtype=torch.float16, device_map="auto", diff --git a/save_callback.py b/save_callback.py index a0f38a89..634e194e 100644 --- a/save_callback.py +++ b/save_callback.py @@ -38,7 +38,11 @@ def on_save( logging.info(f"Trainable parameters saved at: {checkpoint_folder}") # Save LoRA adapter weight - kwargs["model"].save_pretrained(checkpoint_folder) + kwargs["model"].config.save_pretrained(checkpoint_folder) logging.info(f"LoRA adapter weights saved at: {checkpoint_folder}") + # Save model config in order to reuse rope scaling settings + kwargs["model"].save_pretrained(checkpoint_folder) + logging.info(f"Model config saved at: {checkpoint_folder}") + return control