diff --git a/code/constants.py b/code/constants.py new file mode 100644 index 0000000..e0e5b51 --- /dev/null +++ b/code/constants.py @@ -0,0 +1,9 @@ +factors = [1, 2, 4] +pads = [32, 64, 128, 256, 512] +gen_scope = "gen" +dis_scope = "dis" +outputs_prefix = "output_" +lr_key = "lr" +hr_key = "hr" +lr_input_name = "lr_input" +hr_input_name = "hr_input" \ No newline at end of file diff --git a/code/network.py b/code/network.py new file mode 100644 index 0000000..a3d36a4 --- /dev/null +++ b/code/network.py @@ -0,0 +1,141 @@ +import tensorflow as tf +import numpy as np +from functools import partial +from ops import conv, lrelu, conv2d_downscale2d, upscale2d_conv2d, blur2d, pixel_norm +from ops import minibatch_stddev_layer + + +def discriminator(hr_images, scope, dim): + """ + Discriminator + """ + conv_lrelu = partial(conv, activation_fn=lrelu) + + def _combine(x, newdim, name, z=None): + x = conv_lrelu(x, newdim, 1, 1, name) + y = x if z is None else tf.concat([x, z], axis=-1) + return minibatch_stddev_layer(y) + + def _conv_downsample(x, dim, ksize, name): + y = conv2d_downscale2d(x, dim, ksize, name=name) + y = lrelu(y) + return y + + with tf.compat.v1.variable_scope(scope, reuse=tf.compat.v1.AUTO_REUSE): + with tf.compat.v1.variable_scope("res_4x"): + net = _combine(hr_images[1], newdim=dim, name="from_input") + net = conv_lrelu(net, dim, 3, 1, "conv1") + net = conv_lrelu(net, dim, 3, 1, "conv2") + net = conv_lrelu(net, dim, 3, 1, "conv3") + net = _conv_downsample(net, dim, 3, "conv_down") + + with tf.compat.v1.variable_scope("res_2x"): + net = _combine(hr_images[2], newdim=dim, name="from_input", z=net) + dim *= 2 + net = conv_lrelu(net, dim, 3, 1, "conv1") + net = conv_lrelu(net, dim, 3, 1, "conv2") + net = conv_lrelu(net, dim, 3, 1, "conv3") + net = _conv_downsample(net, dim, 3, "conv_down") + + with tf.compat.v1.variable_scope("res_1x"): + net = _combine(hr_images[4], newdim=dim, name="from_input", z=net) + dim *= 2 + net = conv_lrelu(net, dim, 3, 1, "conv") + net = _conv_downsample(net, dim, 3, "conv_down") + + with tf.compat.v1.variable_scope("bn"): + dim *= 2 + net = conv_lrelu(net, dim, 3, 1, "conv1") + net = _conv_downsample(net, dim, 3, "conv_down1") + net = minibatch_stddev_layer(net) + + # dense + dim *= 2 + net = conv_lrelu(net, dim, 1, 1, "dense1") + net = conv(net, 1, 1, 1, "dense2") + net = tf.reduce_mean(net, axis=[1, 2]) + + return net + + +def generator(lr_image, scope, nchannels, nresblocks, dim): + """ + Generator + """ + hr_images = dict() + + def conv_upsample(x, dim, ksize, name): + y = upscale2d_conv2d(x, dim, ksize, name) + y = blur2d(y) + y = lrelu(y) + y = pixel_norm(y) + return y + + def _residule_block(x, dim, name): + with tf.compat.v1.variable_scope(name): + y = conv(x, dim, 3, 1, "conv1") + y = lrelu(y) + y = pixel_norm(y) + y = conv(y, dim, 3, 1, "conv2") + y = pixel_norm(y) + return y + x + + def conv_bn(x, dim, ksize, name): + y = conv(x, dim, ksize, 1, name) + y = lrelu(y) + y = pixel_norm(y) + return y + + def _make_output(net, factor): + hr_images[factor] = conv(net, nchannels, 1, 1, "output") + + with tf.compat.v1.variable_scope(scope, reuse=tf.compat.v1.AUTO_REUSE): + with tf.compat.v1.variable_scope("encoder"): + net = lrelu(conv(lr_image, dim, 9, 1, "conv1_9x9")) + conv1 = net + for i in range(nresblocks): + net = _residule_block(net, dim=dim, name="ResBlock{}".format(i)) + + with tf.compat.v1.variable_scope("res_1x"): + net = conv(net, dim, 3, 1, "conv1") + net = pixel_norm(net) + net += conv1 + _make_output(net, factor=4) + + with tf.compat.v1.variable_scope("res_2x"): + net = conv_upsample(net, 4 * dim, 3, "conv_upsample") + net = conv_bn(net, 4 * dim, 3, "conv1") + net = conv_bn(net, 4 * dim, 3, "conv2") + net = conv_bn(net, 4 * dim, 5, "conv3") + _make_output(net, factor=2) + + with tf.compat.v1.variable_scope("res_4x"): + net = conv_upsample(net, 4 * dim, 3, "conv_upsample") + net = conv_bn(net, 4 * dim, 3, "conv1") + net = conv_bn(net, 4 * dim, 3, "conv2") + net = conv_bn(net, 4 * dim, 9, "conv3") + _make_output(net, factor=1) + + return hr_images + + +def nice_preview(x, refs): + """ + Beautiful previews + Keep only first 3 bands --> RGB + """ + bands = [0, 1, 2] + + _mean = np.zeros(3) + _std = np.zeros(3) + _ninv = 1.0 / float(len(refs)) + for ref in refs: + _mean += _ninv * np.asarray([np.mean(ref[0, :, :, i]) for i in bands]) + _std += _ninv * np.asarray([np.std(ref[0, :, :, i]) for i in bands]) + _min = [__mean - 2 * __std for __mean, __std in zip(_mean, _std)] + _max = [__mean + 2 * __std for __mean, __std in zip(_mean, _std)] + return tf.cast(255 * tf.stack( + [1.0 / (__max - __min) * (tf.clip_by_value(x[:, :, :, i], __min, __max) - __min) for i, __min, __max in + zip(bands, _min, _max)], + axis=3), tf.uint8) + diff --git a/code/ops.py b/code/ops.py new file mode 100644 index 0000000..c75590e --- /dev/null +++ b/code/ops.py @@ -0,0 +1,205 @@ +import tensorflow as tf +from functools import partial +import numpy as np + +lrelu = partial(tf.nn.leaky_relu, alpha=0.2) + + +def get_weight(shape, gain=np.sqrt(2), lrmul=1): + """ + Get weight + """ + fan_in = np.prod(shape[:-1]) + he_std = gain / np.sqrt(fan_in) # He init + + # Equalized learning rate and custom learning rate multiplier. + init_std = 1.0 / lrmul + runtime_coef = he_std * lrmul + + init = tf.initializers.random_normal(0, init_std) + return tf.compat.v1.get_variable("weight", shape=shape, initializer=init) * runtime_coef + + +def apply_bias(x): + """ + Apply bias + """ + b = tf.compat.v1.get_variable('bias', shape=[x.shape[-1]], initializer=tf.keras.initializers.zeros()) + b = tf.cast(b, x.dtype) + if len(x.shape) == 2: + return x + b + else: + return x + tf.reshape(b, [1, 1, 1, -1]) + + +def conv_base(x, fmaps, kernel_size, stride, name, gain=np.sqrt(2), activation_fn=None, normalizer_fn=None, + transpose=False, padding='SAME'): + """ + Convolutional layer base. + """ + assert (isinstance(name, str)) + with tf.compat.v1.variable_scope(name): + strides = [1, stride, stride, 1] + if not transpose: + w = get_weight([kernel_size, kernel_size, x.shape[3].value, fmaps], gain=gain) + w = tf.cast(w, x.dtype) + out = tf.nn.conv2d(x, filter=w, strides=strides, padding=padding) + else: + sz0 = tf.shape(x)[0] + sz1 = tf.shape(x)[1] + sz2 = tf.shape(x)[2] + output_shape = [sz0, stride * sz1, stride * sz2, fmaps] + w = get_weight([kernel_size, kernel_size, fmaps, x.shape[3].value], gain=gain) + w = tf.cast(w, x.dtype) + out = tf.nn.conv2d_transpose(x, filter=w, output_shape=output_shape, strides=strides, padding=padding) + out = apply_bias(out) + if activation_fn is not None: + out = activation_fn(out) + if normalizer_fn is not None: + out = normalizer_fn(out) + return out + + +conv = partial(conv_base, transpose=False) +deconv = partial(conv_base, transpose=True) + + +def _blur2d(x, stride=1, flip=False): + """ + Blur an image tensor + """ + f = [1, 2, 1] + f = np.array(f, dtype=np.float32) + f = f[:, np.newaxis] * f[np.newaxis, :] + f /= np.sum(f) + if flip: + f = f[::-1, ::-1] + f = f[:, :, np.newaxis, np.newaxis] + f = np.tile(f, [1, 1, int(x.shape[-1]), 1]) + f = tf.constant(f, dtype=tf.float32, name="filter_blur2d") + strides = [1, stride, stride, 1] + return tf.nn.depthwise_conv2d(x, f, strides=strides, padding="SAME") + + +def minibatch_stddev_layer(x, group_size=4): + with tf.compat.v1.variable_scope('MinibatchStd'): + group_size = tf.minimum(group_size, tf.shape(x)[0]) + sz1 = tf.shape(x)[1] + sz2 = tf.shape(x)[2] + sz3 = tf.shape(x)[3] + y = tf.reshape(x, [group_size, -1, sz1, sz2, sz3]) + y = tf.cast(y, tf.float32) + y -= tf.reduce_mean(y, axis=0, keepdims=True) + y = tf.reduce_mean(tf.square(y), axis=0) + y = tf.sqrt(y + 1e-8) + y = tf.reduce_mean(y, axis=[1, 2, 3], keepdims=True) + y = tf.cast(y, x.dtype) + y = tf.tile(y, [group_size, sz1, sz2, 1]) + return tf.concat([x, y], axis=3) + + +def pixel_norm(x): + with tf.compat.v1.variable_scope("PixelNorm"): + epsilon = tf.constant(1e-8, dtype=x.dtype, name="epsilon") + return x * tf.math.rsqrt(tf.reduce_mean(tf.square(x), axis=-1, keepdims=True) + epsilon) + + +def conv2d_downscale2d(x, fmaps, kernel, name): + assert kernel >= 1 and kernel % 2 == 1 + with tf.compat.v1.variable_scope(name): + w = get_weight([kernel, kernel, x.shape[3].value, fmaps]) + w = tf.pad(w, [[1, 1], [1, 1], [0, 0], [0, 0]], mode="CONSTANT") + w = tf.add_n([w[1:, 1:], w[:-1, 1:], w[1:, :-1], w[:-1, :-1]]) * 0.25 + w = tf.cast(w, x.dtype) + return tf.nn.conv2d(x, w, strides=[1, 2, 2, 1], padding="SAME") + + +def upscale2d_conv2d(x, fmaps, kernel, name): + assert kernel >= 1 and kernel % 2 == 1 + with tf.compat.v1.variable_scope(name): + w = get_weight([kernel, kernel, x.shape[3].value, fmaps]) + w = tf.transpose(w, [0, 1, 3, 2]) + w = tf.pad(w, [[1, 1], [1, 1], [0, 0], [0, 0]], mode="CONSTANT") + w = tf.add_n([w[1:, 1:], w[:-1, 1:], w[1:, :-1], w[:-1, :-1]]) + w = tf.cast(w, x.dtype) + sz0 = tf.shape(x)[0] + sz1 = tf.shape(x)[1] + sz2 = tf.shape(x)[2] + output_shape = [sz0, 2 * sz1, 2 * sz2, fmaps] + return tf.nn.conv2d_transpose(x, w, output_shape=output_shape, strides=[1, 2, 2, 1], padding="SAME") + + +def _upscale2d(x, factor=2, gain=1): + if gain != 1: + x *= gain + + if factor == 1: + return x + + s = tf.shape(x) + x = tf.reshape(x, [-1, s[1], 1, s[2], 1, s[3]]) + x = tf.tile(x, [1, 1, factor, 1, factor, 1]) + x = tf.reshape(x, [-1, factor * s[1], factor * s[2], s[3]]) + return x + + +def _downscale2d(x, factor=2, gain=1): + if gain != 1: + x *= gain + + if factor == 2 and x.dtype == tf.float32: + return _blur2d(x, stride=factor) + + if factor == 1: + return x + + ksize = [1, factor, factor, 1] + return tf.nn.avg_pool2d(x, ksize=ksize, strides=ksize, padding="VALID") + + +def blur2d(x): + with tf.compat.v1.variable_scope("Blur2D"): + @tf.custom_gradient + def func(in_x): + y = _blur2d(in_x) + + @tf.custom_gradient + def grad(dy): + dx = _blur2d(dy, flip=True) + return dx, lambda ddx: _blur2d(ddx) + + return y, grad + + return func(x) + + +def upscale2d(x, factor=2): + with tf.compat.v1.variable_scope("Upscale2D"): + @tf.custom_gradient + def func(in_x): + y = _upscale2d(in_x, factor) + + @tf.custom_gradient + def grad(dy): + dx = _downscale2d(dy, factor, gain=factor ** 2) + return dx, lambda ddx: _upscale2d(ddx, factor) + + return y, grad + + return func(x) + + +def downscale2d(x, factor=2): + with tf.compat.v1.variable_scope("Downscale2D"): + @tf.custom_gradient + def func(in_x): + y = _downscale2d(in_x, factor) + + @tf.custom_gradient + def grad(dy): + dx = _upscale2d(dy, factor, gain=1 / factor ** 2) + return dx, lambda ddx: _downscale2d(ddx, factor) + + return y, grad + + return func(x) \ No newline at end of file diff --git a/code/sr.py b/code/sr.py new file mode 100644 index 0000000..58e9f03 --- /dev/null +++ b/code/sr.py @@ -0,0 +1,74 @@ +import argparse +import otbApplication +import constants +import logging + +logging.basicConfig(format='%(asctime)s %(levelname)-8s %(message)s', level=logging.WARNING, + datefmt='%Y-%m-%d %H:%M:%S') + +parser = argparse.ArgumentParser() + + +def get_encoding(): + """ + Get the encoding of input image pixels + """ + infos = otbApplication.Registry.CreateApplication('ReadImageInfo') + infos.SetParameterString("in", params.input) + infos.Execute() + return infos.GetImageBasePixelType("in") + + +encodings = {"auto": get_encoding, + "uint8": lambda x: otbApplication.ImagePixelType_uint8, + "uint16": lambda x: otbApplication.ImagePixelType_uint16, + "int16": lambda x: otbApplication.ImagePixelType_int16, + "float": lambda x: otbApplication.ImagePixelType_float} + +parser.add_argument("--input", help="Input LR image ", required=True) +parser.add_argument("--savedmodel", help="Input SavedModel ", required=True) +parser.add_argument("--output", help="Output HR image", required=True) +parser.add_argument('--encoding', type=str, default="auto", const="auto", nargs="?", choices=encodings.keys(), + help="Output HR image encoding") +parser.add_argument('--pad', type=int, default=64, const=64, nargs="?", choices=constants.pads, + help="Margin size for blocking artefacts removal") +parser.add_argument('--ts', default=512, type=int, help="Tile size") +params = parser.parse_args() + + +if __name__ == "__main__": + + gen_fcn = params.pad + efield = params.ts # OTBTF expression field + if efield % min(constants.factors) != 0: + logging.fatal("Please chose a tile size that is consistent with the network.") + quit() + ratio = 1.0 / float(max(constants.factors)) # OTBTF Spacing ratio + rfield = int((efield + 2 * gen_fcn) * ratio) # OTBTF receptive field + + # pixel encoding + encoding_fn = encodings[params.encoding] + encoding = encoding_fn() + + # call otbtf + logging.info("Receptive field: {}, Expression field: {}".format(rfield, efield)) + ph = "{}{}".format(constants.outputs_prefix, params.pad) + infer = otbApplication.Registry.CreateApplication("TensorflowModelServe") + infer.SetParameterStringList("source1.il", [params.input]) + infer.SetParameterInt("source1.rfieldx", rfield) + infer.SetParameterInt("source1.rfieldy", rfield) + infer.SetParameterString("source1.placeholder", constants.lr_input_name) + infer.SetParameterString("model.dir", params.savedmodel) + infer.SetParameterString("model.fullyconv", "on") + infer.SetParameterStringList("output.names", [ph]) + infer.SetParameterInt("output.efieldx", efield) + infer.SetParameterInt("output.efieldy", efield) + infer.SetParameterFloat("output.spcscale", ratio) + infer.SetParameterInt("optim.tilesizex", efield) + infer.SetParameterInt("optim.tilesizey", efield) + infer.SetParameterInt("optim.disabletiling", 1) + out_fn = "{}{}&gdal:co:COMPRESS=DEFLATE".format(params.output, "?" if "?" not in params.output else "") + out_fn += "&streaming:type=tiled&streaming:sizemode=height&streaming:sizevalue={}".format(efield) + infer.SetParameterString("out", out_fn) + infer.SetParameterOutputImagePixelType("out", encoding) + infer.ExecuteAndWriteOutput() diff --git a/code/train.py b/code/train.py new file mode 100644 index 0000000..df4c759 --- /dev/null +++ b/code/train.py @@ -0,0 +1,290 @@ +# Imports +import tensorflow as tf +import datetime +import argparse +from functools import partial +import otbtf +import logging +from ops import blur2d, downscale2d +from vgg import compute_vgg_loss +import network +import constants + +logging.basicConfig(format='%(asctime)s %(levelname)-8s %(message)s', level=logging.WARNING, + datefmt='%Y-%m-%d %H:%M:%S') + +tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) + +parser = argparse.ArgumentParser() + +# Paths +parser.add_argument("--lr_patches", help="LR patches images list", required=True, nargs='+', default=[]) +parser.add_argument("--hr_patches", help="HR patches images list", required=True, nargs='+', default=[]) +parser.add_argument("--preview", help="LR image for preview") +parser.add_argument("--logdir", help="output log directory") +parser.add_argument("--save_ckpt", help="save the checkpoint", required=True) +parser.add_argument("--load_ckpt", help="start from a given checkpoint") +parser.add_argument("--savedmodel", help="create a SavedModel at the end") +parser.add_argument("--vggfile", help="vgg19.npy file") + +# Images scaling +parser.add_argument("--lr_scale", type=float, default=0.0001, help="LR image scaling") +parser.add_argument("--hr_scale", type=float, default=0.0001, help="HR image scaling") + +# Parameters +parser.add_argument("--previews_step", type=int, default=200, help="Number of steps between each preview summary") +parser.add_argument("--depth", type=int, default=64, help="Generator and discriminator depth") +parser.add_argument("--nresblocks", type=int, default=16, help="Number of ResNet blocks in Generator") +parser.add_argument("--epochs", type=int, default=120, help="number of epochs") +parser.add_argument("--batchsize", type=int, default=4, help="batch size") +parser.add_argument("--adam_lr", type=float, default=0.001, help="Adam learning rate") +parser.add_argument("--adam_b1", type=float, default=0.0, help="Adam beta1") +parser.add_argument("--l1weight", type=float, default=0, help="L1 loss weight") +parser.add_argument("--l2weight", type=float, default=100.0, help="L2 loss weight") +parser.add_argument("--vggweight", type=float, default=1.0, help="VGG loss weight") +parser.add_argument('--vggfeatures', default="vgg54", const="vgg54", nargs="?", + choices=["vgg54", "vgg54lin", "vgg344454", "1234lin", "1234"]) +parser.add_argument("--ganweight", type=float, default=1.0, help="GAN loss weight") +parser.add_argument('--losstype', default="WGAN-GP", const="WGAN-GP", nargs="?", + choices=["WGAN-GP", "LSGAN"], help="GAN loss type") +parser.add_argument('--streaming', dest='streaming', action='store_true', + help="Streaming reads patches from the file system. Consumes low RAM, but stresses FS") +parser.set_defaults(streaming=False) +parser.add_argument('--pretrain', dest='pretrain', action='store_true', + help="Pre-train the network using only L1 and L2 (use l1/l2weight)") +parser.set_defaults(pretrain=False) +params = parser.parse_args() + +step = 0 + +def main(unused_argv): + logging.info("************ Parameters summary ************") + logging.info("N epochs : " + str(params.epochs)) + logging.info("Batch size : " + str(params.batchsize)) + logging.info("Adam learning rate : " + str(params.adam_lr)) + logging.info("Adam beta1 : " + str(params.adam_b1)) + logging.info("L1 weight : " + str(params.l1weight)) + logging.info("L2 weight : " + str(params.l2weight)) + logging.info("GAN weight : " + str(params.ganweight)) + logging.info("VGG weight : " + str(params.vggweight)) + logging.info("VGG features : " + str(params.vggfeatures)) + logging.info("Base depth : " + str(params.depth)) + logging.info("Number of ResBlocks : " + str(params.nresblocks)) + logging.info("********************************************") + + # Preview + lr_image_for_prev = None + if params.preview is not None: + lr_image_for_prev = otbtf.read_as_np_arr(otbtf.gdal_open(params.preview), False) + + with tf.Graph().as_default(): + # dataset and iterator + ds = otbtf.DatasetFromPatchesImages(filenames_dict={constants.hr_key: params.hr_patches, + constants.lr_key: params.lr_patches}, + use_streaming=params.streaming) + tf_ds = ds.get_tf_dataset(batch_size=params.batchsize) + iterator = tf.compat.v1.data.Iterator.from_structure(tf_ds.output_types) + iterator_init = iterator.make_initializer(tf_ds) + dataset_inputs = iterator.get_next() + + # placeholders with normalization + def _get_normalized_input(key, scale, name): + default_input = dataset_inputs[key] + shape = (None, None, None, ds.output_shapes[key][-1]) + ph = tf.compat.v1.placeholder_with_default(default_input, shape=shape, name=name) + return scale * ph + + lr_image = _get_normalized_input(constants.lr_key, params.lr_scale, constants.lr_input_name) + hr_image = _get_normalized_input(constants.hr_key, params.hr_scale, constants.hr_input_name) + hr_nch = ds.output_shapes[constants.hr_key][-1] + generator = partial(network.generator, scope=constants.gen_scope, nchannels=hr_nch, nresblocks=params.nresblocks, + dim=params.depth) + discriminator = partial(network.discriminator, scope=constants.dis_scope, dim=params.depth) + + hr_images_real = {factor: downscale2d(blur2d(hr_image), factor=factor) for factor in constants.factors} + hr_images_fake = generator(lr_image) + + # model outputs + gen = {factor: (1.0 / params.hr_scale) * hr_images_fake[factor] for factor in constants.factors} + for pad in constants.pads: + tf.identity(gen[1][:, pad:-pad, pad:-pad, :], name="{}{}".format(constants.outputs_prefix, pad)) + if lr_image_for_prev is not None: + for factor in constants.factors: + prev = network.nice_preview(gen[factor], refs=[lr_image_for_prev]) + tf.compat.v1.summary.image("preview_factor{}".format(factor), prev, collections=['per_epoch']) + + # discriminator + dis_real = discriminator(hr_images=hr_images_real) + dis_fake = discriminator(hr_images=hr_images_fake) + + # l1 loss + gen_loss_l1 = tf.add_n([tf.reduce_mean(tf.abs(hr_images_fake[factor] - + hr_images_real[factor])) for factor in constants.factors]) + + # l2 loss + gen_loss_l2 = tf.add_n([tf.reduce_mean(tf.square(hr_images_fake[factor] - + hr_images_real[factor])) for factor in constants.factors]) + + # VGG loss + gen_loss_vgg = 0.0 + if params.vggfile is not None: + gen_loss_vgg = tf.add_n([compute_vgg_loss(hr_images_real[factor], + hr_images_fake[factor], + params.vggfeatures, + params.vggfile) for factor in constants.factors]) + + # GAN Losses + if params.losstype == "LSGAN": + dis_loss = tf.reduce_mean(tf.square(dis_real - 1) + tf.square(dis_fake)) + gen_loss_gan = tf.reduce_mean(tf.square(dis_fake - 1)) + elif params.losstype == "WGAN-GP": + dis_loss = dis_fake - dis_real + alpha = tf.random_uniform(shape=[params.batchsize, 1, 1, 1], minval=0., maxval=1.) + differences = {factor: hr_images_fake[factor] - hr_images_real[factor] for factor in constants.factors} + interpolates_scales = {factor: hr_images_real[factor] + + alpha * differences[factor] for factor in constants.factors} + mixed_loss = tf.reduce_sum(discriminator(interpolates_scales)) + mixed_grads = tf.gradients(mixed_loss, list(interpolates_scales.values())) + mixed_norms = [tf.sqrt(tf.reduce_sum(tf.square(gradient), reduction_indices=[1, 2, 3])) for gradient in + mixed_grads] + gradient_penalties = [tf.reduce_mean(tf.square(slope - 1.0)) for slope in mixed_norms] + gradient_penalty = tf.reduce_mean(gradient_penalties) + dis_loss += 10 * gradient_penalty + epsilon_penalty = tf.reduce_mean(tf.square(dis_real)) + dis_loss += 0.001 * epsilon_penalty + gen_loss_gan = -1.0 * tf.reduce_mean(dis_fake) + dis_loss = tf.reduce_mean(dis_loss) + else: + raise Exception("Please select an available cost function") + + # Total losses + def _new_loss(value, name, collections=None): + tf.compat.v1.summary.scalar(name, value, collections) + return value + + gen_loss = _new_loss(params.ganweight * gen_loss_gan, "gen_loss_gan") + gen_loss += _new_loss(params.vggweight * gen_loss_vgg, "gen_loss_vgg") + pretrain_loss = _new_loss(params.l1weight * gen_loss_l1, "gen_loss_l1", ["pretrain"]) + pretrain_loss += _new_loss(params.l2weight * gen_loss_l2, "gen_loss_l2", ["pretrain"]) + gen_loss += pretrain_loss + dis_loss = _new_loss(dis_loss, "dis_loss") + + # discriminator optimizer + dis_optim = tf.compat.v1.train.AdamOptimizer(learning_rate=params.adam_lr, beta1=params.adam_b1) + dis_tvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, constants.dis_scope) + dis_grads_and_vars = dis_optim.compute_gradients(dis_loss, var_list=dis_tvars) + with tf.compat.v1.variable_scope("apply_dis_gradients", reuse=tf.compat.v1.AUTO_REUSE): + dis_train = dis_optim.apply_gradients(dis_grads_and_vars) + + # generator optimizer + with tf.control_dependencies([dis_train]): + gen_optim = tf.compat.v1.train.AdamOptimizer(learning_rate=params.adam_lr, beta1=params.adam_b1) + gen_tvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, constants.gen_scope) + gen_grads_and_vars = gen_optim.compute_gradients(gen_loss, var_list=gen_tvars) + with tf.compat.v1.variable_scope("apply_gen_gradients", reuse=tf.compat.v1.AUTO_REUSE): + gen_train = gen_optim.apply_gradients(gen_grads_and_vars) + + pretrain_op = tf.compat.v1.train.AdamOptimizer(learning_rate=params.adam_lr).minimize(pretrain_loss) + train_nodes = [gen_train] + if params.losstype is "LSGAN": + ema = tf.train.ExponentialMovingAverage(decay=0.995) + update_losses = ema.apply([dis_loss, gen_loss]) + train_nodes.append(update_losses) + train_op = tf.group(train_nodes, name="optimizer") + + merged_losses_summaries = tf.compat.v1.summary.merge_all() + merged_pretrain_summaries = tf.compat.v1.summary.merge_all(key='pretrain') + merged_preview_summaries = tf.compat.v1.summary.merge_all(key='per_epoch') + + init = tf.global_variables_initializer() + saver = tf.compat.v1.train.Saver(max_to_keep=5) + + sess = tf.Session() + + # Writer + def _append_desc(key, value): + if value == 0: + return "" + return "_{}{}".format(key, value) + + now = datetime.datetime.now() + summaries_fn = "SR4RS_" + summaries_fn += _append_desc("E", params.epochs) + summaries_fn += _append_desc("B", params.batchsize) + summaries_fn += _append_desc("LR", params.adam_lr) + summaries_fn += _append_desc("Gan", params.ganweight) + summaries_fn += _append_desc("L1-", params.l1weight) + summaries_fn += _append_desc("L2-", params.l2weight) + summaries_fn += _append_desc("VGG", params.vggweight) + summaries_fn += _append_desc("VGGFeat", params.vggfeatures) + summaries_fn += _append_desc("Loss", params.losstype) + summaries_fn += _append_desc("D", params.depth) + summaries_fn += _append_desc("RB", params.nresblocks) + if params.pretrain: + summaries_fn += "pretrained" + summaries_fn += "_{}{}_{}h{}min".format(now.day, now.strftime("%b"), now.hour, now.minute) + + train_writer = None + if params.logdir is not None: + train_writer = tf.summary.FileWriter(params.logdir + summaries_fn, sess.graph) + + def _add_summary(summarized, _step): + if train_writer is not None: + train_writer.add_summary(summarized, _step) + + sess.run(init) + if params.load_ckpt is not None: + saver.restore(sess, params.load_ckpt) + + # preview + def _preview(_step): + if lr_image_for_prev is not None and step % params.previews_step == 0: + summary_pe = sess.run(merged_preview_summaries, {lr_image: lr_image_for_prev}) + _add_summary(summary_pe, _step) + + + def _do(_train_op, _summary_op, name): + global step + for curr_epoch in range(params.epochs): + logging.info("{} Epoch #{}".format(name, curr_epoch)) + sess.run(iterator_init) + try: + while True: + _, _summary = sess.run([_train_op, _summary_op]) + _add_summary(_summary, step) + _preview(curr_epoch) + step += 1 + except tf.errors.OutOfRangeError: + fs_stall_duration = ds.get_total_wait_in_seconds() + logging.info("{}: one epoch done. Total FS stall: {:.2f}s".format(name, fs_stall_duration)) + pass + saver.save(sess, params.save_ckpt + summaries_fn, global_step=curr_epoch) + + # pre training + if params.pretrain: + _do(pretrain_op, merged_pretrain_summaries, "pre-training") + + # training + _do(train_op, merged_losses_summaries, "training") + + # cleaning + if train_writer is not None: + train_writer.close() + + # Export SavedModel + if params.savedmodel is not None: + logging.info("Export SavedModel in {}".format(params.savedmodel)) + outputs = ["{}{}:0".format(constants.outputs_prefix, pad) for pad in constants.pads] + inputs = ["{}:0".format(constants.lr_input_name)] + graph = tf.get_default_graph() + tf.saved_model.simple_save(sess, + params.savedmodel, + inputs={i: graph.get_tensor_by_name(i) for i in inputs}, + outputs={o: graph.get_tensor_by_name(o) for o in outputs}) + + quit() + + +if __name__ == "__main__": + tf.compat.v1.add_check_numerics_ops() + tf.compat.v1.app.run(main) diff --git a/code/vgg.py b/code/vgg.py new file mode 100644 index 0000000..3907589 --- /dev/null +++ b/code/vgg.py @@ -0,0 +1,160 @@ +import tensorflow as tf +import logging +import numpy as np + +class Vgg19: + """ + Enables to use VGG19 features maps + """ + + def __init__(self, vgg19_npy_path): + self.data_dict = np.load(vgg19_npy_path, allow_pickle=True, encoding='latin1').item() + logging.info("npy file for VGG model loaded") + + def build(self, rgb, mode="1234"): + """ + load variable from npy to build the VGG + :param rgb: Tensor for rgb image [batch, height, width, 3] with values scaled in the [0, 1] range + :param mode: name of the perceptual loss to use + """ + + logging.info("building VGG model") + with tf.compat.v1.variable_scope("vgg_model", reuse=tf.compat.v1.AUTO_REUSE): + + vgg_mean_rgb = [123.68, 116.779, 103.939] + + # floating point rgb image in [0, 1] range --> 8 bits rbg image in [0, 255] range + rgb = 255.0 * tf.clip_by_value(rgb, 0.0, 1.0) + + # Convert RGB to BGR + bgr = tf.concat([rgb[:, :, :, ch:ch + 1] - vgg_mean_rgb[ch] for ch in [2, 1, 0]], axis=-1) + + # Build partial network + self.conv1_1, _ = self.conv_layer(bgr, "conv1_1") + self.conv1_2, self.conv1_2lin = self.conv_layer(self.conv1_1, "conv1_2") + self.pool1 = self.max_pool(self.conv1_2, 'pool1') + + self.conv2_1, _ = self.conv_layer(self.pool1, "conv2_1") + self.conv2_2, self.conv2_2lin = self.conv_layer(self.conv2_1, "conv2_2") + self.pool2 = self.max_pool(self.conv2_2, 'pool2') + + self.conv3_1, _ = self.conv_layer(self.pool2, "conv3_1") + self.conv3_2, _ = self.conv_layer(self.conv3_1, "conv3_2") + self.conv3_3, _ = self.conv_layer(self.conv3_2, "conv3_3") + self.conv3_4, self.conv3_4lin = self.conv_layer(self.conv3_3, "conv3_4") + self.pool3 = self.max_pool(self.conv3_4, 'pool3') + + self.conv4_1, _ = self.conv_layer(self.pool3, "conv4_1") + self.conv4_2, _ = self.conv_layer(self.conv4_1, "conv4_2") + self.conv4_3, _ = self.conv_layer(self.conv4_2, "conv4_3") + self.conv4_4, self.conv4_4lin = self.conv_layer(self.conv4_3, "conv4_4") + self.pool4 = self.max_pool(self.conv4_4, 'pool4') + + self.conv5_1, _ = self.conv_layer(self.pool4, "conv5_1") + self.conv5_2, _ = self.conv_layer(self.conv5_1, "conv5_2") + self.conv5_3, _ = self.conv_layer(self.conv5_2, "conv5_3") + self.conv5_4, self.conv5_4lin = self.conv_layer(self.conv5_3, "conv5_4") + self.pool5 = self.max_pool(self.conv5_4, 'pool5') + # + # self.fc6 = self.fc_layer(self.pool5, "fc6") + # #assert self.fc6.get_shape().as_list()[1:] == [4096] + # self.relu6 = tf.nn.relu(self.fc6) + # + # self.fc7 = self.fc_layer(self.relu6, "fc7") + # self.relu7 = tf.nn.relu(self.fc7) + # + # self.fc8 = self.fc_layer(self.relu7, "fc8") + # + # self.prob = tf.nn.softmax(self.fc8, name="prob") + # + # self.data_dict = None + + if mode == "1234": + f1 = tf.reshape(self.pool1, shape=[-1, 1]) + f2 = tf.reshape(self.pool2, shape=[-1, 1]) + f3 = tf.reshape(self.pool3, shape=[-1, 1]) + f4 = tf.reshape(self.pool4, shape=[-1, 1]) + return [f1, f2, f3, f4] + elif mode == "1234lin": + f1 = tf.reshape(self.conv1_2lin, shape=[-1, 1]) + f2 = tf.reshape(self.conv2_2lin, shape=[-1, 1]) + f3 = tf.reshape(self.conv3_4lin, shape=[-1, 1]) + f4 = tf.reshape(self.conv4_4lin, shape=[-1, 1]) + f5 = tf.reshape(self.conv5_4lin, shape=[-1, 1]) + return [f1, f2, f3, f4, f5] + elif mode == "vgg344454": + f1 = tf.reshape(self.conv3_4, shape=[-1, 1]) + f2 = tf.reshape(self.conv4_4, shape=[-1, 1]) + f3 = tf.reshape(self.conv5_4, shape=[-1, 1]) + return [f1, f2, f3] + elif mode == "vgg54": + return [tf.reshape(self.conv5_4, shape=[-1, 1])] + elif mode == "vgg54lin": + return [tf.reshape(self.conv5_4lin, shape=[-1, 1])] + else: + raise Exception("VGG loss \"{}\" not implemented!".format(mode)) + + def avg_pool(self, bottom, name): + return tf.nn.avg_pool2d(bottom, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name=name) + + def max_pool(self, bottom, name): + return tf.nn.max_pool2d(bottom, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name=name) + + def conv_layer(self, bottom, name): + with tf.compat.v1.variable_scope(name): + filt = self.get_conv_filter(name) + + conv_fn = tf.nn.conv2d(bottom, filt, [1, 1, 1, 1], padding='SAME') + + conv_biases = self.get_bias(name) + bias = tf.nn.bias_add(conv_fn, conv_biases) + + relu = tf.nn.relu(bias) + return relu, bias + + def fc_layer(self, bottom, name): + with tf.compat.v1.variable_scope(name): + shape = bottom.get_shape().as_list() + dim = 1 + for d in shape[1:]: + dim *= d + x = tf.reshape(bottom, [-1, dim]) + + weights = self.get_fc_weight(name) + biases = self.get_bias(name) + + # Fully connected layer. Note that the '+' operation automatically + # broadcasts the biases. + fc = tf.nn.bias_add(tf.matmul(x, weights), biases) + + return fc + + def get_conv_filter(self, name): + return tf.constant(self.data_dict[name][0], name="filter") + + def get_bias(self, name): + return tf.constant(self.data_dict[name][1], name="biases") + + def get_fc_weight(self, name): + return tf.constant(self.data_dict[name][0], name="weights") + + +def compute_vgg_loss(ref, gen, mode, vggfile): + """ + Compute "perceptual" (VGG19 based) loss + :param ref: reference image (rgb image [batch, height, width, ch]) + :param gen: generated image (rgb image [batch, height, width, ch]) + :param mode: name of the perceptual loss to use + """ + assert vggfile is not None + + logging.info("Using pre-trained VGG from {}".format(vggfile)) + vgg_model = Vgg19(vggfile) + features_ref = vgg_model.build(ref[:, :, :, 0:3], mode=mode) + features_gen = vgg_model.build(gen[:, :, :, 0:3], mode=mode) + + loss = 0.0 + for fr, fp in zip(features_ref, features_gen): + loss += tf.reduce_mean(tf.square(fr - fp)) + + return loss