-
Notifications
You must be signed in to change notification settings - Fork 0
/
data_utils.py
69 lines (50 loc) · 2.03 KB
/
data_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
# This script contains functions for loading the dataset and labels from the MNIST Dataset
# Importing the required libraries
import gzip
import random
import numpy as np
import matplotlib.pyplot as plt
def load_data(data_dir, verbose = False):
"""This function loads the data for the MNIST Dataset"""
# MNIST Image Size
IMAGE_SIZE = 28
# Unzipping the dataset and reading it
with gzip.open(data_dir, mode = "r") as f:
f.read(16)
buf = f.read()
# Using np.frombuffer() function to decode the data
data = np.frombuffer(buf, dtype = np.uint8).astype(np.float32)
# Reshaping it to dimension(num_imgs, IMAGE_SIZE, IMAGE_SIZE, 1)
data = data.reshape(-1, IMAGE_SIZE, IMAGE_SIZE, 1)
# If verbose, then the loaded data is randomly visualized
if verbose:
idx = random.randint(0, len(data))
# Removing Channel Dimension
image = np.asarray(data[idx]).squeeze()
plt.imshow(image, cmap = "gray")
plt.show()
return data
def load_labels(labels_dir, verbose = False):
"""This function loads the labels for the MNIST Dataset"""
# Unzipping the dataset and reading it
with gzip.open(labels_dir, mode = "r") as f:
f.read(8)
buf = f.read()
# Using np.frombuffer() function to decode the data
data = np.frombuffer(buf, dtype = np.uint8).astype(np.int64)
# If verbose, then the first 10 values of the loaded labels is printed
if verbose:
print(data[:10])
return data
if __name__ == "__main__":
# Loading the data
train_data_dir = "data/train-images-idx3-ubyte.gz"
train_data = load_data(train_data_dir, verbose = True)
# Loading the label
train_labels_dir = "data/train-labels-idx1-ubyte.gz"
train_labels = load_labels(train_labels_dir, verbose = True)
# Randomly visualizing the loaded dataset and it's corresponding label
idx = random.randint(0, len(train_data))
plt.imshow(train_data[idx].squeeze(), cmap = "gray")
plt.title(train_labels[idx])
plt.show()