-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrainer.py
144 lines (116 loc) · 5.59 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import functools
import logging
import os
import torch
from torch.optim import Adam
from tqdm import tqdm
from transformers import WEIGHTS_NAME, AdamW, get_linear_schedule_with_warmup
from torch.cuda.amp import autocast, GradScaler
logger = logging.getLogger(__name__)
class Trainer(object):
def __init__(self, args, model, dataloader, num_train_steps, writer=None, step_callback=None):
self.args = args
self.model = model
self.dataloader = dataloader
self.num_train_steps = num_train_steps
self.writer = writer
self.step_callback = step_callback
self.optimizer = self._create_optimizer(model)
self.scheduler = self._create_scheduler(self.optimizer)
def train(self):
model = self.model
dataloader = self.dataloader
scaler = GradScaler()
epoch = 0
global_step = 0
tr_loss = 0.0
model.train()
with tqdm(total=self.num_train_steps, disable=self.args.local_rank not in (-1, 0)) as pbar:
while True:
for step, batch in enumerate(dataloader):
model.train()
inputs = dict()
inputs = {k: v.to(self.args.device) for k, v in self._create_model_arguments(batch).items() \
if k not in ["history_entities", "raw_user_input", "triplets_ids", "triplets_label"]}
if self.args.fp16:
with autocast():
outputs = model(**inputs)
loss = outputs[0]
else:
outputs = model(**inputs)
loss = outputs[0]
if type(loss) == dict:
loss = loss['total_loss']
if self.args.gradient_accumulation_steps > 1:
loss = loss / self.args.gradient_accumulation_steps
if self.args.fp16:
scaler.scale(loss).backward() # fp16
else:
loss.backward()
tr_loss += loss.item()
# Gradient Accumulation
if (step + 1) % self.args.gradient_accumulation_steps == 0:
if self.args.fp16:
scaler.unscale_(self.optimizer) # fp16
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
if self.args.fp16:
scaler.step(self.optimizer) # fp16
else:
self.optimizer.step()
if self.args.use_contrastive:
if hasattr(model, "module"):
model.module.batch_scale.data = \
torch.clamp(model.module.batch_scale.data, 0, 4.6052)
model.module.self_scale.data = \
torch.clamp(model.module.self_scale.data, 0, 4.6052)
else:
model.batch_scale.data = torch.clamp(model.batch_scale.data, 0, 4.6052)
model.self_scale.data = torch.clamp(model.self_scale.data, 0, 4.6052)
model.zero_grad()
if self.args.fp16:
scaler.update()
self.scheduler.step()
pbar.set_description("epoch: %d loss: %.7f" % (epoch, loss.item()))
pbar.update()
global_step += 1
if self.step_callback is not None:
self.step_callback(model, global_step)
if (
self.args.local_rank in (-1, 0)
and self.args.output_dir
and self.args.save_steps > 0
and global_step % self.args.save_steps == 0
):
output_dir = os.path.join(self.args.output_dir, "checkpoint-{}".format(global_step))
if hasattr(model, "module"):
model.save_pretrained(output_dir)
else:
model.save_pretrained(output_dir)
if global_step == self.num_train_steps:
break
if global_step == self.num_train_steps:
break
epoch += 1
logger.info("global_step = %s, average loss = %s", global_step, tr_loss / global_step)
return model, global_step, tr_loss / global_step
def _create_optimizer(self, model):
model_params = [(n, p) for n, p in model.named_parameters()]
no_decay = ["bias", "LayerNorm.weight"]
optimizer = AdamW(
[
{
"params": [p for n, p in model_params
if not any(nd in n for nd in no_decay)
and p.requires_grad],
"weight_decay": self.args.weight_decay,
"lr": self.args.learning_rate,
},
],
eps=self.args.adam_epsilon
)
return optimizer
def _create_scheduler(self, optimizer, warmup_steps=False):
warmup_steps = int(self.num_train_steps * 0.06)
return get_linear_schedule_with_warmup(optimizer, warmup_steps, self.num_train_steps)
def _create_model_arguments(self, batch):
return batch