Skip to content

Commit

Permalink
adding flag to store images
Browse files Browse the repository at this point in the history
  • Loading branch information
gvanhorn38 committed Mar 4, 2018
1 parent 95489d6 commit e25fb36
Showing 1 changed file with 32 additions and 18 deletions.
50 changes: 32 additions & 18 deletions create_tfrecords.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def _process_image(filename, coder):


def _process_image_files_batch(coder, thread_index, ranges, name, output_directory,
dataset, num_shards, error_queue):
dataset, num_shards, store_images, error_queue):
"""Processes and saves list of images as TFRecord in 1 thread.
Args:
coder: instance of ImageCoder to provide TensorFlow image coding utils.
Expand All @@ -230,6 +230,7 @@ def _process_image_files_batch(coder, thread_index, ranges, name, output_directo
output_directory: string, file path to store the tfrecord files.
dataset: list, a list of image example dicts
num_shards: integer number of shards for this data set.
store_images: bool, should the image be stored in the tfrecord
error_queue: Queue, a queue to place image examples that failed.
"""
# Each thread produces N shards where N = int(num_shards / num_threads).
Expand Down Expand Up @@ -262,26 +263,34 @@ def _process_image_files_batch(coder, thread_index, ranges, name, output_directo
filename = str(image_example['filename'])

try:
if 'encoded' in image_example:
image_buffer = image_example['encoded']
height = image_example['height']
width = image_example['width']
colorspace = image_example['colorspace']
image_format = image_example['format']
num_channels = image_example['channels']
example = _convert_to_example(image_example, image_buffer, height,
width, colorspace, num_channels,
image_format)

if store_images:
if 'encoded' in image_example:
image_buffer = image_example['encoded']
height = image_example['height']
width = image_example['width']
colorspace = image_example['colorspace']
image_format = image_example['format']
num_channels = image_example['channels']
example = _convert_to_example(image_example, image_buffer, height,
width, colorspace, num_channels,
image_format)

else:
image_buffer, height, width = _process_image(filename, coder)
example = _convert_to_example(image_example, image_buffer, height,
width)
else:
image_buffer, height, width = _process_image(filename, coder)
image_buffer=''
height = int(image_example['height'])
width = int(image_example['width'])
example = _convert_to_example(image_example, image_buffer, height,
width)
width)

writer.write(example.SerializeToString())
shard_counter += 1
counter += 1
except Exception as e:
raise
error_counter += 1
error_msg = repr(e)
image_example['error_msg'] = error_msg
Expand All @@ -302,7 +311,7 @@ def _process_image_files_batch(coder, thread_index, ranges, name, output_directo
sys.stdout.flush()


def create(dataset, dataset_name, output_directory, num_shards, num_threads, shuffle=True):
def create(dataset, dataset_name, output_directory, num_shards, num_threads, shuffle=True, store_images=True):
"""Create the tfrecord files to be used to train or test a model.
Args:
Expand Down Expand Up @@ -365,7 +374,7 @@ def create(dataset, dataset_name, output_directory, num_shards, num_threads, shu
threads = []
for thread_index in xrange(len(ranges)):
args = (coder, thread_index, ranges, dataset_name, output_directory, dataset,
num_shards, error_queue)
num_shards, store_images, error_queue)
t = threading.Thread(target=_process_image_files_batch, args=args)
t.start()
threads.append(t)
Expand Down Expand Up @@ -410,7 +419,11 @@ def parse_args():

parser.add_argument('--shuffle', dest='shuffle',
help='Shuffle the records before saving them.',
required=False, action='store_true', default=True)
required=False, action='store_true', default=False)

parser.add_argument('--store_images', dest='store_images',
help='Store the images in the tfrecords.',
required=False, action='store_true', default=False)

parsed_args = parser.parse_args()

Expand All @@ -429,7 +442,8 @@ def main():
output_directory=args.output_dir,
num_shards=args.num_shards,
num_threads=args.num_threads,
shuffle=args.shuffle
shuffle=args.shuffle,
store_images=args.store_images
)

return errors
Expand Down

0 comments on commit e25fb36

Please sign in to comment.