-
Notifications
You must be signed in to change notification settings - Fork 18
/
Copy pathtrain.py
96 lines (77 loc) · 2.99 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import datetime
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
from mxnet import autograd
from mxnet import gluon
from mxnet import nd
import mxnet as mx
import pickle
from model import Net
train_nd = nd.load('train.nd')
valid_nd = nd.load('valid.nd')
input_nd = nd.load('input.nd')
f = open('ids_synsets','rb')
ids_synsets = pickle.load(f)
f.close()
num_epochs = 100
batch_size = 128
learning_rate = 1e-4
weight_decay = 1e-4
pngname='train.png'
modelparams='train.params'
train_data = gluon.data.DataLoader(gluon.data.ArrayDataset(train_nd[0],train_nd[1]), batch_size=batch_size,shuffle=True)
valid_data = gluon.data.DataLoader(gluon.data.ArrayDataset(valid_nd[0],valid_nd[1]), batch_size=batch_size,shuffle=True)
input_data = gluon.data.DataLoader(gluon.data.ArrayDataset(input_nd[0],input_nd[1]), batch_size=batch_size,shuffle=True)
def get_loss(data, net, ctx):
loss = 0.0
for feas, label in data:
label = label.as_in_context(ctx)
output = net(feas.as_in_context(ctx))
cross_entropy = softmax_cross_entropy(output, label)
loss += nd.mean(cross_entropy).asscalar()
return loss / len(data)
softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss()
def train(net, train_data, valid_data, num_epochs, lr, wd, ctx):
trainer = gluon.Trainer(
net.collect_params(), 'adam', {'learning_rate': lr, 'wd': wd})
train_loss = []
if valid_data is not None:
test_loss = []
prev_time = datetime.datetime.now()
for epoch in range(num_epochs):
_loss = 0.
for data, label in train_data:
label = label.as_in_context(ctx)
with autograd.record():
output = net(data.as_in_context(ctx))
loss = softmax_cross_entropy(output, label)
loss.backward()
trainer.step(batch_size)
_loss += nd.mean(loss).asscalar()
cur_time = datetime.datetime.now()
h, remainder = divmod((cur_time - prev_time).seconds, 3600)
m, s = divmod(remainder, 60)
time_str = "Time %02d:%02d:%02d" % (h, m, s)
__loss = _loss/len(train_data)
train_loss.append(__loss)
if valid_data is not None:
valid_loss = get_loss(valid_data, net, ctx)
epoch_str = ("Epoch %d. Train loss: %f, Valid loss %f, "
% (epoch,__loss , valid_loss))
test_loss.append(valid_loss)
else:
epoch_str = ("Epoch %d. Train loss: %f, "
% (epoch, __loss))
prev_time = cur_time
print(epoch_str + time_str + ', lr ' + str(trainer.learning_rate))
plt.plot(train_loss, 'r')
if valid_data is not None:
plt.plot(test_loss, 'g')
plt.legend(['Train_Loss', 'Test_Loss'], loc=2)
plt.savefig(pngname, dpi=1000)
net.collect_params().save(modelparams)
ctx = mx.gpu()
net = Net(ctx).output
net.hybridize()
train(net, train_data,valid_data, num_epochs, learning_rate, weight_decay, ctx)