-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
185 lines (176 loc) · 6.65 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
from general import (
CRITERIONS,
NeuralNet,
DEVICE,
IMAGE_SIZE,
Path,
SyntheticDataset,
current_time,
json,
relative_path,
sqrt,
torch,
torchvision,
)
NUM_EPOCHS = 53
BATCH_SIZE = 16
LEARNING_RATE = 0.000005
CRITERION = 'IngredientScannerLoss'
LOSS_ALPHA = 2.0
LOSS_BETA = 1.2
ENABLE_TEST_TRAIN = False
def calculate_eval(model: NeuralNet, test_loader: torch.utils.data.DataLoader) -> float:
"""
Calculates the score of the provided model of the test dataset.
:param model: Model to evaluate
:param test_loader: loader for the test dataset
:return: score as a float
"""
# set model to evaluation mode
model.eval()
with torch.no_grad():
# count total distance
n_total_distance = 0.0
n_total_count = 0
for images, data in test_loader:
n_total_count += len(images)
images = images.to(DEVICE)
data = data.to(DEVICE)
outputs = model(images)
distance = torch.abs(outputs - data).tolist()
for i in range(len(distance)):
for j in range(len(distance[i]) // 2):
# calculate distance in 2D space
n_total_distance += sqrt(distance[i][2 * j] ** 2 + distance[i][2 * j + 1] ** 2)
# average over number of images and number of points
return n_total_distance / (n_total_count * 12)
def required_loss(loss_history: list[float]) -> float:
"""
Returns the loss required to end the training
:param loss_history: list of all losses
:return: maximum loss
"""
array = sorted(loss_history[-(NUM_EPOCHS // 8):])
if len(array) >= 2:
return array[1]
return array[0]
def main() -> None:
"""
Train the vision model
:return: None
"""
global NUM_EPOCHS # somehow required for `scheduler = torch.optim...`
start_time = current_time()
for path in ['tmp/frames', 'tmp/synthetic_frames']:
Path(relative_path(path)).mkdir(parents=True, exist_ok=True)
transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
])
# load datasets
train_dataset = SyntheticDataset(train=True, transform=transform)
test_dataset = SyntheticDataset(train=False, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE)
model = NeuralNet().to(DEVICE)
# define criterion and optimizer
if CRITERION == 'IngredientScannerLoss':
criterion = CRITERIONS[CRITERION](alpha=LOSS_ALPHA, beta=LOSS_BETA)
else:
criterion = CRITERIONS[CRITERION]()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=0.0001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=NUM_EPOCHS // 10, gamma=0.1)
# train
n_total_steps = len(train_loader)
i = 0
loss_history = []
keyboard_interrupt = False
train = True
epoch = 0
while train:
try:
if keyboard_interrupt:
raise KeyboardInterrupt()
model.train()
for i, (images, data) in enumerate(train_loader):
images = images.to(DEVICE)
data = data.to(DEVICE)
outputs = model(images)
loss = criterion(outputs, data)
if loss.dim() > 0:
loss = loss.mean()
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_history.append(loss.item())
epoch += 1
scheduler.step()
except KeyboardInterrupt:
if not keyboard_interrupt:
print('Keyboard interrupt detected')
train = False
NUM_EPOCHS = epoch
try:
if loss_history:
average_distance = calculate_eval(model, test_loader)
print(f"Synthetic Dataset, Epoch [{epoch}/{NUM_EPOCHS}], Step [{i + 1}/{n_total_steps}], "
f"Loss: {loss_history[-1]:.6f}, Average distance: {average_distance:.4f}")
if epoch >= NUM_EPOCHS and loss_history[-1] <= required_loss(loss_history):
train = False
except KeyboardInterrupt:
print('Keyboard interrupt detected')
keyboard_interrupt = True
if ENABLE_TEST_TRAIN:
# train on test dataset, if selected
for epoch2 in range(NUM_EPOCHS // 2 + 1):
try:
model.train()
for i, (images, data) in enumerate(test_loader):
images = images.to(DEVICE)
data = data.to(DEVICE)
outputs = model(images)
loss = criterion(outputs, data)
if loss.dim() > 0:
loss = loss.mean()
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_history.append(loss.item())
scheduler.step()
except KeyboardInterrupt:
print('Keyboard interrupt detected')
break
try:
if loss_history:
average_distance = calculate_eval(model, test_loader)
print(f"Test Dataset,Epoch [{epoch2}/{NUM_EPOCHS // 2 + 1}], Step [{i + 1}/{n_total_steps}], "
f"Loss: {loss_history[-1]:.6f}, Average distance: {average_distance:.4f}")
except KeyboardInterrupt:
print('Keyboard interrupt detected')
break
print('Finished Training')
average_distance = calculate_eval(model, test_loader)
print(f"Average distance: {average_distance:.4f}")
time_stamp = current_time()
# save model to disk
torch.save(model.state_dict(), relative_path(f"models/{time_stamp}.pt"))
# save metadata of the model to disk
model_data = {
'average_distance': average_distance,
'batch_size': BATCH_SIZE,
'criterion': CRITERION,
'enable_test_train': ENABLE_TEST_TRAIN,
'image_size_x': IMAGE_SIZE[0],
'image_size_y': IMAGE_SIZE[1],
'last_loss': loss_history[-1],
'learning_rate': LEARNING_RATE,
'loss_alpha': LOSS_ALPHA,
'loss_ao10': sum(loss_history[-10:-1]) / 10.0 if len(loss_history) >= 10 else None,
'loss_beta': LOSS_BETA,
'num_epochs': NUM_EPOCHS,
'start_time': start_time,
}
with open(relative_path(f"models/{time_stamp}.json"), 'w') as f:
json.dump(model_data, f, indent=2)
print(f"Model saved as {time_stamp}.pt")
if __name__ == '__main__':
main()