-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathtraning.py
138 lines (113 loc) · 4.49 KB
/
traning.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import torch
import torch.utils.data
import torch.optim as optim
from utils.dataset_processing import evaluation
from models.common import post_process_output
import logging
def validate(net, device, val_data, batches_per_epoch):
"""
Run validation.
:param net: Network
:param device: Torch device
:param val_data: Validation Dataset
:param batches_per_epoch: Number of batches to run
:return: Successes, Failures and Losses
"""
net.eval()
results = {
'correct': 0,
'failed': 0,
'loss': 0,
'losses': {
}
}
ld = len(val_data)
with torch.no_grad():
batch_idx = 0
while batch_idx < batches_per_epoch:
for x, y, didx, rot, zoom_factor in val_data:
batch_idx += 1
if batches_per_epoch is not None and batch_idx >= batches_per_epoch:
break
xc = x.to(device)
yc = [yy.to(device) for yy in y]
lossd = net.compute_loss(xc, yc)
loss = lossd['loss']
results['loss'] += loss.item()/ld
for ln, l in lossd['losses'].items():
if ln not in results['losses']:
results['losses'][ln] = 0
results['losses'][ln] += l.item()/ld
q_out, ang_out, w_out = post_process_output(lossd['pred']['pos'], lossd['pred']['cos'],
lossd['pred']['sin'], lossd['pred']['width'])
s = evaluation.calculate_iou_match(q_out, ang_out,
val_data.dataset.get_gtbb(didx, rot, zoom_factor),
no_grasps=2,
grasp_width=w_out,
)
if s:
results['correct'] += 1
else:
results['failed'] += 1
return results
def train(epoch, net, device, train_data, optimizer, batches_per_epoch, vis=False):
"""
Run one training epoch
:param epoch: Current epoch
:param net: Network
:param device: Torch device
:param train_data: Training Dataset
:param optimizer: Optimizer
:param batches_per_epoch: Data batches to train on
:param vis: Visualise training progress
:return: Average Losses for Epoch
"""
results = {
'loss': 0,
'losses': {
}
}
net.train()
batch_idx = 0
# Use batches per epoch to make training on different sized datasets (cornell/jacquard) more equivalent.
while batch_idx < batches_per_epoch:
for x, y, _, _, _ in train_data:
# print("shape:",x.shape)
batch_idx += 1
# if batch_idx >= batches_per_epoch:
# break
# print("x_0:",x[0].shape,y[0][0].shape)
# plt.imshow(x[0].permute(1,2,0).numpy())
# plt.show()
# plt.imshow(y[0][0][0].numpy())
# plt.show()
xc = x.to(device)
yc = [yy.to(device) for yy in y]
# print("xc shape:",xc.shape)
lossd = net.compute_loss(xc, yc)
loss = lossd['loss']
if batch_idx % 10 == 0:
logging.info('Epoch: {}, Batch: {}, Loss: {:0.4f}'.format(epoch, batch_idx, loss.item()))
results['loss'] += loss.item()
for ln, l in lossd['losses'].items():
if ln not in results['losses']:
results['losses'][ln] = 0
results['losses'][ln] += l.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Display the images
if vis:
imgs = []
n_img = min(4, x.shape[0])
for idx in range(n_img):
imgs.extend([x[idx,].numpy().squeeze()] + [yi[idx,].numpy().squeeze() for yi in y] + [
x[idx,].numpy().squeeze()] + [pc[idx,].detach().cpu().numpy().squeeze() for pc in lossd['pred'].values()])
# gridshow('Display', imgs,
# [(xc.min().item(), xc.max().item()), (0.0, 1.0), (0.0, 1.0), (-1.0, 1.0), (0.0, 1.0)] * 2 * n_img,
# [cv2.COLORMAP_BONE] * 10 * n_img, 10)
# cv2.waitKey(2)
results['loss'] /= batch_idx
for l in results['losses']:
results['losses'][l] /= batch_idx
return results