-
Notifications
You must be signed in to change notification settings - Fork 5
/
utils.py
257 lines (223 loc) · 7.89 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
"""
Different utilities such as orthogonalization of weights, initialization of
loggers, etc
Copyright (C) 2018, Matias Tassano <[email protected]>
This program is free software: you can use, modify and/or
redistribute it under the terms of the GNU General Public
License as published by the Free Software Foundation, either
version 3 of the License, or (at your option) any later
version. You should have received a copy of this license along
this program. If not, see <http://www.gnu.org/licenses/>.
"""
import subprocess
import math
import logging
import numpy as np
import cv2
import torch
import torch.nn as nn
from skimage.measure.simple_metrics import compare_psnr
def weights_init_kaiming(lyr):
r"""Initializes weights of the model according to the "He" initialization
method described in "Delving deep into rectifiers: Surpassing human-level
performance on ImageNet classification" - He, K. et al. (2015), using a
normal distribution.
This function is to be called by the torch.nn.Module.apply() method,
which applies weights_init_kaiming() to every layer of the model.
"""
classname = lyr.__class__.__name__
if classname.find('Conv') != -1:
nn.init.kaiming_normal(lyr.weight.data, a=0, mode='fan_in')
elif classname.find('Linear') != -1:
nn.init.kaiming_normal(lyr.weight.data, a=0, mode='fan_in')
elif classname.find('BatchNorm') != -1:
lyr.weight.data.normal_(mean=0, std=math.sqrt(2./9./64.)).\
clamp_(-0.025, 0.025)
nn.init.constant(lyr.bias.data, 0.0)
def batch_psnr(img, imclean, data_range):
r"""
Computes the PSNR along the batch dimension (not pixel-wise)
Args:
img: a `torch.Tensor` containing the restored image
imclean: a `torch.Tensor` containing the reference image
data_range: The data range of the input image (distance between
minimum and maximum possible values). By default, this is estimated
from the image data-type.
"""
img_cpu = img.data.cpu().numpy().astype(np.float32)
imgclean = imclean.data.cpu().numpy().astype(np.float32)
psnr = 0
for i in range(img_cpu.shape[0]):
psnr += compare_psnr(imgclean[i, :, :, :], img_cpu[i, :, :, :], \
data_range=data_range)
return psnr/img_cpu.shape[0]
def data_augmentation(image, mode):
r"""Performs dat augmentation of the input image
Args:
image: a cv2 (OpenCV) image
mode: int. Choice of transformation to apply to the image
0 - no transformation
1 - flip up and down
2 - rotate counterwise 90 degree
3 - rotate 90 degree and flip up and down
4 - rotate 180 degree
5 - rotate 180 degree and flip
6 - rotate 270 degree
7 - rotate 270 degree and flip
"""
out = np.transpose(image, (1, 2, 0))
if mode == 0:
# original
out = out
elif mode == 1:
# flip up and down
out = np.flipud(out)
elif mode == 2:
# rotate counterwise 90 degree
out = np.rot90(out)
elif mode == 3:
# rotate 90 degree and flip up and down
out = np.rot90(out)
out = np.flipud(out)
elif mode == 4:
# rotate 180 degree
out = np.rot90(out, k=2)
elif mode == 5:
# rotate 180 degree and flip
out = np.rot90(out, k=2)
out = np.flipud(out)
elif mode == 6:
# rotate 270 degree
out = np.rot90(out, k=3)
elif mode == 7:
# rotate 270 degree and flip
out = np.rot90(out, k=3)
out = np.flipud(out)
else:
raise Exception('Invalid choice of image transformation')
return np.transpose(out, (2, 0, 1))
def variable_to_cv2_image(varim):
r"""Converts a torch.autograd.Variable to an OpenCV image
Args:
varim: a torch.autograd.Variable
"""
nchannels = varim.size()[1]
if nchannels == 1:
res = (varim.data.cpu().numpy()[0, 0, :]*255.).clip(0, 255).astype(np.uint8)
elif nchannels == 3:
res = varim.data.cpu().numpy()[0]
res = cv2.cvtColor(res.transpose(1, 2, 0), cv2.COLOR_RGB2BGR)
res = (res*255.).clip(0, 255).astype(np.uint8)
else:
raise Exception('Number of color channels not supported')
return res
def get_git_revision_short_hash():
r"""Returns the current Git commit.
"""
return subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD']).strip()
def init_logger(argdict):
r"""Initializes a logging.Logger to save all the running parameters to a
log file
Args:
argdict: dictionary of parameters to be logged
"""
from os.path import join
logger = logging.getLogger(__name__)
logger.setLevel(level=logging.INFO)
fh = logging.FileHandler(join(argdict.log_dir, 'log.txt'), mode='a')
formatter = logging.Formatter('%(asctime)s - %(message)s')
fh.setFormatter(formatter)
logger.addHandler(fh)
try:
logger.info("Commit: {}".format(get_git_revision_short_hash()))
except Exception as e:
logger.error("Couldn't get commit number: {}".format(e))
logger.info("Arguments: ")
for k in argdict.__dict__:
logger.info("\t{}: {}".format(k, argdict.__dict__[k]))
return logger
def init_logger_ipol():
r"""Initializes a logging.Logger in order to log the results after
testing a model
Args:
result_dir: path to the folder with the denoising results
"""
logger = logging.getLogger('testlog')
logger.setLevel(level=logging.INFO)
fh = logging.FileHandler('out.txt', mode='w')
formatter = logging.Formatter('%(message)s')
fh.setFormatter(formatter)
logger.addHandler(fh)
return logger
def init_logger_test(result_dir):
r"""Initializes a logging.Logger in order to log the results after testing
a model
Args:
result_dir: path to the folder with the denoising results
"""
from os.path import join
logger = logging.getLogger('testlog')
logger.setLevel(level=logging.INFO)
fh = logging.FileHandler(join(result_dir, 'log.txt'), mode='a')
formatter = logging.Formatter('%(asctime)s - %(message)s')
fh.setFormatter(formatter)
logger.addHandler(fh)
return logger
def normalize(data):
r"""Normalizes a unit8 image to a float32 image in the range [0, 1]
Args:
data: a unint8 numpy array to normalize from [0, 255] to [0, 1]
"""
return np.float32(data/255.)
def svd_orthogonalization(lyr):
r"""Applies regularization to the training by performing the
orthogonalization technique described in the paper "FFDNet: Toward a fast
and flexible solution for CNN based image denoising." Zhang et al. (2017).
For each Conv layer in the model, the method replaces the matrix whose columns
are the filters of the layer by new filters which are orthogonal to each other.
This is achieved by setting the singular values of a SVD decomposition to 1.
This function is to be called by the torch.nn.Module.apply() method,
which applies svd_orthogonalization() to every layer of the model.
"""
classname = lyr.__class__.__name__
if classname.find('Conv') != -1:
weights = lyr.weight.data.clone()
c_out, c_in, f1, f2 = weights.size()
dtype = lyr.weight.data.type()
# Reshape filters to columns
# From (c_out, c_in, f1, f2) to (f1*f2*c_in, c_out)
weights = weights.permute(2, 3, 1, 0).contiguous().view(f1*f2*c_in, c_out)
# Convert filter matrix to numpy array
weights = weights.cpu().numpy()
# SVD decomposition and orthogonalization
mat_u, _, mat_vh = np.linalg.svd(weights, full_matrices=False)
weights = np.dot(mat_u, mat_vh)
# As full_matrices=False we don't need to set s[:] = 1 and do mat_u*s
lyr.weight.data = torch.Tensor(weights).view(f1, f2, c_in, c_out).\
permute(3, 2, 0, 1).type(dtype)
else:
pass
def remove_dataparallel_wrapper(state_dict):
r"""Converts a DataParallel model to a normal one by removing the "module."
wrapper in the module dictionary
Args:
state_dict: a torch.nn.DataParallel state dictionary
"""
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, vl in state_dict.items():
name = k[7:] # remove 'module.' of DataParallel
new_state_dict[name] = vl
return new_state_dict
def is_rgb(im_path):
r""" Returns True if the image in im_path is an RGB image
"""
from skimage.io import imread
rgb = False
im = imread(im_path)
if (len(im.shape) == 3):
if not(np.allclose(im[...,0], im[...,1]) and np.allclose(im[...,2], im[...,1])):
rgb = True
print("rgb: {}".format(rgb))
print("im shape: {}".format(im.shape))
return rgb