diff --git a/src/loops.py b/src/loops.py index 8c167f0..f7bb7e3 100644 --- a/src/loops.py +++ b/src/loops.py @@ -89,6 +89,8 @@ def training_loop( if flag_save: atexit.register(goodbye, save_dir) + bert_params, task_params = model.get_params() + # Epoch level for e in range(epochs_last_run + 1, epochs + epochs_last_run + 1): @@ -106,7 +108,7 @@ def training_loop( for i, instance in enumerate(tqdm(trn_dataset)): # Reset the gradients. - opt.zero_grad() + model.zero_grad() instance["prep_coref_eval"] = True @@ -128,11 +130,13 @@ def training_loop( # Clip Gradients if clip_grad_norm > 0: - torch.nn.utils.clip_grad_norm_([param for group in opt.param_groups for param in group['params']], - clip_grad_norm) + torch.nn.utils.clip_grad_norm_(bert_params, clip_grad_norm) + torch.nn.utils.clip_grad_norm_(task_params, clip_grad_norm) # Backward Pass - opt.step() + # opt.step() + for optimizer in opt: + optimizer.step() # Throw the outputs to the eval benchmark also train_eval.update(instance=instance, outputs=outputs) @@ -141,8 +145,9 @@ def training_loop( per_epoch_loss[instance['domain']][task_nm].append(outputs["loss"][task_nm].item()) # If LR scheduler is provided, run it - if scheduler_per_iter is not None: - scheduler_per_iter.step() + if scheduler_per_iter: + scheduler_per_iter[0].step() + scheduler_per_iter[1].step() trn_dataset[i] = change_device(instance, 'cpu') @@ -151,8 +156,9 @@ def training_loop( dev_eval.run() # If LR scheduler (per epoch) is provided, run it - if scheduler_per_epoch is not None: - scheduler_per_epoch.step() + if scheduler_per_epoch: + scheduler_per_epoch[0].step() + scheduler_per_epoch[1].step() # Bookkeeping (summarise the train and valid evaluations, and the loss) train_metrics = train_eval.aggregate_reports(train_metrics, train_eval.report()) @@ -167,7 +173,7 @@ def training_loop( for k in skipped.keys(): skipped[k].append(per_epoch_skipped[k]) - lrs = [param_group['lr'] for param_group in opt.param_groups] + lrs = [pg['lr'] for pg in opt[0].param_groups] + [pg['lr'] for pg in opt[1].param_groups] if flag_wandb: wandb.log({"train": train_eval.report(), "valid": dev_eval.report()}, step=e) wandb.log({f'lr_{i}': lrs[i] for i in range(len(lrs))}, step=e) @@ -204,9 +210,16 @@ def training_loop( torch.save({ 'epochs_last_run': e, 'model_state_dict': model.state_dict(), - 'optimizer_state_dict': opt.state_dict(), - 'scheduler_per_epoch_state_dict': scheduler_per_epoch.state_dict() if scheduler_per_epoch else None, - 'scheduler_per_iter_state_dict': scheduler_per_iter.state_dict() if scheduler_per_iter else None, + 'optimizer_bert_state_dict': opt[0].state_dict(), + 'optimizer_task_state_dict': opt[1].state_dict(), + 'scheduler_per_epoch_bert_state_dict': scheduler_per_epoch[ + 0].state_dict() if scheduler_per_epoch else None, + 'scheduler_per_epoch_task_state_dict': scheduler_per_epoch[ + 0].state_dict() if scheduler_per_epoch else None, + 'scheduler_per_iter_bert_state_dict': scheduler_per_iter[ + 0].state_dict() if scheduler_per_iter else None, + 'scheduler_per_iter_task_state_dict': scheduler_per_iter[ + 1].state_dict() if scheduler_per_iter else None, }, Path(save_dir) / 'torch.save') print(f"Model saved on Epoch {e} at {save_dir}.") diff --git a/src/models/multitask.py b/src/models/multitask.py index a18fccd..25fe100 100644 --- a/src/models/multitask.py +++ b/src/models/multitask.py @@ -219,6 +219,17 @@ def __init__( self.ner_loss[task.dataset] = nn.functional.cross_entropy self.pos_loss = nn.functional.cross_entropy + def get_params(self, named=False): + bert_based_param, task_param = [], [] + for name, param in self.named_parameters(): + if name.startswith('bert'): + to_add = (name, param) if named else param + bert_based_param.append(to_add) + else: + to_add = (name, param) if named else param + task_param.append(to_add) + return bert_based_param, task_param + def task_separate_gradient_clipping(self): # noinspection PyAttributeOutsideInit self.clip_grad_norm_ = self.separate_max_norm_base_task @@ -313,7 +324,7 @@ def coref_loss(top_antecedent_scores, top_antecedent_labels): log_norm = torch.logsumexp(top_antecedent_scores, 1) # [top_cand] return log_norm - marginalized_gold_scores # [top_cand] - def get_coref_loss( + def todel_get_coref_loss( self, candidate_starts: torch.tensor, candidate_ends: torch.tensor, diff --git a/src/run.py b/src/run.py index 454627b..8918ae3 100644 --- a/src/run.py +++ b/src/run.py @@ -12,6 +12,8 @@ import transformers import wandb from mytorch.utils.goodies import mt_save_dir, FancyDict +from torch.optim import Adam, AdamW +from torch.optim.lr_scheduler import _LRScheduler # Local imports try: @@ -40,6 +42,44 @@ # torch.backends.cudnn.deterministic = True +def make_optimizer_hoi( + model: MTLModel, + base_keyword: str, + task_weight_decay: Optional[float], + task_learning_rate: Optional[float], + adam_beta1: float = 0.9, + adam_beta2: float = 0.999, + adam_epsilon: float = 1e-6, + encoder_learning_rate: float = 2e-05, + encoder_weight_decay: float = 0.0, + freeze_encoder: bool = False, + optimizer_class: Callable = torch.optim.AdamW, +): + no_decay = ['bias', 'LayerNorm.weight'] + bert_param, task_param = model.get_params(named=True) + if task_learning_rate is None: + task_learning_rate = encoder_learning_rate + if task_weight_decay is None: + task_weight_decay = encoder_weight_decay + + grouped_bert_param = [ + { + 'params': [p for n, p in bert_param if not any(nd in n for nd in no_decay)], + 'lr': encoder_learning_rate, + 'weight_decay': encoder_weight_decay + }, { + 'params': [p for n, p in bert_param if any(nd in n for nd in no_decay)], + 'lr': encoder_learning_rate, + 'weight_decay': 0.0 + } + ] + optimizers = [ + AdamW(grouped_bert_param, lr=encoder_learning_rate, eps=adam_epsilon), + Adam(model.get_params()[1], lr=task_learning_rate, eps=adam_epsilon, weight_decay=0) + ] + return optimizers + + def make_optimizer( model: MTLModel, base_keyword: str, @@ -97,17 +137,18 @@ def make_optimizer( return optimizer_class(optimizer_grouped_parameters, **optimizer_kwargs) -# noinspection PyProtectedMember -def make_scheduler(opt, lr_schedule: Optional[str], lr_schedule_val: Optional[float], n_updates: int) \ - -> Optional[Type[torch.optim.lr_scheduler._LRScheduler]]: +def make_scheduler_hoi(opts, lr_schedule: Optional[str], lr_schedule_val: Optional[float], n_updates: int) \ + -> Optional[Tuple[List[Type[_LRScheduler]], List[Type[_LRScheduler]]]]: + # TODO: implement gamma and other things as well if not lr_schedule: return None, None if lr_schedule == 'gamma': hyperparam = lr_schedule_val if lr_schedule_val >= 0 else SCHEDULER_CONFIG['gamma']['decay_rate'] lambda_1 = lambda epoch: hyperparam ** epoch - scheduler_per_epoch = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda=lambda_1) + scheduler_per_epoch = torch.optim.lr_scheduler.LambdaLR(opts, lr_lambda=lambda_1) scheduler_per_iter = None + raise NotImplementedError elif lr_schedule == 'warmup': # TODO: model both optimizers here warmup_ratio = lr_schedule_val if lr_schedule_val >= 0 else SCHEDULER_CONFIG['warmup']['warmup'] @@ -120,6 +161,38 @@ def lr_lambda_bert(current_step): 0.0, float(n_updates - current_step) / float(max(1, n_updates - warmup_steps)) ) + def lr_lambda_task(current_step): + return max(0.0, float(n_updates - current_step) / float(max(1, n_updates))) + + scheduler_per_iter = [ + torch.optim.lr_scheduler.LambdaLR(opts[0], lr_lambda_bert), + torch.optim.lr_scheduler.LambdaLR(opts[1], lr_lambda_task) + ] + scheduler_per_epoch = None + else: + raise BadParameters(f"Unknown LR Schedule Recipe Name - {lr_schedule}") + + if scheduler_per_iter is not None and scheduler_per_epoch is not None: + raise ValueError(f"Both Scheduler per iter and Scheduler per epoch are non-none. This won't fly.") + + return scheduler_per_epoch, scheduler_per_iter + + +# noinspection PyProtectedMember +def make_scheduler(opt, lr_schedule: Optional[str], lr_schedule_val: Optional[float], n_updates: int) \ + -> Optional[Type[torch.optim.lr_scheduler._LRScheduler]]: + if not lr_schedule: + return None, None + + if lr_schedule == 'gamma': + hyperparam = lr_schedule_val if lr_schedule_val >= 0 else SCHEDULER_CONFIG['gamma']['decay_rate'] + lambda_1 = lambda epoch: hyperparam ** epoch + scheduler_per_epoch = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda=lambda_1) + scheduler_per_iter = None + elif lr_schedule == 'warmup': + warmup_ratio = lr_schedule_val if lr_schedule_val >= 0 else SCHEDULER_CONFIG['warmup']['warmup'] + warmup_steps = int(n_updates * warmup_ratio) + def lr_lambda_task(current_step): return max(0.0, float(n_updates - current_step) / float(max(1, n_updates))) @@ -609,7 +682,7 @@ def train(ctx): # Make the optimizer # opt_base = torch.optim.Adam - opt = make_optimizer( + opts = make_optimizer_hoi( model=model, task_learning_rate=config.trainer.learning_rate, freeze_encoder=config.freeze_encoder, @@ -621,8 +694,8 @@ def train(ctx): adam_beta2=config.trainer.adam_beta2, adam_epsilon=config.trainer.adam_epsilon, ) - scheduler_per_epoch, scheduler_per_iter = make_scheduler( - opt=opt, + scheduler_per_epoch, scheduler_per_iter = make_scheduler_hoi( + opts=opts, lr_schedule=config.trainer.lr_schedule, lr_schedule_val=config.trainer.lr_schedule_param, n_updates=len_train * config.trainer.epochs) @@ -692,11 +765,14 @@ def train(ctx): """ We're actually resuming a run. So now we need to load params, state dicts""" checkpoint = torch.load(savedir / 'torch.save') model.load_state_dict(checkpoint['model_state_dict']) - opt.load_state_dict(checkpoint['optimizer_state_dict']) + opts[0].load_state_dict(checkpoint['optimizer_bert_state_dict']) + opts[1].load_state_dict(checkpoint['optimizer_task_state_dict']) if scheduler_per_epoch: - scheduler_per_epoch.load_state_dict(checkpoint['scheduler_per_epoch_state_dict']) + scheduler_per_epoch[0].load_state_dict(checkpoint['scheduler_per_epoch_bert_state_dict']) + scheduler_per_epoch[1].load_state_dict(checkpoint['scheduler_per_epoch_task_state_dict']) if scheduler_per_iter: - scheduler_per_iter.load_state_dict(checkpoint['scheduler_per_iter_state_dict']) + scheduler_per_iter[0].load_state_dict(checkpoint['scheduler_per_iter_bert_state_dict']) + scheduler_per_iter[1].load_state_dict(checkpoint['scheduler_per_iter_task_state_dict']) else: config.params = n_params @@ -711,7 +787,7 @@ def train(ctx): device=device, train_eval=train_eval, dev_eval=dev_eval, - opt=opt, + opt=opts, tasks=[tasks, tasks_2] if _is_multidomain else [tasks], # This is used only for bookkeeping. We're assuming empty entries in logs are fine. flag_wandb=config.wandb,