From 8f30c0f221b6c1224a51c42e087227710a06f4ac Mon Sep 17 00:00:00 2001 From: igv Date: Wed, 2 Oct 2019 20:03:22 +0300 Subject: [PATCH 1/2] Support for scale=1 --- FSRCNN.py | 3 ++- gen.py | 62 +++++++++++++++++++++++++++++-------------------------- model.py | 4 ++-- 3 files changed, 37 insertions(+), 32 deletions(-) diff --git a/FSRCNN.py b/FSRCNN.py index f12435f..93acbac 100644 --- a/FSRCNN.py +++ b/FSRCNN.py @@ -72,7 +72,8 @@ def model(self): deconv_biases = tf.get_variable('deconv_b', initializer=tf.zeros([self.scale**2])) deconv = tf.nn.conv2d(conv, deconv_weights, strides=[1,1,1,1], padding='SAME', data_format='NHWC') deconv = tf.nn.bias_add(deconv, deconv_biases, data_format='NHWC') - deconv = tf.depth_to_space(deconv, self.scale, name='pixel_shuffle', data_format='NHWC') + if self.scale > 1: + deconv = tf.depth_to_space(deconv, self.scale, name='pixel_shuffle', data_format='NHWC') return deconv diff --git a/gen.py b/gen.py index 9b1b5cd..f6c538b 100644 --- a/gen.py +++ b/gen.py @@ -28,7 +28,8 @@ def format_weights(weights, n, length=4): def base_header(file): file.write('//!HOOK LUMA\n') - file.write('//!WHEN OUTPUT.w LUMA.w / {0}.400 > OUTPUT.h LUMA.h / {0}.400 > *\n'.format(scale - 1)) + if scale > 1: + file.write('//!WHEN OUTPUT.w LUMA.w / {0}.400 > OUTPUT.h LUMA.h / {0}.400 > *\n'.format(scale - 1)) def header1(file, n, d): base_header(file) @@ -75,8 +76,9 @@ def header5(file, n, d, inp): file.write('//!DESC sub-pixel convolution {}\n'.format((n//comps) + 1)) for i in range(d//4): file.write('//!BIND {}{}\n'.format(inp, i + 1)) - file.write('//!SAVE SUBCONV{}\n'.format((n//comps) + 1)) - file.write('//!COMPONENTS {}\n'.format(comps)) + if scale > 1: + file.write('//!SAVE SUBCONV{}\n'.format((n//comps) + 1)) + file.write('//!COMPONENTS {}\n'.format(comps)) def header6(file): base_header(file) @@ -219,45 +221,47 @@ def main(): ln = get_line_number("deconv_b", fname) biases = read_weights(fname, ln) inp = "EXPANDED" if shrinking else "RES" - comps = 3 if scale == 3 else 4 + comps = scale if scale % 2 == 1 else 4 for n in range(0, scale**2, comps): header5(file, n, d, inp) file.write('vec4 hook()\n') file.write('{\n') - file.write('vec{0} res = vec{0}({1});\n'.format(comps, format_weights(biases[0], n, length=comps))) + if scale == 1: + file.write('float res = {};\n'.format(format_weights(biases[0], n, length=comps))) + else: + file.write('vec{0} res = vec{0}({1});\n'.format(comps, format_weights(biases[0], n, length=comps))) p = 0 for l in range(0, len(weights), 4): if l % d == 0: y, x = p%(radius*2+1)-radius, p//(radius*2+1)-radius p += 1 idx = (l//4)%(d//4) - file.write('res += mat4x{}({},{},{},{}) * {}{}_texOff(vec2({},{}));\n'.format( - comps, format_weights(weights[l], n, length=comps), format_weights(weights[l+1], n, length=comps), + file.write('res += {}{}({},{},{},{}){} {}{}_texOff(vec2({},{})){};\n'.format( + "mat4x" if scale > 1 else "dot(", comps if scale > 1 else "vec4", + format_weights(weights[l], n, length=comps), format_weights(weights[l+1], n, length=comps), format_weights(weights[l+2], n, length=comps), format_weights(weights[l+3], n, length=comps), - inp, idx + 1, x, y)) - if comps == 4: - file.write('return res;\n') - else: - file.write('return vec4(res, 0);\n') + " *" if scale > 1 else ",", inp, idx + 1, x, y, "" if scale > 1 else ")")) + file.write('return vec4(res{});\n'.format(", 0" * (4 - comps))) file.write('}\n\n') - # Aggregation - header6(file) - file.write('vec4 hook()\n') - file.write('{\n') - file.write('vec2 fcoord = fract(SUBCONV1_pos * SUBCONV1_size);\n') - file.write('vec2 base = SUBCONV1_pos + (vec2(0.5) - fcoord) * SUBCONV1_pt;\n') - file.write('ivec2 index = ivec2(fcoord * vec2({}));\n'.format(scale)) - if scale > 2: - file.write('mat{0} res = mat{0}(SUBCONV1_tex(base).{1}'.format(scale, "rgba"[:comps])) - for i in range(scale-1): - file.write(',SUBCONV{}_tex(base).{}'.format(i + 2, "rgba"[:comps])) - file.write(');\n') - file.write('return vec4(res[index.x][index.y], 0, 0, 1);\n') - else: - file.write('vec4 res = SUBCONV1_tex(base);\n') - file.write('return vec4(res[index.x * {} + index.y], 0, 0, 1);\n'.format(scale)) - file.write('}\n') + if scale > 1: + # Aggregation + header6(file) + file.write('vec4 hook()\n') + file.write('{\n') + file.write('vec2 fcoord = fract(SUBCONV1_pos * SUBCONV1_size);\n') + file.write('vec2 base = SUBCONV1_pos + (vec2(0.5) - fcoord) * SUBCONV1_pt;\n') + file.write('ivec2 index = ivec2(fcoord * vec2({}));\n'.format(scale)) + if scale > 2: + file.write('mat{0} res = mat{0}(SUBCONV1_tex(base).{1}'.format(scale, "rgba"[:comps])) + for i in range(scale-1): + file.write(',SUBCONV{}_tex(base).{}'.format(i + 2, "rgba"[:comps])) + file.write(');\n') + file.write('return vec4(res[index.x][index.y], 0, 0, 1);\n') + else: + file.write('vec4 res = SUBCONV1_tex(base);\n') + file.write('return vec4(res[index.x * {} + index.y], 0, 0, 1);\n'.format(scale)) + file.write('}\n') else: print("Missing argument: You must specify a file name") diff --git a/model.py b/model.py index 5db5488..7cfb375 100644 --- a/model.py +++ b/model.py @@ -37,8 +37,8 @@ def __init__(self, sess, config): self.padding = 4 # Different image/label sub-sizes for different scaling factors x2, x3, x4 - scale_factors = [[20 + self.padding, 40], [14 + self.padding, 42], [12 + self.padding, 48]] - self.image_size, self.label_size = scale_factors[self.scale - 2] + scale_factors = [[40 + self.padding, 40], [20 + self.padding, 40], [14 + self.padding, 42], [12 + self.padding, 48]] + self.image_size, self.label_size = scale_factors[self.scale - 1] self.stride = self.image_size - self.padding From 85c1912989fbf606cee79918c23e116c9657f68f Mon Sep 17 00:00:00 2001 From: igv Date: Wed, 2 Oct 2019 23:25:59 +0300 Subject: [PATCH 2/2] Always use multiple processes for pre-processing --- main.py | 1 - model.py | 11 +++----- utils.py | 82 +++++--------------------------------------------------- 3 files changed, 10 insertions(+), 84 deletions(-) diff --git a/main.py b/main.py index 97a03d3..be7ba29 100644 --- a/main.py +++ b/main.py @@ -18,7 +18,6 @@ flags.DEFINE_string("output_dir", "result", "Name of test output directory [result]") flags.DEFINE_string("data_dir", "Train", "Name of data directory to train on [FastTrain]") flags.DEFINE_boolean("train", True, "True for training, false for testing [True]") -flags.DEFINE_integer("threads", 1, "Number of processes to pre-process data with [1]") flags.DEFINE_boolean("distort", False, "Distort some images with JPEG compression artifacts after downscaling [False]") flags.DEFINE_boolean("params", False, "Save weight and bias parameters [False]") diff --git a/model.py b/model.py index 7cfb375..c4c3240 100644 --- a/model.py +++ b/model.py @@ -1,6 +1,5 @@ from utils import ( - thread_train_setup, - train_input_setup, + multiprocess_train_setup, test_input_setup, save_params, merge, @@ -31,7 +30,6 @@ def __init__(self, sess, config): self.radius = config.radius self.batch_size = config.batch_size self.learning_rate = config.learning_rate - self.threads = config.threads self.distort = config.distort self.params = config.params @@ -94,11 +92,8 @@ def run(self): def run_train(self): start_time = time.time() print("Beginning training setup...") - if self.threads == 1: - train_data, train_label = train_input_setup(self) - else: - train_data, train_label = thread_train_setup(self) - print("Training setup took {} seconds with {} threads".format(time.time() - start_time, self.threads)) + train_data, train_label = multiprocess_train_setup(self) + print("Training setup took {} seconds".format(time.time() - start_time)) print("Training...") start_time = time.time() diff --git a/utils.py b/utils.py index 94382a7..2663281 100644 --- a/utils.py +++ b/utils.py @@ -12,7 +12,7 @@ import tensorflow as tf from PIL import Image import numpy as np -from multiprocessing import Pool, Lock, active_children +import multiprocessing FLAGS = tf.app.flags.FLAGS @@ -145,46 +145,19 @@ def train_input_worker(args): return [single_input_sequence, single_label_sequence] - -def thread_train_setup(config): +def multiprocess_train_setup(config): """ - Spawns |config.threads| worker processes to pre-process the data - - This has not been extensively tested so use at your own risk. - Also this is technically multiprocessing not threading, I just say thread - because it's shorter to type. + Spawns several processes to pre-process the data """ if downsample == False: import sys sys.exit() - sess = config.sess - - # Load data path - data = prepare_data(sess, dataset=config.data_dir) - - # Initialize multiprocessing pool with # of processes = config.threads - pool = Pool(config.threads) - - # Distribute |images_per_thread| images across each worker process - config_values = [config.image_size, config.label_size, config.stride, config.scale, config.padding // 2, config.distort] - images_per_thread = len(data) // config.threads - workers = [] - for thread in range(config.threads): - args_list = [(data[i], config_values) for i in range(thread * images_per_thread, (thread + 1) * images_per_thread)] - worker = pool.map_async(train_input_worker, args_list) - workers.append(worker) - print("{} worker processes created".format(config.threads)) - - pool.close() + data = prepare_data(config.sess, dataset=config.data_dir) - results = [] - for i in range(len(workers)): - print("Waiting for worker process {}".format(i)) - results.extend(workers[i].get(timeout=240)) - print("Worker process {} done".format(i)) - - print("All worker processes done!") + with multiprocessing.Pool(max(multiprocessing.cpu_count() - 1, 1)) as pool: + config_values = [config.image_size, config.label_size, config.stride, config.scale, config.padding // 2, config.distort] + results = pool.map(train_input_worker, [(data[i], config_values) for i in range(len(data))]) sub_input_sequence, sub_label_sequence = [], [] @@ -198,47 +171,6 @@ def thread_train_setup(config): return (arrdata, arrlabel) -def train_input_setup(config): - """ - Read image files, make their sub-images, and save them as a h5 file format. - """ - if downsample == False: - import sys - sys.exit() - - sess = config.sess - image_size, label_size, stride, scale, padding = config.image_size, config.label_size, config.stride, config.scale, config.padding // 2 - - # Load data path - data = prepare_data(sess, dataset=config.data_dir) - - sub_input_sequence, sub_label_sequence = [], [] - - for i in range(len(data)): - input_, label_ = preprocess(data[i], scale, distort=config.distort) - - if len(input_.shape) == 3: - h, w, _ = input_.shape - else: - h, w = input_.shape - - for x in range(0, h - image_size + 1, stride): - for y in range(0, w - image_size + 1, stride): - sub_input = input_[x : x + image_size, y : y + image_size] - x_loc, y_loc = x + padding, y + padding - sub_label = label_[x_loc * scale : x_loc * scale + label_size, y_loc * scale : y_loc * scale + label_size] - - sub_input = sub_input.reshape([image_size, image_size, 1]) - sub_label = sub_label.reshape([label_size, label_size, 1]) - - sub_input_sequence.append(sub_input) - sub_label_sequence.append(sub_label) - - arrdata = np.asarray(sub_input_sequence) - arrlabel = np.asarray(sub_label_sequence) - - return (arrdata, arrlabel) - def test_input_setup(config): sess = config.sess