Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

On merging hoitrainer. #1

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 25 additions & 12 deletions src/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ def training_loop(
if flag_save:
atexit.register(goodbye, save_dir)

bert_params, task_params = model.get_params()

# Epoch level
for e in range(epochs_last_run + 1, epochs + epochs_last_run + 1):

Expand All @@ -106,7 +108,7 @@ def training_loop(
for i, instance in enumerate(tqdm(trn_dataset)):

# Reset the gradients.
opt.zero_grad()
model.zero_grad()

instance["prep_coref_eval"] = True

Expand All @@ -128,11 +130,13 @@ def training_loop(

# Clip Gradients
if clip_grad_norm > 0:
torch.nn.utils.clip_grad_norm_([param for group in opt.param_groups for param in group['params']],
clip_grad_norm)
torch.nn.utils.clip_grad_norm_(bert_params, clip_grad_norm)
torch.nn.utils.clip_grad_norm_(task_params, clip_grad_norm)

# Backward Pass
opt.step()
# opt.step()
for optimizer in opt:
optimizer.step()

# Throw the outputs to the eval benchmark also
train_eval.update(instance=instance, outputs=outputs)
Expand All @@ -141,8 +145,9 @@ def training_loop(
per_epoch_loss[instance['domain']][task_nm].append(outputs["loss"][task_nm].item())

# If LR scheduler is provided, run it
if scheduler_per_iter is not None:
scheduler_per_iter.step()
if scheduler_per_iter:
scheduler_per_iter[0].step()
scheduler_per_iter[1].step()

trn_dataset[i] = change_device(instance, 'cpu')

Expand All @@ -151,8 +156,9 @@ def training_loop(
dev_eval.run()

# If LR scheduler (per epoch) is provided, run it
if scheduler_per_epoch is not None:
scheduler_per_epoch.step()
if scheduler_per_epoch:
scheduler_per_epoch[0].step()
scheduler_per_epoch[1].step()

# Bookkeeping (summarise the train and valid evaluations, and the loss)
train_metrics = train_eval.aggregate_reports(train_metrics, train_eval.report())
Expand All @@ -167,7 +173,7 @@ def training_loop(

for k in skipped.keys():
skipped[k].append(per_epoch_skipped[k])
lrs = [param_group['lr'] for param_group in opt.param_groups]
lrs = [pg['lr'] for pg in opt[0].param_groups] + [pg['lr'] for pg in opt[1].param_groups]
if flag_wandb:
wandb.log({"train": train_eval.report(), "valid": dev_eval.report()}, step=e)
wandb.log({f'lr_{i}': lrs[i] for i in range(len(lrs))}, step=e)
Expand Down Expand Up @@ -204,9 +210,16 @@ def training_loop(
torch.save({
'epochs_last_run': e,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': opt.state_dict(),
'scheduler_per_epoch_state_dict': scheduler_per_epoch.state_dict() if scheduler_per_epoch else None,
'scheduler_per_iter_state_dict': scheduler_per_iter.state_dict() if scheduler_per_iter else None,
'optimizer_bert_state_dict': opt[0].state_dict(),
'optimizer_task_state_dict': opt[1].state_dict(),
'scheduler_per_epoch_bert_state_dict': scheduler_per_epoch[
0].state_dict() if scheduler_per_epoch else None,
'scheduler_per_epoch_task_state_dict': scheduler_per_epoch[
0].state_dict() if scheduler_per_epoch else None,
'scheduler_per_iter_bert_state_dict': scheduler_per_iter[
0].state_dict() if scheduler_per_iter else None,
'scheduler_per_iter_task_state_dict': scheduler_per_iter[
1].state_dict() if scheduler_per_iter else None,
}, Path(save_dir) / 'torch.save')
print(f"Model saved on Epoch {e} at {save_dir}.")

Expand Down
13 changes: 12 additions & 1 deletion src/models/multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,17 @@ def __init__(
self.ner_loss[task.dataset] = nn.functional.cross_entropy
self.pos_loss = nn.functional.cross_entropy

def get_params(self, named=False):
bert_based_param, task_param = [], []
for name, param in self.named_parameters():
if name.startswith('bert'):
to_add = (name, param) if named else param
bert_based_param.append(to_add)
else:
to_add = (name, param) if named else param
task_param.append(to_add)
return bert_based_param, task_param

def task_separate_gradient_clipping(self):
# noinspection PyAttributeOutsideInit
self.clip_grad_norm_ = self.separate_max_norm_base_task
Expand Down Expand Up @@ -313,7 +324,7 @@ def coref_loss(top_antecedent_scores, top_antecedent_labels):
log_norm = torch.logsumexp(top_antecedent_scores, 1) # [top_cand]
return log_norm - marginalized_gold_scores # [top_cand]

def get_coref_loss(
def todel_get_coref_loss(
self,
candidate_starts: torch.tensor,
candidate_ends: torch.tensor,
Expand Down
98 changes: 87 additions & 11 deletions src/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import transformers
import wandb
from mytorch.utils.goodies import mt_save_dir, FancyDict
from torch.optim import Adam, AdamW
from torch.optim.lr_scheduler import _LRScheduler

# Local imports
try:
Expand Down Expand Up @@ -40,6 +42,44 @@
# torch.backends.cudnn.deterministic = True


def make_optimizer_hoi(
model: MTLModel,
base_keyword: str,
task_weight_decay: Optional[float],
task_learning_rate: Optional[float],
adam_beta1: float = 0.9,
adam_beta2: float = 0.999,
adam_epsilon: float = 1e-6,
encoder_learning_rate: float = 2e-05,
encoder_weight_decay: float = 0.0,
freeze_encoder: bool = False,
optimizer_class: Callable = torch.optim.AdamW,
):
no_decay = ['bias', 'LayerNorm.weight']
bert_param, task_param = model.get_params(named=True)
if task_learning_rate is None:
task_learning_rate = encoder_learning_rate
if task_weight_decay is None:
task_weight_decay = encoder_weight_decay

grouped_bert_param = [
{
'params': [p for n, p in bert_param if not any(nd in n for nd in no_decay)],
'lr': encoder_learning_rate,
'weight_decay': encoder_weight_decay
}, {
'params': [p for n, p in bert_param if any(nd in n for nd in no_decay)],
'lr': encoder_learning_rate,
'weight_decay': 0.0
}
]
optimizers = [
AdamW(grouped_bert_param, lr=encoder_learning_rate, eps=adam_epsilon),
Adam(model.get_params()[1], lr=task_learning_rate, eps=adam_epsilon, weight_decay=0)
]
return optimizers


def make_optimizer(
model: MTLModel,
base_keyword: str,
Expand Down Expand Up @@ -97,17 +137,18 @@ def make_optimizer(
return optimizer_class(optimizer_grouped_parameters, **optimizer_kwargs)


# noinspection PyProtectedMember
def make_scheduler(opt, lr_schedule: Optional[str], lr_schedule_val: Optional[float], n_updates: int) \
-> Optional[Type[torch.optim.lr_scheduler._LRScheduler]]:
def make_scheduler_hoi(opts, lr_schedule: Optional[str], lr_schedule_val: Optional[float], n_updates: int) \
-> Optional[Tuple[List[Type[_LRScheduler]], List[Type[_LRScheduler]]]]:
# TODO: implement gamma and other things as well
if not lr_schedule:
return None, None

if lr_schedule == 'gamma':
hyperparam = lr_schedule_val if lr_schedule_val >= 0 else SCHEDULER_CONFIG['gamma']['decay_rate']
lambda_1 = lambda epoch: hyperparam ** epoch
scheduler_per_epoch = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda=lambda_1)
scheduler_per_epoch = torch.optim.lr_scheduler.LambdaLR(opts, lr_lambda=lambda_1)
scheduler_per_iter = None
raise NotImplementedError
elif lr_schedule == 'warmup':
# TODO: model both optimizers here
warmup_ratio = lr_schedule_val if lr_schedule_val >= 0 else SCHEDULER_CONFIG['warmup']['warmup']
Expand All @@ -120,6 +161,38 @@ def lr_lambda_bert(current_step):
0.0, float(n_updates - current_step) / float(max(1, n_updates - warmup_steps))
)

def lr_lambda_task(current_step):
return max(0.0, float(n_updates - current_step) / float(max(1, n_updates)))

scheduler_per_iter = [
torch.optim.lr_scheduler.LambdaLR(opts[0], lr_lambda_bert),
torch.optim.lr_scheduler.LambdaLR(opts[1], lr_lambda_task)
]
scheduler_per_epoch = None
else:
raise BadParameters(f"Unknown LR Schedule Recipe Name - {lr_schedule}")

if scheduler_per_iter is not None and scheduler_per_epoch is not None:
raise ValueError(f"Both Scheduler per iter and Scheduler per epoch are non-none. This won't fly.")

return scheduler_per_epoch, scheduler_per_iter


# noinspection PyProtectedMember
def make_scheduler(opt, lr_schedule: Optional[str], lr_schedule_val: Optional[float], n_updates: int) \
-> Optional[Type[torch.optim.lr_scheduler._LRScheduler]]:
if not lr_schedule:
return None, None

if lr_schedule == 'gamma':
hyperparam = lr_schedule_val if lr_schedule_val >= 0 else SCHEDULER_CONFIG['gamma']['decay_rate']
lambda_1 = lambda epoch: hyperparam ** epoch
scheduler_per_epoch = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda=lambda_1)
scheduler_per_iter = None
elif lr_schedule == 'warmup':
warmup_ratio = lr_schedule_val if lr_schedule_val >= 0 else SCHEDULER_CONFIG['warmup']['warmup']
warmup_steps = int(n_updates * warmup_ratio)

def lr_lambda_task(current_step):
return max(0.0, float(n_updates - current_step) / float(max(1, n_updates)))

Expand Down Expand Up @@ -609,7 +682,7 @@ def train(ctx):

# Make the optimizer
# opt_base = torch.optim.Adam
opt = make_optimizer(
opts = make_optimizer_hoi(
model=model,
task_learning_rate=config.trainer.learning_rate,
freeze_encoder=config.freeze_encoder,
Expand All @@ -621,8 +694,8 @@ def train(ctx):
adam_beta2=config.trainer.adam_beta2,
adam_epsilon=config.trainer.adam_epsilon,
)
scheduler_per_epoch, scheduler_per_iter = make_scheduler(
opt=opt,
scheduler_per_epoch, scheduler_per_iter = make_scheduler_hoi(
opts=opts,
lr_schedule=config.trainer.lr_schedule,
lr_schedule_val=config.trainer.lr_schedule_param,
n_updates=len_train * config.trainer.epochs)
Expand Down Expand Up @@ -692,11 +765,14 @@ def train(ctx):
""" We're actually resuming a run. So now we need to load params, state dicts"""
checkpoint = torch.load(savedir / 'torch.save')
model.load_state_dict(checkpoint['model_state_dict'])
opt.load_state_dict(checkpoint['optimizer_state_dict'])
opts[0].load_state_dict(checkpoint['optimizer_bert_state_dict'])
opts[1].load_state_dict(checkpoint['optimizer_task_state_dict'])
if scheduler_per_epoch:
scheduler_per_epoch.load_state_dict(checkpoint['scheduler_per_epoch_state_dict'])
scheduler_per_epoch[0].load_state_dict(checkpoint['scheduler_per_epoch_bert_state_dict'])
scheduler_per_epoch[1].load_state_dict(checkpoint['scheduler_per_epoch_task_state_dict'])
if scheduler_per_iter:
scheduler_per_iter.load_state_dict(checkpoint['scheduler_per_iter_state_dict'])
scheduler_per_iter[0].load_state_dict(checkpoint['scheduler_per_iter_bert_state_dict'])
scheduler_per_iter[1].load_state_dict(checkpoint['scheduler_per_iter_task_state_dict'])
else:
config.params = n_params

Expand All @@ -711,7 +787,7 @@ def train(ctx):
device=device,
train_eval=train_eval,
dev_eval=dev_eval,
opt=opt,
opt=opts,
tasks=[tasks, tasks_2] if _is_multidomain else [tasks],
# This is used only for bookkeeping. We're assuming empty entries in logs are fine.
flag_wandb=config.wandb,
Expand Down