Skip to content

Commit

Permalink
Merge pull request #2 from igv/master
Browse files Browse the repository at this point in the history
Pull recent igv commits (2019-10-04)
  • Loading branch information
HelpSeeker authored Oct 4, 2019
2 parents e30e472 + 85c1912 commit 24ba21f
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 116 deletions.
3 changes: 2 additions & 1 deletion FSRCNN.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ def model(self):
deconv_biases = tf.get_variable('deconv_b', initializer=tf.zeros([self.scale**2]))
deconv = tf.nn.conv2d(conv, deconv_weights, strides=[1,1,1,1], padding='SAME', data_format='NHWC')
deconv = tf.nn.bias_add(deconv, deconv_biases, data_format='NHWC')
deconv = tf.depth_to_space(deconv, self.scale, name='pixel_shuffle', data_format='NHWC')
if self.scale > 1:
deconv = tf.depth_to_space(deconv, self.scale, name='pixel_shuffle', data_format='NHWC')

return deconv

Expand Down
62 changes: 33 additions & 29 deletions gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ def format_weights(weights, n, length=4):

def base_header(file):
file.write('//!HOOK LUMA\n')
file.write('//!WHEN OUTPUT.w LUMA.w / {0}.400 > OUTPUT.h LUMA.h / {0}.400 > *\n'.format(scale - 1))
if scale > 1:
file.write('//!WHEN OUTPUT.w LUMA.w / {0}.400 > OUTPUT.h LUMA.h / {0}.400 > *\n'.format(scale - 1))

