-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathceleba.py
70 lines (56 loc) · 2.27 KB
/
celeba.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from os import listdir
import random
from scipy.misc import imread, imresize
import numpy as np
import tensorflow as tf
def create_pipeline(filename, width=64, height=64, depth=3, batch_size=64, name=None):
with tf.variable_scope(name):
filename_queue = tf.train.string_input_producer([filename])
reader = tf.TFRecordReader()
key, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
features={
'height': tf.FixedLenFeature([], tf.int64),
'width': tf.FixedLenFeature([], tf.int64),
'depth': tf.FixedLenFeature([], tf.int64),
'image_raw': tf.FixedLenFeature([], tf.string),
})
image = tf.decode_raw(features['image_raw'], tf.uint8)
image.set_shape(height * width * depth)
image = tf.reshape(image, (height, width, depth))
image = tf.cast(image, tf.float32)
image = image / 127.5 - 1
pipeline = tf.train.shuffle_batch(
[image], batch_size=batch_size, num_threads=4,
capacity=1000 + 3 * batch_size, min_after_dequeue=1000)
return pipeline
class Reader:
def __init__(self, path, width=64, height=64, batch_size=64, shuffle=True):
self.file_list = listdir(path)
self.path = path
self.file_count = len(self.file_list)
self.batch_size = batch_size
self.width = width
self.height = height
self.shuffle = shuffle
self.last_read = 0
def next_batch(self):
batch = np.ndarray([self.batch_size, self.height, self.width, 3])
for i in xrange(0, self.batch_size):
if self.shuffle:
ind = random.randint(0, self.file_count - 1)
filename = self.file_list[ind]
image = imresize(imread(self.path + '/' + filename), (self.height, self.width))
batch[i, :, :, :] = \
(image - 127.5) / 127.5
else:
raise NotImplementedError
return batch
def test():
reader = Reader("/mnt/DataBlock/CelebA/Img/img_align_celeba", 64, 64)
for i in xrange(0, 10000):
batch = reader.next_batch()
print 'done'
if __name__ == "__main__":
test()