From 0a4db5d931ee5a208087419e908406fdaadb3866 Mon Sep 17 00:00:00 2001 From: Remi Cresson Date: Sat, 13 Mar 2021 15:22:48 +0100 Subject: [PATCH] FIX: input images filtering --- code/train.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/code/train.py b/code/train.py index 1a9f79b..06e1e87 100644 --- a/code/train.py +++ b/code/train.py @@ -98,12 +98,14 @@ def _get_normalized_input(key, scale, name): lr_image = _get_normalized_input(constants.lr_key, params.lr_scale, constants.lr_input_name) hr_image = _get_normalized_input(constants.hr_key, params.hr_scale, constants.hr_input_name) + + # model hr_nch = ds.output_shapes[constants.hr_key][-1] - generator = partial(network.generator, scope=constants.gen_scope, nchannels=hr_nch, nresblocks=params.nresblocks, - dim=params.depth) + generator = partial(network.generator, scope=constants.gen_scope, nchannels=hr_nch, + nresblocks=params.nresblocks, dim=params.depth) discriminator = partial(network.discriminator, scope=constants.dis_scope, dim=params.depth) - hr_images_real = {factor: downscale2d(blur2d(hr_image), factor=factor) for factor in constants.factors} + hr_images_real = {factor: downscale2d(hr_image, factor=factor) for factor in constants.factors} hr_images_fake = generator(lr_image) # model outputs