-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathflower299.py
129 lines (99 loc) · 5.01 KB
/
flower299.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# -*- coding: utf-8 -*-
"""
preprocess flower data:
train:
1. crop data, size exchange into 500 x 500
2. resize data, size exchange into 299 x 299
3. save data and label into TFRecords
test:
1. crop data, size exchange into 500 x 500
2. resize data, size exchange into 299 x 299
3. save test data and label into TFRecords
读取原始数据,
将train数据集裁剪成500 x 500,然后最邻近重采样成 299 x 299,保存成TFRecords;
将test数据集裁剪成500 x 500,然后最邻近重采样成 299 x 299,保存成TFRecords;
"""
##################### load packages #####################
import numpy as np
import os
from PIL import Image
import scipy.io
import tensorflow as tf
##################### load flower data ##########################
def flower_preprocess(flower_folder):
'''
flower_floder: flower original path 原始花的路径
flower_crop: 处理后的flower存放路径
'''
######## flower dataset label 数据label ########
labels = scipy.io.loadmat('/Users/shaoqi/Desktop/Googlenet/imagelabels.mat')
labels = np.array(labels['labels'][0]) - 1
######## flower dataset: train test valid 数据id标识 ########
setid = scipy.io.loadmat('/Users/shaoqi/Desktop/Googlenet/setid.mat')
train = np.array(setid['tstid'][0]) - 1
np.random.shuffle(train)
test = np.array(setid['trnid'][0]) - 1
np.random.shuffle(test)
######## flower data TFRecords save path TFRecords保存路径 ########
writer_299_train = tf.python_io.TFRecordWriter("/Users/shaoqi/Desktop/Googlenet/flower_train_299.tfrecords")
writer_299_test = tf.python_io.TFRecordWriter("/Users/shaoqi/Desktop/Googlenet/flower_test_299.tfrecords")
######## flower data path 数据保存路径 ########
flower_dir = list()
######## flower data dirs 生成保存数据的绝对路径和名称 ########
for img in os.listdir(flower_folder):
######## flower data ########
flower_dir.append(os.path.join(flower_folder, img))
######## flower data dirs sort 数据的绝对路径和名称排序 从小到大 ########
flower_dir.sort()
###################### flower train data #####################
for tid in train:
######## open image and get label ########
img = Image.open(flower_dir[tid])
######## get width and height ########
width, height = img.size
######## crop paramater ########
h = 500
x = int((width - h) / 2)
y = int((height - h) / 2)
################### crop image 500 x 500 and save image ##################
img_crop = img.crop([x, y, x + h, y + h])
img_299 = img_crop.resize((299, 299), Image.NEAREST)
######## img to bytes 将图片转化为二进制格式 ########
img_299 = img_299.tobytes()
######## build features 建立包含多个Features 的 Example ########
example_299 = tf.train.Example(features=tf.train.Features(feature={
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[labels[tid]])),
'img': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_299])),
'height': tf.train.Feature(int64_list=tf.train.Int64List(value=[299])),
'width': tf.train.Feature(int64_list=tf.train.Int64List(value=[299]))
}))
######## 序列化为字符串,写入到硬盘 ########
writer_299_train.write(example_299.SerializeToString())
##################### flower test data ####################
for tsd in np.sort(test):
####### open image and get width and height #######
img = Image.open(flower_dir[tsd])
width, height = img.size
######## crop paramater ########
h = 500
x = int((width - h) / 2)
y = int((height - h) / 2)
################### crop image 500 x 500 and save image ##################
img_crop = img.crop([x, y, x + h, y + h])
img_299 = img_crop.resize((299, 299), Image.NEAREST)
######## img to bytes 将图片转化为二进制格式 ########
img_299 = img_299.tobytes()
######## build features 建立包含多个Features 的 Example ########
example_299 = tf.train.Example(features=tf.train.Features(feature={
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[labels[tsd]])),
'img': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_299])),
'height': tf.train.Feature(int64_list=tf.train.Int64List(value=[height])),
'width': tf.train.Feature(int64_list=tf.train.Int64List(value=[width]))}))
######## 序列化为字符串,写入到硬盘 ########
writer_299_test.write(example_299.SerializeToString())
################ main函数入口 ##################
if __name__ == '__main__':
######### flower path 鲜花数据存放路径 ########
flower_folder = '/Users/shaoqi/Desktop/Googlenet/102flowers'
######## 数据预处理 ########
flower_preprocess(flower_folder)