Skip to content

Commit

Permalink
Always use multiple processes for pre-processing
Browse files Browse the repository at this point in the history
  • Loading branch information
igv committed Oct 4, 2019
1 parent 8f30c0f commit 85c1912
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 84 deletions.
1 change: 0 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
flags.DEFINE_string("output_dir", "result", "Name of test output directory [result]")
flags.DEFINE_string("data_dir", "Train", "Name of data directory to train on [FastTrain]")
flags.DEFINE_boolean("train", True, "True for training, false for testing [True]")
flags.DEFINE_integer("threads", 1, "Number of processes to pre-process data with [1]")
flags.DEFINE_boolean("distort", False, "Distort some images with JPEG compression artifacts after downscaling [False]")
flags.DEFINE_boolean("params", False, "Save weight and bias parameters [False]")

Expand Down
11 changes: 3 additions & 8 deletions model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from utils import (
thread_train_setup,
train_input_setup,
multiprocess_train_setup,
test_input_setup,
save_params,
merge,
Expand Down Expand Up @@ -31,7 +30,6 @@ def __init__(self, sess, config):
self.radius = config.radius
self.batch_size = config.batch_size
self.learning_rate = config.learning_rate
self.threads = config.threads
self.distort = config.distort
self.params = config.params

Expand Down Expand Up @@ -94,11 +92,8 @@ def run(self):
def run_train(self):
start_time = time.time()
print("Beginning training setup...")
if self.threads == 1:
train_data, train_label = train_input_setup(self)
else:
train_data, train_label = thread_train_setup(self)
print("Training setup took {} seconds with {} threads".format(time.time() - start_time, self.threads))
train_data, train_label = multiprocess_train_setup(self)
print("Training setup took {} seconds".format(time.time() - start_time))

print("Training...")
start_time = time.time()
Expand Down
82 changes: 7 additions & 75 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import tensorflow as tf
from PIL import Image
import numpy as np
from multiprocessing import Pool, Lock, active_children
import multiprocessing

FLAGS = tf.app.flags.FLAGS

Expand Down Expand Up @@ -145,46 +145,19 @@ def train_input_worker(args):

return [single_input_sequence, single_label_sequence]


def thread_train_setup(config):
def multiprocess_train_setup(config):
"""
Spawns |config.threads| worker processes to pre-process the data
This has not been extensively tested so use at your own risk.
Also this is technically multiprocessing not threading, I just say thread
because it's shorter to type.
Spawns several processes to pre-process the data
"""
if downsample == False:
import sys
sys.exit()

sess = config.sess

# Load data path
data = prepare_data(sess, dataset=config.data_dir)

# Initialize multiprocessing pool with # of processes = config.threads
pool = Pool(config.threads)

# Distribute |images_per_thread| images across each worker process
config_values = [config.image_size, config.label_size, config.stride, config.scale, config.padding // 2, config.distort]
images_per_thread = len(data) // config.threads
workers = []
for thread in range(config.threads):
args_list = [(data[i], config_values) for i in range(thread * images_per_thread, (thread + 1) * images_per_thread)]
worker = pool.map_async(train_input_worker, args_list)
workers.append(worker)
print("{} worker processes created".format(config.threads))

pool.close()
data = prepare_data(config.sess, dataset=config.data_dir)

results = []
for i in range(len(workers)):
print("Waiting for worker process {}".format(i))
results.extend(workers[i].get(timeout=240))
print("Worker process {} done".format(i))

print("All worker processes done!")
with multiprocessing.Pool(max(multiprocessing.cpu_count() - 1, 1)) as pool:
config_values = [config.image_size, config.label_size, config.stride, config.scale, config.padding // 2, config.distort]
results = pool.map(train_input_worker, [(data[i], config_values) for i in range(len(data))])

sub_input_sequence, sub_label_sequence = [], []

Expand All @@ -198,47 +171,6 @@ def thread_train_setup(config):

return (arrdata, arrlabel)

def train_input_setup(config):
"""
Read image files, make their sub-images, and save them as a h5 file format.
"""
if downsample == False:
import sys
sys.exit()

sess = config.sess
image_size, label_size, stride, scale, padding = config.image_size, config.label_size, config.stride, config.scale, config.padding // 2

# Load data path
data = prepare_data(sess, dataset=config.data_dir)

sub_input_sequence, sub_label_sequence = [], []

for i in range(len(data)):
input_, label_ = preprocess(data[i], scale, distort=config.distort)

if len(input_.shape) == 3:
h, w, _ = input_.shape
else:
h, w = input_.shape

for x in range(0, h - image_size + 1, stride):
for y in range(0, w - image_size + 1, stride):
sub_input = input_[x : x + image_size, y : y + image_size]
x_loc, y_loc = x + padding, y + padding
sub_label = label_[x_loc * scale : x_loc * scale + label_size, y_loc * scale : y_loc * scale + label_size]

sub_input = sub_input.reshape([image_size, image_size, 1])
sub_label = sub_label.reshape([label_size, label_size, 1])

sub_input_sequence.append(sub_input)
sub_label_sequence.append(sub_label)

arrdata = np.asarray(sub_input_sequence)
arrlabel = np.asarray(sub_label_sequence)

return (arrdata, arrlabel)

def test_input_setup(config):
sess = config.sess

Expand Down

0 comments on commit 85c1912

Please sign in to comment.