def header1(file, n, d):
base_header(file)
Expand Down Expand Up @@ -75,8 +76,9 @@ def header5(file, n, d, inp):
file.write('//!DESC sub-pixel convolution {}\n'.format((n//comps) + 1))
for i in range(d//4):
file.write('//!BIND {}{}\n'.format(inp, i + 1))
file.write('//!SAVE SUBCONV{}\n'.format((n//comps) + 1))
file.write('//!COMPONENTS {}\n'.format(comps))
if scale > 1:
file.write('//!SAVE SUBCONV{}\n'.format((n//comps) + 1))
file.write('//!COMPONENTS {}\n'.format(comps))

def header6(file):
base_header(file)
Expand Down Expand Up @@ -219,45 +221,47 @@ def main():
ln = get_line_number("deconv_b", fname)
biases = read_weights(fname, ln)
inp = "EXPANDED" if shrinking else "RES"
comps = 3 if scale == 3 else 4
comps = scale if scale % 2 == 1 else 4
for n in range(0, scale**2, comps):
header5(file, n, d, inp)
file.write('vec4 hook()\n')
file.write('{\n')
file.write('vec{0} res = vec{0}({1});\n'.format(comps, format_weights(biases[0], n, length=comps)))
if scale == 1:
file.write('float res = {};\n'.format(format_weights(biases[0], n, length=comps)))
else:
file.write('vec{0} res = vec{0}({1});\n'.format(comps, format_weights(biases[0], n, length=comps)))
p = 0
for l in range(0, len(weights), 4):
if l % d == 0:
y, x = p%(radius*2+1)-radius, p//(radius*2+1)-radius
p += 1
idx = (l//4)%(d//4)
file.write('res += mat4x{}({},{},{},{}) * {}{}_texOff(vec2({},{}));\n'.format(
comps, format_weights(weights[l], n, length=comps), format_weights(weights[l+1], n, length=comps),
file.write('res += {}{}({},{},{},{}){} {}{}_texOff(vec2({},{})){};\n'.format(
"mat4x" if scale > 1 else "dot(", comps if scale > 1 else "vec4",
format_weights(weights[l], n, length=comps), format_weights(weights[l+1], n, length=comps),
format_weights(weights[l+2], n, length=comps), format_weights(weights[l+3], n, length=comps),
inp, idx + 1, x, y))
if comps == 4:
file.write('return res;\n')
else:
file.write('return vec4(res, 0);\n')
" *" if scale > 1 else ",", inp, idx + 1, x, y, "" if scale > 1 else ")"))
file.write('return vec4(res{});\n'.format(", 0" * (4 - comps)))
file.write('}\n\n')

# Aggregation
header6(file)
file.write('vec4 hook()\n')
file.write('{\n')
file.write('vec2 fcoord = fract(SUBCONV1_pos * SUBCONV1_size);\n')
file.write('vec2 base = SUBCONV1_pos + (vec2(0.5) - fcoord) * SUBCONV1_pt;\n')
file.write('ivec2 index = ivec2(fcoord * vec2({}));\n'.format(scale))
if scale > 2:
file.write('mat{0} res = mat{0}(SUBCONV1_tex(base).{1}'.format(scale, "rgba"[:comps]))
for i in range(scale-1):
file.write(',SUBCONV{}_tex(base).{}'.format(i + 2, "rgba"[:comps]))
file.write(');\n')
file.write('return vec4(res[index.x][index.y], 0, 0, 1);\n')
else:
file.write('vec4 res = SUBCONV1_tex(base);\n')
file.write('return vec4(res[index.x * {} + index.y], 0, 0, 1);\n'.format(scale))
file.write('}\n')
if scale > 1:
# Aggregation
header6(file)
file.write('vec4 hook()\n')
file.write('{\n')
file.write('vec2 fcoord = fract(SUBCONV1_pos * SUBCONV1_size);\n')
file.write('vec2 base = SUBCONV1_pos + (vec2(0.5) - fcoord) * SUBCONV1_pt;\n')
file.write('ivec2 index = ivec2(fcoord * vec2({}));\n'.format(scale))
if scale > 2:
file.write('mat{0} res = mat{0}(SUBCONV1_tex(base).{1}'.format(scale, "rgba"[:comps]))
for i in range(scale-1):
file.write(',SUBCONV{}_tex(base).{}'.format(i + 2, "rgba"[:comps]))
file.write(');\n')
file.write('return vec4(res[index.x][index.y], 0, 0, 1);\n')
else:
file.write('vec4 res = SUBCONV1_tex(base);\n')
file.write('return vec4(res[index.x * {} + index.y], 0, 0, 1);\n'.format(scale))
file.write('}\n')

else:
print("Missing argument: You must specify a file name")
Expand Down
1 change: 0 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
flags.DEFINE_string("output_dir", "result", "Name of test output directory [result]")
flags.DEFINE_string("data_dir", "Train", "Name of data directory to train on [FastTrain]")
flags.DEFINE_boolean("train", True, "True for training, false for testing [True]")
flags.DEFINE_integer("threads", 1, "Number of processes to pre-process data with [1]")
flags.DEFINE_boolean("distort", False, "Distort some images with JPEG compression artifacts after downscaling [False]")
flags.DEFINE_boolean("params", False, "Save weight and bias parameters [False]")

Expand Down
15 changes: 5 additions & 10 deletions model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from utils import (
thread_train_setup,
train_input_setup,
multiprocess_train_setup,
test_input_setup,
save_params,
merge,
Expand Down Expand Up @@ -31,14 +30,13 @@ def __init__(self, sess, config):
self.radius = config.radius
self.batch_size = config.batch_size
self.learning_rate = config.learning_rate
self.threads = config.threads
self.distort = config.distort
self.params = config.params

self.padding = 4
# Different image/label sub-sizes for different scaling factors x2, x3, x4
scale_factors = [[20 + self.padding, 40], [14 + self.padding, 42], [12 + self.padding, 48]]
self.image_size, self.label_size = scale_factors[self.scale - 2]
scale_factors = [[40 + self.padding, 40], [20 + self.padding, 40], [14 + self.padding, 42], [12 + self.padding, 48]]
self.image_size, self.label_size = scale_factors[self.scale - 1]

self.stride = self.image_size - self.padding

Expand Down Expand Up @@ -94,11 +92,8 @@ def run(self):
def run_train(self):
start_time = time.time()
print("Beginning training setup...")
if self.threads == 1:
train_data, train_label = train_input_setup(self)
else:
train_data, train_label = thread_train_setup(self)
print("Training setup took {} seconds with {} threads".format(time.time() - start_time, self.threads))
train_data, train_label = multiprocess_train_setup(self)
print("Training setup took {} seconds".format(time.time() - start_time))

print("Training...")
start_time = time.time()
Expand Down
82 changes: 7 additions & 75 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import tensorflow as tf
from PIL import Image
import numpy as np
from multiprocessing import Pool, Lock, active_children
import multiprocessing

FLAGS = tf.app.flags.FLAGS

Expand Down Expand Up @@ -145,46 +145,19 @@ def train_input_worker(args):

return [single_input_sequence, single_label_sequence]


def thread_train_setup(config):
def multiprocess_train_setup(config):
"""
Spawns |config.threads| worker processes to pre-process the data
This has not been extensively tested so use at your own risk.
Also this is technically multiprocessing not threading, I just say thread
because it's shorter to type.
Spawns several processes to pre-process the data
"""
if downsample == False:
import sys
sys.exit()

sess = config.sess

# Load data path
data = prepare_data(sess, dataset=config.data_dir)

# Initialize multiprocessing pool with # of processes = config.threads
pool = Pool(config.threads)

# Distribute |images_per_thread| images across each worker process
config_values = [config.image_size, config.label_size, config.stride, config.scale, config.padding // 2, config.distort]
images_per_thread = len(data) // config.threads
workers = []
for thread in range(config.threads):
args_list = [(data[i], config_values) for i in range(thread * images_per_thread, (thread + 1) * images_per_thread)]
worker = pool.map_async(train_input_worker, args_list)
workers.append(worker)
print("{} worker processes created".format(config.threads))

pool.close()
data = prepare_data(config.sess, dataset=config.data_dir)

results = []
for i in range(len(workers)):
print("Waiting for worker process {}".format(i))
results.extend(workers[i].get(timeout=240))
print("Worker process {} done".format(i))

print("All worker processes done!")
with multiprocessing.Pool(max(multiprocessing.cpu_count() - 1, 1)) as pool:
config_values = [config.image_size, config.label_size, config.stride, config.scale, config.padding // 2, config.distort]
results = pool.map(train_input_worker, [(data[i], config_values) for i in range(len(data))])

sub_input_sequence, sub_label_sequence = [], []

Expand All @@ -198,47 +171,6 @@ def thread_train_setup(config):

return (arrdata, arrlabel)

def train_input_setup(config):
"""
Read image files, make their sub-images, and save them as a h5 file format.
"""
if downsample == False:
import sys
sys.exit()

sess = config.sess
image_size, label_size, stride, scale, padding = config.image_size, config.label_size, config.stride, config.scale, config.padding // 2

# Load data path
data = prepare_data(sess, dataset=config.data_dir)

sub_input_sequence, sub_label_sequence = [], []

for i in range(len(data)):
input_, label_ = preprocess(data[i], scale, distort=config.distort)

if len(input_.shape) == 3:
h, w, _ = input_.shape
else:
h, w = input_.shape

for x in range(0, h - image_size + 1, stride):
for y in range(0, w - image_size + 1, stride):
sub_input = input_[x : x + image_size, y : y + image_size]
x_loc, y_loc = x + padding, y + padding
sub_label = label_[x_loc * scale : x_loc * scale + label_size, y_loc * scale : y_loc * scale + label_size]

sub_input = sub_input.reshape([image_size, image_size, 1])
sub_label = sub_label.reshape([label_size, label_size, 1])

sub_input_sequence.append(sub_input)
sub_label_sequence.append(sub_label)

arrdata = np.asarray(sub_input_sequence)
arrlabel = np.asarray(sub_label_sequence)

return (arrdata, arrlabel)

def test_input_setup(config):
sess = config.sess

Expand Down

0 comments on commit 24ba21f

Please sign in to comment.