forked from bat67/pytorch-FCN-easiest-demo
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
123 lines (93 loc) · 4.77 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from datetime import datetime
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import visdom
from BagData import test_dataloader, train_dataloader
from FCN import FCN8s, FCN16s, FCN32s, FCNs, VGGNet
def train(epo_num=50, show_vgg_params=False):
vis = visdom.Visdom()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
vgg_model = VGGNet(requires_grad=True, show_params=show_vgg_params)
fcn_model = FCNs(pretrained_net=vgg_model, n_class=2)
fcn_model = fcn_model.to(device)
criterion = nn.BCELoss().to(device)
optimizer = optim.SGD(fcn_model.parameters(), lr=1e-2, momentum=0.7)
all_train_iter_loss = []
all_test_iter_loss = []
# start timing
prev_time = datetime.now()
for epo in range(epo_num):
train_loss = 0
fcn_model.train()
for index, (bag, bag_msk) in enumerate(train_dataloader):
# bag.shape is torch.Size([4, 3, 160, 160])
# bag_msk.shape is torch.Size([4, 2, 160, 160])
bag = bag.to(device)
bag_msk = bag_msk.to(device)
optimizer.zero_grad()
output = fcn_model(bag)
output = torch.sigmoid(output) # output.shape is torch.Size([4, 2, 160, 160])
loss = criterion(output, bag_msk)
loss.backward()
iter_loss = loss.item()
all_train_iter_loss.append(iter_loss)
train_loss += iter_loss
optimizer.step()
output_np = output.cpu().detach().numpy().copy() # output_np.shape = (4, 2, 160, 160)
output_np = np.argmin(output_np, axis=1)
bag_msk_np = bag_msk.cpu().detach().numpy().copy() # bag_msk_np.shape = (4, 2, 160, 160)
bag_msk_np = np.argmin(bag_msk_np, axis=1)
if np.mod(index, 15) == 0:
print('epoch {}, {}/{},train loss is {}'.format(epo, index, len(train_dataloader), iter_loss))
# vis.close()
vis.images(output_np[:, None, :, :], win='train_pred', opts=dict(title='train prediction'))
vis.images(bag_msk_np[:, None, :, :], win='train_label', opts=dict(title='label'))
vis.line(all_train_iter_loss, win='train_iter_loss',opts=dict(title='train iter loss'))
# plt.subplot(1, 2, 1)
# plt.imshow(np.squeeze(bag_msk_np[0, ...]), 'gray')
# plt.subplot(1, 2, 2)
# plt.imshow(np.squeeze(output_np[0, ...]), 'gray')
# plt.pause(0.5)
test_loss = 0
fcn_model.eval()
with torch.no_grad():
for index, (bag, bag_msk) in enumerate(test_dataloader):
bag = bag.to(device)
bag_msk = bag_msk.to(device)
optimizer.zero_grad()
output = fcn_model(bag)
output = torch.sigmoid(output) # output.shape is torch.Size([4, 2, 160, 160])
loss = criterion(output, bag_msk)
iter_loss = loss.item()
all_test_iter_loss.append(iter_loss)
test_loss += iter_loss
output_np = output.cpu().detach().numpy().copy() # output_np.shape = (4, 2, 160, 160)
output_np = np.argmin(output_np, axis=1)
bag_msk_np = bag_msk.cpu().detach().numpy().copy() # bag_msk_np.shape = (4, 2, 160, 160)
bag_msk_np = np.argmin(bag_msk_np, axis=1)
if np.mod(index, 15) == 0:
print(r'Testing... Open http://localhost:8097/ to see test result.')
# vis.close()
vis.images(output_np[:, None, :, :], win='test_pred', opts=dict(title='test prediction'))
vis.images(bag_msk_np[:, None, :, :], win='test_label', opts=dict(title='label'))
vis.line(all_test_iter_loss, win='test_iter_loss', opts=dict(title='test iter loss'))
# plt.subplot(1, 2, 1)
# plt.imshow(np.squeeze(bag_msk_np[0, ...]), 'gray')
# plt.subplot(1, 2, 2)
# plt.imshow(np.squeeze(output_np[0, ...]), 'gray')
# plt.pause(0.5)
cur_time = datetime.now()
h, remainder = divmod((cur_time - prev_time).seconds, 3600)
m, s = divmod(remainder, 60)
time_str = "Time %02d:%02d:%02d" % (h, m, s)
prev_time = cur_time
print('epoch train loss = %f, epoch test loss = %f, %s'
%(train_loss/len(train_dataloader), test_loss/len(test_dataloader), time_str))
if np.mod(epo, 5) == 0:
torch.save(fcn_model, 'checkpoints/fcn_model_{}.pt'.format(epo))
print('saveing checkpoints/fcn_model_{}.pt'.format(epo))
if __name__ == "__main__":
train(epo_num=100, show_vgg_params=False)