-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathclassification.py
314 lines (271 loc) · 12.5 KB
/
classification.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
# training loop implement here
# ---libs---
import argparse
import torch
from torch import optim
from torch.utils.data import DataLoader
from torchvision import models
from torchvision.utils import make_grid
from torch import nn
from torch.utils.tensorboard import SummaryWriter
import numpy as np
# ---modules---
from dataset.dataset import TrainDataset
from net import SimpleModel
# ---misc---
from tqdm import tqdm
import os
from datetime import datetime
import shutil
import logging
from config import *
from lib import sample_images, count_parameters
# Set logging
FORMAT = '[%(asctime)s [%(name)s][%(levelname)s]: %(message)s'
logging.basicConfig(format=FORMAT, datefmt='%Y-%m-%d %H:%M:%S')
LOG = logging.getLogger('Classification')
def train(args):
"""
Training process main function
:param args: args from command line inputs
:return: -
"""
# get CNN model
model = select_model(args)
model.cuda()
# create optimizer
optimizer = optim.Adam(model.parameters(), lr=LR, weight_decay=L2, amsgrad=False)
# loss function
criteria = nn.CrossEntropyLoss(reduction='sum')
# variables
global_i = 0 # global steps counter
best_eval = 10000 # store the best validation result
res = -1 # store the validation result of current epoch
graph_loaded = False # tensorboard graph loading status
n_samples = 500 # samples of features to project in tensorboard
# dataset and loader
if args.vgg or args.inception:
dataset = TrainDataset(data_path=args.dataset, test_size=testSize, train_size=trainSize,
color=True, img_size=upSampling)
else:
dataset = TrainDataset(data_path=args.dataset, test_size=testSize, train_size=trainSize,
color=False, img_size=upSampling)
loader = DataLoader(dataset, batch_size=BS, shuffle=SF, num_workers=numWorkers,
pin_memory=pinMem, drop_last=dropLast, timeout=timeOut)
if len(loader) < 1:
LOG.error('Dataset maybe empty.')
raise ValueError('Dataset maybe empty.')
# continue training check & prepare saving directory
if args.load:
trained_epoch, global_i = load_ckpt(model, optimizer, args.load)
save_root = os.path.dirname(os.path.dirname(args.load))
log_dir = os.path.join(save_root, 'log/')
ckpt_dir = os.path.join(save_root, 'checkpoints/')
else:
trained_epoch = 0
save_root = os.path.join('./results', datetime.now().strftime("%H%M_%d%m%Y"))
os.mkdir(save_root)
log_dir = os.path.join(save_root, 'log/')
os.mkdir(log_dir)
ckpt_dir = os.path.join(save_root, 'checkpoints/')
os.mkdir(ckpt_dir)
# save configuration to the training directory
shutil.copy2('./config.py', save_root)
# TensorBoard summary writer -- for training process visualization
writer = SummaryWriter(log_dir=log_dir, max_queue=10, flush_secs=120)
# sample images and show the overview of the input features
sampled_images, sampled_labels = sample_images(dataset.trainset['image'], dataset.trainset['label'], n_samples)
writer.add_embedding(sampled_images.reshape(n_samples, -1),
metadata=[dataset.c2l[l] for l in sampled_labels],
label_img=torch.from_numpy(sampled_images[:, np.newaxis, :, :]))
# start training epoch
for epoch in range(trained_epoch, trained_epoch+Epochs):
model.train() # switch model to training mode
dataset.train() # switch dataset to training mode
running_loss = 0.0 # running loss for training set
running_acc = 0.0 # running accuracy for training set
running_loss_eval = 0.0 # running loss for validation set
running_acc_eval = 0.0 # running accuracy for validation set
print()
LOG.warning('Start epoch %d.' % (epoch + 1))
# iterating each batch
for i, data in enumerate(tqdm(loader)):
img, label = data[0].float().cuda(), data[1].long().cuda() # push the data to GPU
pred = model(img) # model inference
# loss for one output and tuple output
if opt.inception:
loss = loss_batch(pred, label, criteria, optimizer, mode='inception')
acc = (pred[0].argmax(1) == label).float().sum() / BS
else:
loss = loss_batch(pred, label, criteria, optimizer)
acc = (pred.argmax(1) == label).float().sum() / BS
# collect statistics
running_loss += loss.item()
running_acc += acc # accuracy of this batch
# update training statistics
if i % TBUpdate == 0:
if i == 0:
img_grid = make_grid(img)
writer.add_image('Train/Batch', img_grid, global_i)
if not graph_loaded: # only add once
writer.add_graph(model, img)
graph_loaded = True
writer.add_scalar('Train/Loss', loss.item(), global_i) # or optimizer, dropout info
writer.add_scalar('Train/Accuracy', acc, global_i)
writer.flush()
# update global step
global_i += 1
# show epoch info
LOG.warning('Epoch %d: running loss: %.4f running accuracy: %.2f' %
(epoch + 1, running_loss / len(loader), running_acc / len(loader)))
# validation
model.eval() # switch model to validation mode
dataset.eval() # switch dataset to validation mode
with torch.no_grad(): # no need to track computation graph during testing, save resources and speedups
LOG.warning('Evaluation on testing data...')
for i, data in enumerate(tqdm(loader)):
img, label = data[0].float().cuda(), data[1].long().cuda()
pred = model(img)
loss = loss_batch(pred, label, criteria)
acc = (pred.argmax(1) == label).float().sum() / BS
running_loss_eval += loss.item()
running_acc_eval += acc
# validation batch for visualization
if i == 0:
img_grid = make_grid(img)
writer.add_image('Validation/Batch', img_grid, global_i)
# validation results
res = running_loss_eval / len(loader)
acc_eval = running_acc_eval / len(loader)
writer.add_scalar('Validation/Loss', res, global_i)
writer.add_scalar('Validation/Accuracy', acc_eval, global_i)
writer.flush()
LOG.warning('Epoch %d: validation loss: %.4f accuracy: %.2f' % (epoch + 1, res, acc_eval))
# store the best model
if res < best_eval:
best_eval = res
save_ckpt(model, optimizer, (epoch + 1), global_i,
os.path.join(ckpt_dir, 'Epoch%dloss%.4f.tar' % (epoch + 1, res)))
# save the trained model
save_ckpt(model, optimizer, trained_epoch+Epochs, global_i,
os.path.join(ckpt_dir, 'Epoch%dloss%.4f.tar' % (trained_epoch+Epochs, res)))
return
def test(args):
return
def loss_batch(pred, gt, func, optimizer=None, mode='normal'):
"""
calculate the losses from the model's output and the ground truth
:param pred: model output, must be compatible with ground truth and loss function
:param gt: ground truth, must be compatible with logits and loss function
:param func: loss function, must be compatible with logits and ground truth
:param optimizer: optimizer instance
:param mode: str, 'normal' or 'inception', inception has additional auxiliary logits
:return: loss value
"""
if mode == 'normal':
loss = func(pred, gt)
elif mode == 'inception':
logits, aux_logits = pred
l1, l2 = func(logits, gt), func(aux_logits, gt)
loss = l1 + 0.3 * l2
else:
LOG.error("Unknown model output? Need a method to compute loss...")
raise ValueError("Unknown model output? Need a method to compute loss...")
if optimizer is not None:
# zero the parameter gradients
optimizer.zero_grad()
# auto-calculate gradients
loss.backward()
# apply gradients
optimizer.step()
return loss
def select_model(args):
"""
return the corresponding model based on the command line args
:param args: command line args
:return:
"""
# use the base vgg19 with batch normalization model
if args.vgg:
if upSampling < 224:
LOG.error("Minimum input size for VGG net is 224!")
raise ValueError("Minimum input size for VGG net is 224!")
LOG.warning('Loading VGG19 with Batch Normalization model...')
LOG.warning('It may take few minutes to load the PyTorch model...please wait patiently...')
model = models.vgg19_bn() # 5 maxpooling layers
# modify the model classifier to match our dataset
model.classifier = nn.Sequential(nn.Linear(in_features=25088, out_features=4096), # in 256*3*3, out 1024
nn.ReLU(),
nn.Dropout(p=0.5), # keep the same dropout rate
nn.Linear(in_features=4096, out_features=768),
nn.ReLU(),
nn.Dropout(p=0.5),
nn.Linear(in_features=768, out_features=71)) # output 71 classes
LOG.warning('Trainable parameters: %d' % count_parameters(model))
return model
# load inception v3 model
elif args.inception:
if upSampling < 299 and BS < 2:
LOG.error("Minimum input size for Inception v3 net is 299, batch size must > 1!")
raise ValueError("Minimum input size for Inception v3 net is 299, batch size must > 1!")
LOG.warning('Loading Inception v3 model...')
LOG.warning('It may take few minutes to load the PyTorch model...please wait patiently...')
model = models.inception_v3(num_classes=71)
LOG.warning('Trainable parameters: %d' % count_parameters(model))
return model
# load customzied model
elif args.simple:
model = SimpleModel.CNN()
LOG.warning('Trainable parameters: %d' % count_parameters(model))
return model
else:
LOG.error("You must select a model to train.")
raise ValueError("No model selected. use --vgg or --inception")
def save_ckpt(model, optimizer, epoch, global_step, path):
"""
Save the trained model checkpoint with a given name
:param model: pytorch model to save
:param optimizer: optimizer to save
:param epoch: current epoch value
:param global_step: current global step for tensorboard
:param path: model path to save
"""
torch.save({
'epoch': epoch,
'global_step': global_step,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
}, path)
return
def load_ckpt(model, optimizer, path):
"""
Load a pre-trained model on GPU for training or evaluation
:param model: pytorch model object to load trained parameters
:param optimizer: optimizer object used in the last training
:param path: path to the saved checkpoint
"""
ckpt = torch.load(path)
model.load_state_dict(ckpt['model_state_dict'])
optimizer.load_state_dict(ckpt['optimizer_state_dict'])
epoch = ckpt['epoch']
global_step = ckpt['global_step']
return epoch, global_step
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--load', type=str, default='', help="path to the saved 'tar' checkpoint file.")
parser.add_argument('--dataset', type=str, default='./dataset/data.npz', help="Path to the dataset.")
parser.add_argument('--train', action='store_true', help='Start training process.')
parser.add_argument('--test', action='store_true', help='Start testing process.')
parser.add_argument('--vgg', action='store_true', help='Use VGG 19 with Batch Normalization.')
parser.add_argument('--simple', action='store_true', help='Use customized CNN.')
parser.add_argument('--inception', action='store_true', help='Use Inception V3 Net.')
opt = parser.parse_args()
print(opt)
if opt.train:
train(opt)
elif opt.test:
test(opt)
raise NotImplementedError('testing mode is not implemented')
else:
LOG.warning("Please specify whether to train or test")
raise ValueError("Please specify whether to train or test using --train or --test")