-
Notifications
You must be signed in to change notification settings - Fork 3
/
train.py
255 lines (224 loc) · 9.79 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
__author__='thiagocastroferreira'
import json
import argparse
from models.bartgen import BARTGen
from models.bert import BERTGen
from models.gportuguesegen import GPorTugueseGen
from models.t5gen import T5Gen
from models.gpt2 import GPT2
from models.blenderbot import Blenderbot
from torch.utils.data import DataLoader, Dataset
import nltk
from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction
nltk.download('punkt')
import os
import torch
from torch import optim
class Trainer:
'''
Module for training a generative neural model
'''
def __init__(self, model, trainloader, devdata, optimizer, epochs, \
batch_status, device, write_path, early_stop=5, verbose=True, language='english'):
'''
params:
---
model: model to be trained
trainloader: training data
devdata: dev data
optimizer
epochs: number of epochs
batch_status: update the loss after each 'batch_status' updates
device: cpu or gpy
write_path: folder to save best model
early_stop
verbose
language
'''
self.model = model
self.optimizer = optimizer
self.epochs = epochs
self.batch_status = batch_status
self.device = device
self.early_stop = early_stop
self.verbose = verbose
self.trainloader = trainloader
self.devdata = devdata
self.write_path = write_path
self.language = language
if not os.path.exists(write_path):
os.mkdir(write_path)
def train(self):
'''
Train model based on the parameters specified in __init__ function
'''
max_bleu, repeat = 0, 0
for epoch in range(self.epochs):
self.model.model.train()
losses = []
for batch_idx, inp in enumerate(self.trainloader):
intents, texts = inp['X'], inp['y']
self.optimizer.zero_grad()
# generating
output = self.model(intents, texts)
# Calculate loss
loss = output.loss
losses.append(float(loss))
# Backpropagation
loss.backward()
self.optimizer.step()
# Display
if (batch_idx+1) % self.batch_status == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tTotal Loss: {:.6f}'.format(epoch, \
batch_idx+1, len(self.trainloader), 100. * batch_idx / len(self.trainloader),
float(loss), round(sum(losses) / len(losses), 5)))
bleu, acc = self.evaluate()
checkpoint = { 'epoch': epoch+1, 'bleu': bleu, 'acc': acc, 'best_model': False }
print('BLEU: ', bleu, 'Accuracy: ', acc)
if bleu > max_bleu:
self.model.model.save_pretrained(os.path.join(self.write_path, 'model'))
max_bleu = bleu
repeat = 0
checkpoint['best_model'] = True
print('Saving best model...')
else:
repeat += 1
if repeat == self.early_stop:
break
# saving checkpoint
if os.path.exists(f"{self.write_path}/checkpoint.json"):
checkpoints = json.load(open(f"{self.write_path}/checkpoint.json"))
checkpoints['checkpoints'].append(checkpoint)
else:
checkpoints = { 'checkpoints': [checkpoint] }
json.dump(checkpoints, open(f"{self.write_path}/checkpoint.json", 'w'), separators=(',', ':'), sort_keys=True, indent=4)
def evaluate(self):
'''
Evaluating the model in devset after each epoch
'''
self.model.model.eval()
results = {}
for batch_idx, inp in enumerate(self.devdata):
intent, text = inp['X'], inp['y']
if intent not in results:
results[intent] = { 'hyp': '', 'refs': [] }
# predict
output = self.model([intent])
results[intent]['hyp'] = output[0]
# Display
if (batch_idx+1) % self.batch_status == 0:
print('Evaluation: [{}/{} ({:.0f}%)]'.format(batch_idx+1, \
len(self.devdata), 100. * batch_idx / len(self.devdata)))
results[intent]['refs'].append(text)
hyps, refs, acc = [], [], 0
for i, intent in enumerate(results.keys()):
if i < 20 and self.verbose:
print('Real: ', results[intent]['refs'][0])
print('Pred: ', results[intent]['hyp'])
print()
if self.language != 'english':
hyps.append(nltk.word_tokenize(results[intent]['hyp'], language=self.language))
refs.append([nltk.word_tokenize(ref, language=self.language) for ref in results[intent]['refs']])
else:
hyps.append(nltk.word_tokenize(results[intent]['hyp']))
refs.append([nltk.word_tokenize(ref) for ref in results[intent]['refs']])
if results[intent]['hyp'] in results[intent]['refs'][0]:
acc += 1
chencherry = SmoothingFunction()
bleu = corpus_bleu(refs, hyps, smoothing_function=chencherry.method3)
return bleu, float(acc) / len(results)
class NewsDataset(Dataset):
def __init__(self, data):
"""
Args:
data (string): data
"""
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
def load_data(src_fname, trg_fname):
with open(src_fname) as f:
src = f.read().split('\n')
with open(trg_fname) as f:
trg = f.read().split('\n')
assert len(src) == len(trg)
data = [{ 'X': src[i], 'y': trg[i] } for i in range(len(src))]
return data
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--tokenizer", help="path to the tokenizer", required=True)
parser.add_argument("--model", help="path to the model", required=True)
parser.add_argument("--src_train", help="path to the source train data", required=True)
parser.add_argument("--trg_train", help="path to the target train data", required=True)
parser.add_argument("--src_dev", help="path to the source dev data", required=True)
parser.add_argument("--trg_dev", help="path to the target dev data", required=True)
parser.add_argument("--epochs", help="number of epochs", type=int, default=5)
parser.add_argument("--learning_rate", help="learning rate", type=float, default=1e-5)
parser.add_argument("--batch_size", help="batch size", type=int, default=16)
parser.add_argument("--early_stop", help="earling stop", type=int, default=3)
parser.add_argument("--max_length", help="maximum length to be processed by the network", type=int, default=180)
parser.add_argument("--write_path", help="path to write best model", required=True)
parser.add_argument("--language", help="language", default='english')
parser.add_argument("--verbose", help="should display the loss?", action="store_true")
parser.add_argument("--batch_status", help="display of loss", type=int)
parser.add_argument("--cuda", help="use CUDA", action="store_true")
parser.add_argument("--src_lang", help="source language of mBART tokenizer", default='pt_XX')
parser.add_argument("--trg_lang", help="target language of mBART tokenizer", default='pt_XX')
args = parser.parse_args()
# settings
learning_rate = args.learning_rate
epochs = args.epochs
batch_size = args.batch_size
batch_status = args.batch_status
early_stop =args.early_stop
language = args.language
try:
verbose = args.verbose
except:
verbose = False
try:
device = 'cuda' if args.cuda else 'cpu'
except:
device = 'cpu'
write_path = args.write_path
# model
max_length = args.max_length
tokenizer_path = args.tokenizer
model_path = args.model
if 'mbart' in tokenizer_path:
src_lang = args.src_lang
trg_lang = args.trg_lang
generator = BARTGen(tokenizer_path, model_path, max_length, device, True, src_lang, trg_lang)
elif 'bart' in tokenizer_path:
generator = BARTGen(tokenizer_path, model_path, max_length, device, False)
elif 'bert' in tokenizer_path:
generator = BERTGen(tokenizer_path, model_path, max_length, device)
elif 'mt5' in tokenizer_path:
generator = T5Gen(tokenizer_path, model_path, max_length, device, True)
elif 't5' in tokenizer_path:
generator = T5Gen(tokenizer_path, model_path, max_length, device, False)
elif 'gpt2-small-portuguese' in tokenizer_path:
generator = GPorTugueseGen(tokenizer_path, model_path, max_length, device)
elif tokenizer_path == 'gpt2':
generator = GPT2(tokenizer_path, model_path, max_length, device)
elif 'blenderbot' in tokenizer_path:
generator = Blenderbot(tokenizer_path, model_path, max_length, device)
else:
raise Exception("Invalid model")
# train data
src_fname = args.src_train
trg_fname = args.trg_train
data = load_data(src_fname, trg_fname)
dataset = NewsDataset(data)
trainloader = DataLoader(dataset, batch_size=batch_size)
# dev data
src_fname = args.src_dev
trg_fname = args.trg_dev
devdata = load_data(src_fname, trg_fname)
# optimizer
optimizer = optim.AdamW(generator.model.parameters(), lr=learning_rate)
# trainer
trainer = Trainer(generator, trainloader, devdata, optimizer, epochs, batch_status, device, write_path, early_stop, verbose, language)
trainer.train()