-
Notifications
You must be signed in to change notification settings - Fork 1
/
digits_mnist_demo.py
93 lines (68 loc) · 2.63 KB
/
digits_mnist_demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import numpy as np
import gzip
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
from analysis.confustion_matrix import plot_confusion_matrix
from analysis.one_hot_encoder import indices_to_one_hot
from network.network import MultiLayerNetwork
from network.preprocessor import Preprocessor
from network.trainer import Trainer
filename = [
["training_images", "train-images-idx3-ubyte.gz"],
["test_images", "t10k-images-idx3-ubyte.gz"],
["training_labels", "train-labels-idx1-ubyte.gz"],
["test_labels", "t10k-labels-idx1-ubyte.gz"]
]
prefix_path = "dataset/digits_mnist"
file_path = "{prefix}/{file}"
def main():
class_labels = [str(x) for x in range(10)]
train_x, train_labels, test_x, test_labels = load_mnist()
# Convert the label class into a one-hot representation
train_y = indices_to_one_hot(train_labels, 10)
test_y = indices_to_one_hot(test_labels, 10)
# normalize the input data (since max value is at most 255)
train_x = train_x / 255
test_x = test_x / 255
input_dim = 784
neurons = [128, 64, 10]
activations = ["relu", "relu", "identity"]
net = MultiLayerNetwork(input_dim, neurons, activations)
trainer = Trainer(
network=net,
batch_size=512,
nb_epoch=256,
learning_rate=0.007,
loss_fun="cross_entropy",
shuffle_flag=True,
)
trainer.train(train_x, train_y)
print("Train loss = ", trainer.eval_loss(train_x, train_y))
print("Validation loss = ", trainer.eval_loss(test_x, test_y))
preds = net(test_x).argmax(axis=1).squeeze()
targets = test_y.argmax(axis=1).squeeze()
accuracy = (preds == targets).mean()
print("Validation accuracy: {}".format(accuracy))
# Confusion matrix
cm = confusion_matrix(targets, preds)
plot_confusion_matrix(cm, class_labels)
def load_mnist():
mnist = {}
for name in filename[:2]:
path = file_path.format(prefix=prefix_path, file=name[1])
with gzip.open(path, 'rb') as f:
mnist[name[0]] = np.frombuffer(
f.read(), np.uint8, offset=16).reshape(-1, 28*28)
for name in filename[-2:]:
path = file_path.format(prefix=prefix_path, file=name[1])
with gzip.open(path, 'rb') as f:
mnist[name[0]] = np.frombuffer(f.read(), np.uint8, offset=8)
return mnist["training_images"], mnist["training_labels"], mnist["test_images"], mnist["test_labels"]
def visualise_image(label, x_set, y_set):
img_idx = np.where(y_set == label)[0][0]
img = np.reshape(x_set[img_idx], (28, 28))
plt.figure()
plt.imshow(img)
plt.show()
if __name__ == "__main__":
main()