diff --git a/code/train.py b/code/train.py index 06e1e87..f122e62 100644 --- a/code/train.py +++ b/code/train.py @@ -5,7 +5,7 @@ from functools import partial import otbtf import logging -from ops import blur2d, downscale2d +from ops import downscale2d from vgg import compute_vgg_loss import network import constants @@ -89,15 +89,14 @@ def main(unused_argv): iterator_init = iterator.make_initializer(tf_ds) dataset_inputs = iterator.get_next() - # placeholders with normalization - def _get_normalized_input(key, scale, name): + # model inputs + def _get_input(key, name): default_input = dataset_inputs[key] shape = (None, None, None, ds.output_shapes[key][-1]) - ph = tf.compat.v1.placeholder_with_default(default_input, shape=shape, name=name) - return scale * ph + return tf.compat.v1.placeholder_with_default(default_input, shape=shape, name=name) - lr_image = _get_normalized_input(constants.lr_key, params.lr_scale, constants.lr_input_name) - hr_image = _get_normalized_input(constants.hr_key, params.hr_scale, constants.hr_input_name) + lr_image = _get_input(constants.lr_key, constants.lr_input_name) + hr_image = _get_input(constants.hr_key, constants.hr_input_name) # model hr_nch = ds.output_shapes[constants.hr_key][-1] @@ -105,8 +104,8 @@ def _get_normalized_input(key, scale, name): nresblocks=params.nresblocks, dim=params.depth) discriminator = partial(network.discriminator, scope=constants.dis_scope, dim=params.depth) - hr_images_real = {factor: downscale2d(hr_image, factor=factor) for factor in constants.factors} - hr_images_fake = generator(lr_image) + hr_images_real = {factor: params.hr_scale * downscale2d(hr_image, factor=factor) for factor in constants.factors} + hr_images_fake = generator(params.lr_scale * lr_image) # model outputs gen = {factor: (1.0 / params.hr_scale) * hr_images_fake[factor] for factor in constants.factors}