Skip to content

Commit

Permalink
Merge pull request #10 from remicres/8-fix_image_summary_dynamic
Browse files Browse the repository at this point in the history
fix image summary dynamic
  • Loading branch information
remicres authored Mar 10, 2021
2 parents 16131d7 + 2ca187e commit 4addaf8
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 17 deletions.
23 changes: 8 additions & 15 deletions code/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,23 +119,16 @@ def _make_output(net, factor):
return hr_images


def nice_preview(x, refs):
def nice_preview(x):
"""
Beautiful previews
Keep only first 3 bands --> RGB
"""
bands = [0, 1, 2]

_mean = np.zeros(3)
_std = np.zeros(3)
_ninv = 1.0 / float(len(refs))
for ref in refs:
_mean += _ninv * np.asarray([np.mean(ref[0, :, :, i]) for i in bands])
_std += _ninv * np.asarray([np.std(ref[0, :, :, i]) for i in bands])
_min = [__mean - 2 * __std for __mean, __std in zip(_mean, _std)]
_max = [__mean + 2 * __std for __mean, __std in zip(_mean, _std)]
return tf.cast(255 * tf.stack(
[1.0 / (__max - __min) * (tf.clip_by_value(x[:, :, :, i], __min, __max) - __min) for i, __min, __max in
zip(bands, _min, _max)],
axis=3), tf.uint8)
x = x[:, :, :, :3]
axis = [0, 1, 2]
stds = tf.math.reduce_std(x, axis=axis, keepdims=True)
means = tf.math.reduce_mean(x, axis=axis, keepdims=True)
mins = means - 2 * stds
maxs = means + 2 * stds
return tf.cast(255 * tf.divide(x - mins, maxs - mins), tf.uint8)

4 changes: 2 additions & 2 deletions code/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def _get_normalized_input(key, scale, name):
tf.identity(gen[1][:, pad:-pad, pad:-pad, :], name="{}{}".format(constants.outputs_prefix, pad))
if lr_image_for_prev is not None:
for factor in constants.factors:
prev = network.nice_preview(gen[factor], refs=[lr_image_for_prev])
prev = network.nice_preview(gen[factor])
tf.compat.v1.summary.image("preview_factor{}".format(factor), prev, collections=[constants.epoch_key])

# discriminator
Expand Down Expand Up @@ -212,7 +212,7 @@ def _append_desc(key, value):
return "_{}{}".format(key, value)

now = datetime.datetime.now()
summaries_fn = "SR4RS_"
summaries_fn = "SR4RS"
summaries_fn += _append_desc("E", params.epochs)
summaries_fn += _append_desc("B", params.batchsize)
summaries_fn += _append_desc("LR", params.adam_lr)
Expand Down

0 comments on commit 4addaf8

Please sign in to comment.