-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathprint_examples.py
43 lines (32 loc) · 1.07 KB
/
print_examples.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import numpy as np
import _pickle
import png
with open('cifar-100-python/train', 'rb') as f1:
dict = _pickle.load( f1 , encoding='latin1')
predata = dict['data']
labels = dict['fine_labels']
with open('cifar-100-python/test', 'rb') as f2:
dict = _pickle.load( f2 , encoding='latin1')
preval_data = dict['data'][:5000]
val_labels = dict['fine_labels'][:5000]
pretest_data = dict['data'][5000:]
test_labels = dict['fine_labels'][5000:]
with open('cifar-100-python/meta', 'rb') as f3:
dict = _pickle.load( f3 , encoding='latin1')
names=dict['fine_label_names']
del dict
def fit(arr):
p = arr.reshape((-1, 3, 32 ,32))
p = p.swapaxes(1,3)
p = p.swapaxes(1,2)
return p
predata = fit(predata)
preval_data = fit(preval_data)
pretest_data = fit(pretest_data)
# choose numbers to print
numbers_to_print = [1598,13]
for num in numbers_to_print :
f = open(str(num) + '_' + names[labels[num]] +'.png', 'wb') # binary mode is important
w = png.Writer(32,32)
w.write(f, np.reshape(predata[num], (-1, 96)))
f.close()