-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain.py
376 lines (340 loc) · 20.6 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as tud
import torch.optim as optim
import logging
import argparse
from evaluate.evaluate import test_multi_task_learner_2
from transformers import ElectraTokenizerFast
from model import utils
from model.intensive_reading_ca import IntensiveReadingWithCrossAttention
from model.intensive_reading_ma import IntensiveReadingWithMatchAttention
from model.intensive_reading_cnn import IntensiveReadingWithConvolutionNet
from model.dataset import my_collate_fn, QuestionAnsweringDatasetConfiguration, QuestionAnsweringDataset
from model.multitask_loss_wrapper import DynamicWeightAveragingWrapper
from functools import partial
def test(valid_iterator, model, device, tokenizer, regression_loss=False):
model.eval()
cls_loss = nn.BCELoss()
if regression_loss:
start_end_loss = nn.MSELoss()
else:
start_end_loss = nn.CrossEntropyLoss()
loss_sum = 0 # loss
loss_count = 0
cls_correct_count = 0 # is impossible
cls_total_count = 0
f1_sum = 0 # F1-score
f1_count = 0
with torch.no_grad():
for data in valid_iterator:
batch_encoding, is_impossibles, start_position, end_position, _ = data
is_impossibles = utils.move_to_device(is_impossibles, device)
start_position = utils.move_to_device(start_position, device)
end_position = utils.move_to_device(end_position, device)
# minus one, because we removed [CLS] when utils.generate_question_and_passage_hidden
start_position = torch.where(start_position > 1, start_position - 1, start_position)
end_position = torch.where(end_position > 1, end_position - 1, end_position)
max_con_len, max_qus_len = utils.find_max_qus_con_length(attention_mask=batch_encoding['attention_mask'],
token_type_ids=batch_encoding['token_type_ids'],
max_length=batch_encoding['input_ids'].size(1),
)
cls_out, start_logits, end_logits = model(batch_encoding['input_ids'].to(device),
attention_mask=batch_encoding['attention_mask'].to(device),
token_type_ids=batch_encoding['token_type_ids'].to(device),
pad_idx=tokenizer.pad_token_id,
max_qus_length=max_qus_len,
max_con_length=max_con_len,
)
impossible_loss = cls_loss(cls_out, is_impossibles)
if regression_loss:
start_logits = F.softmax(start_logits, dim=-1)
start_one_hot = F.one_hot(start_position, start_logits.size(1)).float().to(device)
start_loss = start_end_loss(start_logits, start_one_hot)
end_logits = F.softmax(end_logits, dim=-1)
end_one_hot = F.one_hot(end_position, end_logits.size(1)).float().to(device)
end_loss = start_end_loss(end_logits, end_one_hot)
else:
start_loss = start_end_loss(start_logits, start_position)
end_loss = start_end_loss(end_logits, end_position)
loss = (start_loss + end_loss) / 2 + impossible_loss
loss_sum += loss.item()
loss_count += 1
predict_start = torch.argmax(start_logits, dim=-1)
predict_end = torch.argmax(end_logits, dim=-1)
cls_out = torch.argmax(cls_out, dim=-1)
cls_out = (cls_out == is_impossibles.argmax(dim=-1)).float()
cls_correct_count += torch.sum(cls_out)
cls_total_count += cls_out.size(0)
predict_start = predict_start.cpu().numpy()
predict_end = predict_end.cpu().numpy()
start_position = start_position.cpu().numpy()
end_position = end_position.cpu().numpy()
for ps, pe, rs, re in zip(predict_start, predict_end, start_position, end_position):
recall = utils.calculate_recall(ps, pe, rs, re)
precision = utils.calculate_recall(rs, re, ps, pe)
f1_sum += (recall + precision) / 2
f1_count += 1
return loss_sum / loss_count, cls_correct_count / cls_total_count, f1_sum / f1_count
def main(epoch=4, which_config='cross-attention', which_dataset='small', multitask_weight=1.0, seed=2020,
dynamic_weight_averaging=False, regression_loss=False):
torch.random.manual_seed(seed)
torch.manual_seed(seed)
# log
logger = logging.getLogger()
logger.setLevel(level=logging.INFO)
handler = logging.FileHandler("log.log")
handler.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)
# load configuration
if config == 'cnn-span-large' or config == 'cross-attention-large':
lr = 1e-5
batch_size = 24
hidden_dim = 768
which_model = 'google/electra-base-discriminator'
else:
lr = 1e-4
batch_size = 48
hidden_dim = 256
which_model = 'google/electra-small-discriminator'
# load dataset
tokenizer = ElectraTokenizerFast.from_pretrained(which_model)
if which_dataset == 'small':
config_train = QuestionAnsweringDatasetConfiguration(squad_train=True)
config_valid = QuestionAnsweringDatasetConfiguration(squad_dev=True)
else:
config_train = QuestionAnsweringDatasetConfiguration(squad_train=True, squad_dev=False, drop_train=True,
drop_dev=True, newsqa_train=True, newsqa_dev=True,
medhop_dev=True, medhop_train=True, quoref_dev=True,
quoref_train=True, wikihop_dev=True, wikihop_train=True)
config_valid = QuestionAnsweringDatasetConfiguration(squad_dev=True)
dataset_train = QuestionAnsweringDataset(config_train, tokenizer=tokenizer)
dataset_valid = QuestionAnsweringDataset(config_valid, tokenizer=tokenizer)
dataloader_train = tud.DataLoader(dataset=dataset_train, batch_size=batch_size, shuffle=True, drop_last=True,
collate_fn=partial(my_collate_fn, tokenizer=tokenizer))
dataloader_valid = tud.DataLoader(dataset=dataset_valid, batch_size=batch_size, shuffle=False, drop_last=True,
collate_fn=partial(my_collate_fn, tokenizer=tokenizer))
# initialize model
if which_config == 'cross-attention' or which_config == 'cross-attention-large':
retro_reader = IntensiveReadingWithCrossAttention(clm_model=which_model, hidden_dim=hidden_dim)
elif which_config == 'match-attention':
retro_reader = IntensiveReadingWithMatchAttention(clm_model=which_model, hidden_dim=hidden_dim)
elif which_config == 'cnn-span' or which_config == 'cnn-span-large':
retro_reader = IntensiveReadingWithConvolutionNet(clm_model=which_model, hidden_dim=hidden_dim, out_channel=100,
filter_size=3)
else:
raise Exception('Wrong config error')
retro_reader.train()
# GPU Config:
if torch.cuda.device_count() > 1:
device = torch.cuda.current_device()
retro_reader.to(device)
retro_reader = nn.DataParallel(module=retro_reader)
print('Use Multi GPUs. Number of GPUs: ', torch.cuda.device_count())
elif torch.cuda.device_count() == 1:
device = torch.cuda.current_device()
retro_reader.to(device)
print('Use 1 GPU')
else:
device = torch.device('cpu') # CPU
print("use CPU")
if torch.cuda.device_count() > 1:
if config == 'match-attention':
optimizer = optim.Adam(
[{'params': retro_reader.module.pre_trained_clm.parameters(), 'lr': lr, 'eps': 1e-6},
{'params': retro_reader.module.cls_head.parameters(), 'lr': 1e-3, 'weight_decay': 0.01},
{'params': retro_reader.module.Hq_proj.parameters(), 'lr': 1e-3, 'weight_decay': 0.01},
{'params': retro_reader.module.span_detect_layer.parameters(), 'lr': 1e-3, 'weight_decay': 0.01},
]
)
elif config == 'cross-attention' or config == 'cross-attention-large':
optimizer = optim.Adam(
[{'params': retro_reader.module.pre_trained_clm.parameters(), 'lr': lr, 'eps': 1e-6},
{'params': retro_reader.module.cls_head.parameters(), 'lr': 1e-3, 'weight_decay': 0.01},
{'params': retro_reader.module.attention.parameters(), 'lr': 1e-3, 'weight_decay': 0.01},
{'params': retro_reader.module.span_detect_layer.parameters(), 'lr': 1e-3, 'weight_decay': 0.01},
]
)
elif config == 'cnn-span' or config == 'cnn-span-large':
optimizer = optim.Adam(
[{'params': retro_reader.module.pre_trained_clm.parameters(), 'lr': lr, 'eps': 1e-6},
{'params': retro_reader.module.cls_head.parameters(), 'lr': 1e-3, 'weight_decay': 0.01},
{'params': retro_reader.module.conv.parameters(), 'lr': 1e-3, 'weight_decay': 0.01},
]
)
else:
raise Exception('Wrong config error')
else:
if config == 'match-attention':
optimizer = optim.Adam(
[{'params': retro_reader.pre_trained_clm.parameters(), 'lr': lr, 'eps': 1e-6},
{'params': retro_reader.cls_head.parameters(), 'lr': 1e-3, 'weight_decay': 0.01},
{'params': retro_reader.Hq_proj.parameters(), 'lr': 1e-3, 'weight_decay': 0.01},
{'params': retro_reader.span_detect_layer.parameters(), 'lr': 1e-3, 'weight_decay': 0.01},
]
)
elif config == 'cross-attention' or config == 'cross-attention-large':
optimizer = optim.Adam(
[{'params': retro_reader.pre_trained_clm.parameters(), 'lr': lr, 'eps': 1e-6},
{'params': retro_reader.cls_head.parameters(), 'lr': 1e-3, 'weight_decay': 0.01},
{'params': retro_reader.attention.parameters(), 'lr': 1e-3, 'weight_decay': 0.01},
{'params': retro_reader.span_detect_layer.parameters(), 'lr': 1e-3, 'weight_decay': 0.01},
]
)
elif config == 'cnn-span' or config == 'cnn-span-large':
optimizer = optim.Adam(
[{'params': retro_reader.pre_trained_clm.parameters(), 'lr': lr, 'eps': 1e-6},
{'params': retro_reader.cls_head.parameters(), 'lr': 1e-3, 'weight_decay': 0.01},
{'params': retro_reader.conv.parameters(), 'lr': 1e-3, 'weight_decay': 0.01},
]
)
else:
raise Exception('Wrong config error')
dynamic_loss = DynamicWeightAveragingWrapper(scale=multitask_weight)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.5) # learning rate decay (*0.5)
cls_loss = nn.BCELoss() # Binary Cross Entropy Loss
if regression_loss:
start_end_loss = nn.MSELoss() # regression loss
else:
start_end_loss = nn.CrossEntropyLoss() # classification loss
best_score = 0.25
'''
for e in range(epoch):
for i, data in enumerate(iter(dataloader_train)):
batch_encoding, is_impossibles, start_position, end_position, _ = data
retro_reader.train()
is_impossibles = utils.move_to_device(is_impossibles, device)
start_position = utils.move_to_device(start_position, device)
end_position = utils.move_to_device(end_position, device)
# minus one, because we removed [CLS] when utils.generate_question_and_passage_hidden
start_position = torch.where(start_position > 1, start_position - 1, start_position)
end_position = torch.where(end_position > 1, end_position - 1, end_position)
max_con_len, max_qus_len = utils.find_max_qus_con_length(attention_mask=batch_encoding['attention_mask'],
token_type_ids=batch_encoding['token_type_ids'],
max_length=batch_encoding['input_ids'].size(1),
)
model_output = retro_reader(batch_encoding['input_ids'].to(device),
attention_mask=batch_encoding['attention_mask'].to(device),
token_type_ids=batch_encoding['token_type_ids'].to(device),
pad_idx=tokenizer.pad_token_id,
max_qus_length=max_qus_len,
max_con_length=max_con_len,
)
cls_output, start_logits, end_logits = model_output
start_loss = start_end_loss(start_logits, start_position)
end_loss = start_end_loss(end_logits, end_position)
answerable_loss = cls_loss(cls_output, is_impossibles)
printable = (((start_loss + end_loss) / 2).item(), answerable_loss.item())
loss = (start_loss + end_loss) / 2 + answerable_loss * multitask_weight
optimizer.zero_grad()
loss.backward()
optimizer.step()
if i % 1000 == 0:
logger.info('Epoch {}, Iteration {}, Span Loss: {:.4f}, Ans Loss {:.4f}'.format(e, i, printable[0],
printable[1]))
v_loss_intensive, acc, f1 = test(iter(dataloader_valid), retro_reader, device, tokenizer)
logger.info('Epoch {}, Iteration {}, Intensive valid loss {:.4f}, CLS acc {:.4f}, F1-score {:.4f}'
.format(e, i, v_loss_intensive, acc, f1))
score = acc * f1
if score >= best_score: # save the best model
best_score = score
torch.save(retro_reader.state_dict(), 'retro_reader.pth')
'''
model_dict = retro_reader.state_dict()
previous_dict = torch.load('model_parameters.pth')
previous_dict = {k: v for k, v in previous_dict.items() if k in model_dict}
model_dict.update(previous_dict)
retro_reader.load_state_dict(model_dict)
# refine our model with cross-attention / match-attention
logger.info('-------------------------------------------------------------------------')
for e in range(epoch):
for i, data in enumerate(iter(dataloader_train)):
batch_encoding, is_impossibles, start_position, end_position, _ = data
retro_reader.train()
is_impossibles = utils.move_to_device(is_impossibles, device)
start_position = utils.move_to_device(start_position, device)
end_position = utils.move_to_device(end_position, device)
# minus one, because we removed [CLS] when utils.generate_question_and_passage_hidden
start_position = torch.where(start_position > 1, start_position - 1, start_position)
end_position = torch.where(end_position > 1, end_position - 1, end_position)
max_con_len, max_qus_len = utils.find_max_qus_con_length(attention_mask=batch_encoding['attention_mask'],
token_type_ids=batch_encoding['token_type_ids'],
max_length=batch_encoding['input_ids'].size(1),
)
cls_output, start_logits, end_logits = retro_reader(batch_encoding['input_ids'].to(device),
attention_mask=batch_encoding['attention_mask']
.to(device),
token_type_ids=batch_encoding['token_type_ids']
.to(device),
pad_idx=tokenizer.pad_token_id,
max_qus_length=max_qus_len,
max_con_length=max_con_len,
)
if regression_loss:
start_logits = F.softmax(start_logits, dim=-1)
start_one_hot = F.one_hot(start_position, start_logits.size(1)).float().to(device)
start_loss = start_end_loss(start_logits, start_one_hot)
end_logits = F.softmax(end_logits, dim=-1)
end_one_hot = F.one_hot(end_position, end_logits.size(1)).float().to(device)
end_loss = start_end_loss(end_logits, end_one_hot)
else:
start_loss = start_end_loss(start_logits, start_position)
end_loss = start_end_loss(end_logits, end_position)
answerable_loss = cls_loss(cls_output, is_impossibles)
printable = (((start_loss + end_loss) / 2).item(), answerable_loss.item())
if dynamic_weight_averaging:
loss = dynamic_loss.loss((start_loss + end_loss) / 2, answerable_loss)
else:
loss = (start_loss + end_loss) / 2 + answerable_loss * multitask_weight
optimizer.zero_grad()
loss.backward()
optimizer.step()
if i % 1000 == 0:
logger.info('Epoch {}, Span Loss: {:.4f}, Ans Loss {:.4f}'
.format(e, printable[0], printable[1]))
v_loss_intensive, acc, f1 = test(iter(dataloader_valid), retro_reader, device, tokenizer,
regression_loss=regression_loss)
logger.info('Epoch {}, Intensive valid loss {:.4f}, CLS acc {:.4f}, F1-score {:.4f}'
.format(e, v_loss_intensive, acc, f1))
score = acc * f1
if score >= best_score: # save the best model
best_score = score
torch.save(retro_reader.state_dict(), 'retro_reader.pth')
scheduler.step()
# test our model
logger.info('-------------------------------------------------------------------------')
retro_reader.load_state_dict(torch.load('retro_reader.pth'))
test_multi_task_learner_2(iter(dataloader_valid), retro_reader, device, tokenizer)
torch.save(retro_reader.module.state_dict(), 'single_gpu_model.pth')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config', type=str, help='which config')
parser.add_argument('-d', '--dataset', type=str, help='train on which dataset')
parser.add_argument('-w', '--multitask-weight', type=float, default=1.0, help='learn [CLS] and span jointly, given '
'the loss weight')
parser.add_argument('-dw', '--dynamic-weight', action='store_true', default=False, help='dynamic weight averaging')
parser.add_argument('-rl', '--regression-loss', action='store_true', default=False, help='using MSE loss')
parser.add_argument('-s', '--seed', type=int, default=2020, help='random seed')
args = parser.parse_args()
config = args.config
dataset = args.dataset
weight = args.multitask_weight
seed = args.seed
dynamic_weight = args.dynamic_weight
regression_loss = args.regression_loss
CONFIG = ['cross-attention', 'match-attention', 'cnn-span', 'cnn-span-large', 'cross-attention-large']
DATASET = ['small', 'normal']
assert config in CONFIG, 'Given config wrong'
assert dataset in DATASET, 'Given dataset wrong'
assert weight > 0, 'Given weight should be larger than zero'
print('Experiment config, {}'.format(config))
print('Dataset size, {}'.format(dataset))
print('Multi-task weight, 1: {}'.format(weight))
print('Random seed: {}'.format(seed))
print('dynamic_weight, {}'.format(dynamic_weight))
print('regression_loss, {}'.format(regression_loss))
main(epoch=4, which_config=config, which_dataset=dataset, multitask_weight=weight, seed=seed,
dynamic_weight_averaging=dynamic_weight, regression_loss=regression_loss)