-
Notifications
You must be signed in to change notification settings - Fork 19
/
train_mnist.py
executable file
·122 lines (98 loc) · 3.63 KB
/
train_mnist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
#!/usr/bin/env python
"""Chainer example: train a multi-layer perceptron on MNIST
This is a minimal example to write a feed-forward net.
"""
from __future__ import print_function
import argparse
import numpy as np
import six
import chainer
from chainer import computational_graph
from chainer import cuda
import chainer.links as L
from chainer import optimizers
from chainer import serializers
import data
import net
import weight_clip
parser = argparse.ArgumentParser(description='Chainer example: MNIST')
parser.add_argument('--initmodel', '-m', default='',
help='Initialize the model from given file')
parser.add_argument('--resume', '-r', default='',
help='Resume the optimization from snapshot')
parser.add_argument('--gpu', '-g', default=-1, type=int,
help='GPU ID (negative value indicates CPU)')
args = parser.parse_args()
batchsize = 100
n_epoch = 20
n_units = 1000
# Prepare dataset
print('load MNIST dataset')
mnist = data.load_mnist_data()
mnist['data'] = mnist['data'].astype(np.float32)
mnist['data'] /= 255
mnist['target'] = mnist['target'].astype(np.int32)
N = 60000
x_train, x_test = np.split(mnist['data'], [N])
y_train, y_test = np.split(mnist['target'], [N])
N_test = y_test.size
# Prepare multi-layer perceptron model, defined in net.py
model = L.Classifier(net.MnistMLP(784, n_units, 10))
if args.gpu >= 0:
cuda.get_device(args.gpu).use()
model.to_gpu()
xp = np if args.gpu < 0 else cuda.cupy
# Setup optimizer
optimizer = optimizers.Adam()
optimizer.setup(model)
optimizer.add_hook(weight_clip.WeightClip())
# Init/Resume
if args.initmodel:
print('Load model from', args.initmodel)
serializers.load_hdf5(args.initmodel, model)
if args.resume:
print('Load optimizer state from', args.resume)
serializers.load_hdf5(args.resume, optimizer)
# Learning loop
for epoch in six.moves.range(1, n_epoch + 1):
print('epoch', epoch)
# training
perm = np.random.permutation(N)
sum_accuracy = 0
sum_loss = 0
net.train = True
for i in six.moves.range(0, N, batchsize):
x = chainer.Variable(xp.asarray(x_train[perm[i:i + batchsize]]))
t = chainer.Variable(xp.asarray(y_train[perm[i:i + batchsize]]))
# Pass the loss function (Classifier defines it) and its arguments
optimizer.update(model, x, t)
if epoch == 1 and i == 0:
with open('graph.dot', 'w') as o:
g = computational_graph.build_computational_graph(
(model.loss, ), remove_split=True)
o.write(g.dump())
print('graph generated')
sum_loss += float(model.loss.data) * len(t.data)
sum_accuracy += float(model.accuracy.data) * len(t.data)
print('train mean loss={}, accuracy={}'.format(
sum_loss / N, sum_accuracy / N))
# evaluation
sum_accuracy = 0
sum_loss = 0
# net.train = False
for i in six.moves.range(0, N_test, batchsize):
# these volatile='on' but current chainer has bug on batch normalization
x = chainer.Variable(xp.asarray(x_test[i:i + batchsize]),
volatile='off')
t = chainer.Variable(xp.asarray(y_test[i:i + batchsize]),
volatile='off')
loss = model(x, t)
sum_loss += float(loss.data) * len(t.data)
sum_accuracy += float(model.accuracy.data) * len(t.data)
print('test mean loss={}, accuracy={}'.format(
sum_loss / N_test, sum_accuracy / N_test))
# Save the model and the optimizer
print('save the model')
serializers.save_hdf5('mlp.model', model)
print('save the optimizer')
serializers.save_hdf5('mlp.state', optimizer)