-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtraining_app.py
293 lines (235 loc) · 11.1 KB
/
training_app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
"""
Main class to train, evaluate, and save different models on the CIFAR-10 dataset.
TODO: app_parameters, and hyperparameters would be better passed as parser arguments
TODO: Many of the savings could be done more efficiently by using Writters.
"""
import os
from time import time
from datetime import datetime as dt
from datetime import timedelta
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm # Just to log fancy training progress bar
import pandas as pd
import numpy as np
from utils.util import set_seed, get_accuracy
from dataset import CIFARDataset
class CIFARTrainingApp:
def __init__(self, data_train, data_val, app_parameters, hyperparameters, transformations, design):
# Dataset
self.data_train = data_train
self.data_val = data_val
# App parameters
self.app_parameters = app_parameters
self.num_workers = app_parameters['num_workers']
self.save_path = app_parameters['save_path']
if app_parameters['seed'] is None:
self.seed = int(dt.now().timestamp())
else:
self.seed = int(app_parameters['seed'])
# Hyperparameters
self.hyperparameters = hyperparameters
self.batch_size = hyperparameters['batch_size']
self.epochs = hyperparameters['epochs']
self.model_parameters = hyperparameters['model_parameters']
self.optimizer_parameters = hyperparameters['optimizer_parameters']
if 'scheduler_parameters' in hyperparameters.keys():
self.scheduler_parameters = hyperparameters['scheduler_parameters']
else:
self.scheduler_parameters = None
# Device
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using {self.device} device')
# Metrics
self.metrics_train = {'loss': [], 'acc': []}
self.metrics_val = {'loss': [], 'acc': []}
self.best_accuracy = 0.0
# Initialize design (model, criterion and optimizer)
self.design = design
self.init_design()
# Transformations
self.transformations = transformations
self.transformation_train = transformations['train']
self.transformation_val = transformations['val']
# Timestamp to identify training runs
time_str = dt.now().strftime('%Y-%m-%d_%H.%M.%S')
self.file_path = f'{self.save_path}_{time_str}'
# To continue training from checkpoint
self.start_epoch = 0
self.elapsed_time = timedelta(0)
def main(self, validate_bool=True, save_bool=False):
"""
Here is where the magic happens. It trains the model for all epochs, stores the training and
validation metrics, and displays the results.
"""
print("Loading data...")
loader_train, loader_val = self.init_dataloader()
n_train = len(loader_train)
n_val = len(loader_val)
print(f"\nStart training with {n_train}/{n_val} batches (trn/val) "
f"of size {self.batch_size} for {self.epochs} epochs\n")
# Open and write "header" of the txt file if file does not exist already (in case of checkpoint)
if save_bool and not os.path.exists(self.file_path + '.txt'):
with open(self.file_path + '.txt', 'w') as file:
file.write(str(self.design) + '\n')
file.write(str(self.transformations) + '\n')
file.write(str(self.app_parameters) + '\n')
file.write(str(self.hyperparameters) + '\n')
start = time() - self.elapsed_time.seconds
for epoch_ndx in range(1 + self.start_epoch, self.epochs + 1):
# Train and store metrics for a single epoch
self.train(epoch_ndx, loader_train)
# Validation
if validate_bool:
loss_val, accuracy_val = self.validate(loader_val, n_val)
# Save epoch metrics
self.metrics_val['loss'].append(loss_val)
self.metrics_val['acc'].append(accuracy_val)
# Log metrics
print(f'Val Loss: {loss_val:.3f} Val Acc: {(100 * accuracy_val):.2f}%', end="")
# Check for best accuracy and updates it
if accuracy_val > self.best_accuracy:
self.best_accuracy = accuracy_val
print(' (Best)', end="")
# Log elapsed time
end = time()
self.elapsed_time = timedelta(seconds=int(end-start))
print(f' ({str(self.elapsed_time)})\n')
if save_bool:
# Save metric to txt file
with open(self.file_path + '.txt', 'a') as file:
file.write(f'\nEpoch {epoch_ndx}/{self.epochs} Val Loss: {loss_val:.3f} Val Acc {(100 * accuracy_val):.2f}% ({str(self.elapsed_time)})\n')
# Save model's state
self.save_model(epoch_ndx, stamp='')
# If best performance achieved, update best-model file
if accuracy_val == self.best_accuracy:
self.save_model(epoch_ndx, stamp='_best')
# Update scheduler
if self.scheduler is not None:
self.scheduler.step()
print(f'lr: {self.scheduler.get_last_lr()[0]}')
# Compute accuracy over entire training set
loss_train, accuracy_train = self.validate(loader_train, n_train)
# Log training metrics
print(f'Train Loss: {loss_train:.3f} Train Acc: {(100 * accuracy_train):.2f}%', end="")
# Log and save best accuracy after all training epochs
if validate_bool:
print(f'\nBest accuracy: {(100 * self.best_accuracy):.2f}%')
if save_bool:
with open(self.file_path + '.txt', 'a') as file:
file.write(f'\nTrain Loss: {loss_train:.3f} Train Acc {(100 * accuracy_train):.2f}%\n')
file.write(f'\nBest accuracy: {(100 * self.best_accuracy):.2f}%')
def init_design(self):
set_seed(self.seed)
self.model = self.design['model'](**self.model_parameters).to(self.device)
self.criterion = self.design['criterion']()
self.optimizer = self.design['optimizer'](self.model.parameters(), **self.optimizer_parameters)
if 'scheduler' in self.design.keys():
self.scheduler = self.design['scheduler'](self.optimizer, **self.scheduler_parameters)
else:
self.scheduler = None
def init_dataloader(self):
dataset_train = CIFARDataset(self.data_train, transform=self.transformation_train)
dataset_val = CIFARDataset(self.data_val, transform=self.transformation_val)
loader_train = DataLoader(dataset_train,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True)
loader_val = DataLoader(dataset_val,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True)
return loader_train, loader_val
def train(self, epoch_ndx, trainloader):
# To log fancy progress bar
loop = tqdm(enumerate(trainloader, start=1), total=len(trainloader))
loop.set_description(f"Epoch {epoch_ndx}/{self.epochs}")
self.model.train()
for i, (x, y) in loop:
# Send to device
x = x.to(self.device)
y = y.to(self.device)
# Set gradients to zero
self.optimizer.zero_grad()
# Forward pass
y_pred = self.model(x)
# Compute loss
loss = self.criterion(y_pred, y.squeeze())
# Backward pass
loss.backward()
# Optimize
self.optimizer.step()
# Compute statistics
loss_train = loss.item()
accuracy_train = get_accuracy(y_pred, y)
# Save batch metrics
self.metrics_train['loss'].append(loss_train)
self.metrics_train['acc'].append(accuracy_train)
# Update progress bar
loop.set_postfix_str(f"Loss = {loss_train:.3f}, Acc = {(100 * accuracy_train):.2f}%")
print('')
def validate(self, loader, n_loader):
# Initialize running loss and accuracy
loss = 0.0
accuracy = 0.0
self.model.eval()
with torch.no_grad():
for x, y in loader:
# Send to device
x = x.to(self.device)
y = y.to(self.device)
# Forward, and compute loss
y_pred = self.model(x)
loss_batch = self.criterion(y_pred, y.squeeze())
# Update statistics
loss += loss_batch.item()
accuracy += get_accuracy(y_pred, y)
loss /= n_loader
accuracy /= n_loader
return loss, accuracy
def get_metrics(self):
"""
Returns metrics as Pandas DataFrames ready to be plotted
"""
metrics_train_df = pd.DataFrame.from_dict(self.metrics_train)
metrics_val_df = pd.DataFrame.from_dict(self.metrics_val)
# Change indices to plot training and validation metrics in same plot
step = len(metrics_val_df) / len(metrics_train_df)
metrics_train_df.index = np.arange(step, len(metrics_val_df) + step, step)
metrics_val_df.index = np.arange(1, len(metrics_val_df) + 1)
return metrics_train_df, metrics_val_df
def save_model(self, epoch_ndx, stamp):
"""
Save model state and hyperparameters
"""
state = {
'model_state': self.model.state_dict(), # Model's state
'optimizer_step': self.optimizer.state_dict(),
'app_parameters': self.app_parameters,
'hyperparameters': self.hyperparameters,
'design': self.design,
'transformations': self.transformations,
'metrics': {'train': self.metrics_train, 'val': self.metrics_val},
'epoch': epoch_ndx,
'file_path': self.file_path,
'elapsed_time': self.elapsed_time,
'best_accuracy': self.best_accuracy
}
if self.scheduler is not None:
state['scheduler_state'] = self.scheduler.state_dict()
torch.save(state, self.file_path + stamp + '.state')
def load_state(self, path):
state = torch.load(path)
# Load model, optimizer and scheduler states
self.model.load_state_dict(state['model_state'])
self.optimizer.load_state_dict(state['optimizer_step'])
if 'scheduler_state' in state.keys() and self.scheduler is not None:
self.scheduler.load_state_dict(state['scheduler_state'])
# Load metrics
self.metrics_train = state['metrics']['train']
self.metrics_val = state['metrics']['val']
# To continue training from checkpoint correctly
self.best_accuracy = state['best_accuracy']
self.start_epoch = state['epoch']
self.elapsed_time = state['elapsed_time']
self.file_path = state['file_path']