Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Finetuning feature added for setting vision_lr and resampler_lr #521

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 22 additions & 8 deletions finetune/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ class TrainingArguments(transformers.TrainingArguments):
llm_type: str = field(default="minicpm")
use_lora: Optional[bool] = field(default=False)
max_slice_nums: Optional[int] = field(default=9)
vision_lr: Optional[float] = None
resampler_lr: Optional[float] = None


@dataclass
Expand All @@ -74,12 +76,25 @@ def rank0_print(*args):
if local_rank == 0:
print(*args)


def safe_save_model_for_hf_trainer(trainer, output_dir: str, bias="none"):
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer,
output_dir: str):
"""Collects the state dict and dump to disk."""
if trainer.args.should_save and trainer.args.local_rank == 0:
trainer.save_model(output_dir,)

if trainer.deepspeed:
trainer.accelerator.wait_for_everyone()
torch.cuda.synchronize()
trainer.save_model(output_dir)
return

state_dict = trainer.model.state_dict()
if trainer.args.should_save:
cpu_state_dict = {
key: value.cpu()
for key, value in state_dict.items()
}
del state_dict
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
trainer.model.config.save_pretrained(output_dir)

def make_supervised_data_module(
tokenizer: transformers.PreTrainedTokenizer,
Expand Down Expand Up @@ -202,6 +217,7 @@ def train():
trust_remote_code=True,
torch_dtype=compute_dtype,
device_map=device_map,
attn_implementation="flash_attention_2"
)

tokenizer = AutoTokenizer.from_pretrained(
Expand Down Expand Up @@ -250,7 +266,6 @@ def get_input_embeddings(self):

rank0_print(f'llm_type={llm_type}')


# Load data
if hasattr(model.config, "slice_config"):
model.config.slice_config.max_slice_nums = training_args.max_slice_nums
Expand Down Expand Up @@ -291,9 +306,8 @@ def get_input_embeddings(self):

safe_save_model_for_hf_trainer(
trainer=trainer,
output_dir=training_args.output_dir,
bias=lora_args.lora_bias)
output_dir=training_args.output_dir)


if __name__ == "__main__":
train()
train()
4 changes: 3 additions & 1 deletion finetune/finetune_ds.sh
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ torchrun $DISTRIBUTED_ARGS finetune.py \
--save_strategy "steps" \
--save_steps 1000 \
--save_total_limit 10 \
--learning_rate 1e-6 \
--learning_rate 1e-5 \
--vision_lr 2e-6 \
--resampler_lr 2e-6 \
--weight_decay 0.1 \
--adam_beta2 0.95 \
--warmup_ratio 0.01 \
Expand Down
8 changes: 6 additions & 2 deletions finetune/finetune_lora.sh
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ torchrun $DISTRIBUTED_ARGS finetune.py \
--tune_vision true \
--tune_llm false \
--use_lora true \
--lora_target_modules "llm\..*layers\.\d+\.self_attn\.(q_proj|k_proj|v_proj|o_proj)" \
--lora_r 64 \
--lora_alpha 128 \
--lora_target_modules "llm\..*layers\.\d+\.(self_attn|mlp)\.(q_proj|k_proj|v_proj|o_proj|gate_proj|up_proj|down_proj)" \
--model_max_length $MODEL_MAX_Length \
--max_slice_nums 9 \
--max_steps 10000 \
Expand All @@ -56,7 +58,9 @@ torchrun $DISTRIBUTED_ARGS finetune.py \
--save_strategy "steps" \
--save_steps 1000 \
--save_total_limit 10 \
--learning_rate 1e-6 \
--learning_rate 1e-4 \
--vision_lr 2e-6 \
--resampler_lr 2e-6 \
--weight_decay 0.1 \
--adam_beta2 0.95 \
--warmup_ratio 0.01 \
Expand Down
81 changes: 81 additions & 0 deletions finetune/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,87 @@


class CPMTrainer(Trainer):

def create_optimizer(self):
"""
Setup the optimizer.

We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
Trainer's init through `optimizers`, or subclass and override this method in a subclass.
"""
if is_sagemaker_mp_enabled():
return super().create_optimizer()

opt_model = self.model

if self.optimizer is None:
decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
decay_parameters = [name for name in decay_parameters if "bias" not in name]
lr_mapper = {}
if self.args.resampler_lr is not None:
lr_mapper["resampler"] = self.args.resampler_lr
if self.args.vision_lr is not None:
lr_mapper["vpm"] = self.args.vision_lr
if len(lr_mapper) > 0:
special_lr_parameters = [name for name, _ in opt_model.named_parameters() if any(module_keyword in name for module_keyword in lr_mapper)]
optimizer_grouped_parameters = [
{
"params": [p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in special_lr_parameters and p.requires_grad)],
"weight_decay": self.args.weight_decay,
},
{
"params": [p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in special_lr_parameters and p.requires_grad)],
"weight_decay": 0.0,
},
]
for module_keyword, lr in lr_mapper.items():
module_parameters = [name for name, _ in opt_model.named_parameters() if module_keyword in name]
optimizer_grouped_parameters.extend(
[
{
"params": [p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in module_parameters and p.requires_grad)],
"weight_decay": self.args.weight_decay,
"lr": lr,
},
{
"params": [p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in module_parameters and p.requires_grad)],
"weight_decay": 0.0,
"lr": lr,
},
]
)
else:
optimizer_grouped_parameters = [
{
"params": [p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)],
"weight_decay": self.args.weight_decay,
},
{
"params": [p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)],
"weight_decay": 0.0,
},
]

optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)

self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if optimizer_cls.__name__ == "Adam8bit":
import bitsandbytes

manager = bitsandbytes.optim.GlobalOptimManager.get_instance()

skipped = 0
for module in opt_model.modules():
if isinstance(module, nn.Embedding):
skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
logger.info(f"skipped {module}: {skipped/2**20}M params")
manager.register_module_override(module, "weight", {"optim_bits": 32})
logger.debug(f"bitsandbytes: will optimize {module} in fp32")
logger.info(f"skipped: {skipped/2**20}M params")

return self.optimizer


def compute_loss(self, model, inputs, return_outputs=False):
if "labels" in inputs:
labels = inputs.pop("labels")
Expand Down