-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrand_hist_norm.py
410 lines (360 loc) · 15 KB
/
rand_hist_norm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
# -*- coding: utf-8 -*-
"""
Implementation of
Nyúl László G., Jayaram K. Udupa, and Xuan Zhang.
"New variants of a method of MRI scale standardization."
IEEE transactions on medical imaging 19.2 (2000): 143-150.
This implementation only supports input images with floating point number,
(not integers).
"""
from __future__ import absolute_import, print_function, division
import os
import numpy as np
import numpy.ma as ma
DEFAULT_CUTOFF = [0.01, 0.99]
SUPPORTED_CUTPOINTS = set(['percentile', 'quartile', 'median'])
def touch_folder(model_dir):
"""
This function returns the absolute path of `model_dir` if exists
otherwise try to create the folder and returns the absolute path.
"""
model_dir = os.path.expanduser(model_dir)
if not os.path.exists(model_dir):
try:
os.makedirs(model_dir)
except (OSError, TypeError):
print('could not create model folder: %s', model_dir)
raise
absolute_dir = os.path.abspath(model_dir)
return absolute_dir
# Print iterations progress
def print_progress_bar(iteration, total,
prefix='', suffix='', decimals=1, length=10, fill='='):
"""
Call in a loop to create terminal progress bar
:param iteration: current iteration (Int)
:param total: total iterations (Int)
:param prefix: prefix string (Str)
:param suffix: suffix string (Str)
:param decimals: number of decimals in percent complete (Int)
:param length: character length of bar (Int)
:param fill: bar fill character (Str)
"""
percent = ("{0:." + str(decimals) + "f}").format(
100 * (iteration / float(total)))
filledLength = int(length * iteration // total)
bars = fill * filledLength + '-' * (length - filledLength)
print('\r%s |%s| %s%% %s' % (prefix, bars, percent, suffix), end='\r')
# Print New Line on Complete
if iteration == total:
print('\n')
def __compute_percentiles(img, mask, cutoff):
"""
Creates the list of percentile values to be used as landmarks for the
linear fitting.
:param img: Image on which to determine the percentiles
:param mask: Mask to use over the image to constraint to the relevant
information
:param cutoff: Values of the minimum and maximum percentiles to use for
the linear fitting
:return perc_results: list of percentiles value for the given image over
the mask
"""
perc = [cutoff[0],
0.1, 0.2, 0.25, 0.3, 0.4, 0.5, 0.6, 0.7, 0.75, 0.8, 0.9,
cutoff[1]]
masked_img = ma.masked_array(img, np.logical_not(mask)).compressed()
perc_results = np.percentile(masked_img, 100 * np.array(perc))
# hist, bin = np.histogram(ma.compressed(masked_img), bins=50)
return perc_results
def __standardise_cutoff(cutoff, type_hist='quartile'):
"""
Standardises the cutoff values given in the configuration
:param cutoff:
:param type_hist: Type of landmark normalisation chosen (median,
quartile, percentile)
:return cutoff: cutoff with appropriate adapted values
"""
cutoff = np.asarray(cutoff)
if cutoff is None:
return DEFAULT_CUTOFF
if len(cutoff) > 2:
cutoff = np.unique([np.min(cutoff), np.max(cutoff)])
if len(cutoff) < 2:
return DEFAULT_CUTOFF
if cutoff[0] > cutoff[1]:
cutoff[0], cutoff[1] = cutoff[1], cutoff[0]
cutoff[0] = max(0., cutoff[0])
cutoff[1] = min(1., cutoff[1])
if type_hist == 'quartile':
cutoff[0] = np.min([cutoff[0], 0.24])
cutoff[1] = np.max([cutoff[1], 0.76])
else:
cutoff[0] = np.min([cutoff[0], 0.09])
cutoff[1] = np.max([cutoff[1], 0.91])
return cutoff
def create_mapping_from_multimod_arrayfiles(images,
modalities,
mod_to_train,
cutoff):
"""
Performs the mapping creation based on a list of files. For each of the
files (potentially multimodal), the landmarks are defined for each
modality and stored in a database. The average of these landmarks is
returned providing the landmarks to use for the linear mapping of any
new incoming data
:param images: array of images
:param modalities: Name of the modalities used for the
standardisation and the corresponding order in the multimodal files
:param cutoff: Minimum and maximum landmarks percentile values to use for
the mapping
:return:
"""
perc_database = {}
for (i, p) in enumerate(images):
print_progress_bar(i, len(images),
prefix='normalisation histogram training',
decimals=1, length=10, fill='*')
img_data = images[i, ...]
for mod_i, m in enumerate(modalities):
if m not in mod_to_train:
continue
if m not in perc_database.keys():
perc_database[m] = []
img_2d = img_data[mod_i, ...]
mask_2d = np.ones_like(img_2d, dtype=np.bool)
perc = __compute_percentiles(img_2d, mask_2d, cutoff)
perc_database[m].append(perc)
mapping = {}
for m in list(perc_database):
perc_database[m] = np.vstack(perc_database[m])
s1, s2 = create_standard_range()
mapping[m] = tuple(__averaged_mapping(perc_database[m], s1, s2))
return mapping
#__averaged_mapping(perc_database['FLAIR'], s1, s2)
def create_standard_range():
return 0., 100.
def __averaged_mapping(perc_database, s1, s2):
"""
Map the landmarks of the database to the chosen range
:param perc_database: perc_database over which to perform the averaging
:param s1, s2: limits of the mapping range
:return final_map: the average mapping
"""
# assuming shape: n_data_points = perc_database.shape[0]
# n_percentiles = perc_database.shape[1]
slope = (s2 - s1) / (perc_database[:, -1] - perc_database[:, 0])
slope = np.nan_to_num(slope)
final_map = slope.dot(perc_database) / perc_database.shape[0]
intercept = np.mean(s1 - slope * perc_database[:, 0])
final_map = final_map + intercept
return final_map
#slope = (s2 - s1) / (perc_database['T1'][:, -1] - perc_database['T1'][:, 0])
#slope = (s2 - s1) / (perc_database['FLAIR'][:, -1] - perc_database['FLAIR'][:, 0])
# final_map = slope.dot(perc_database['T1']) / perc_database['T1'].shape[0]
def transform_by_mapping(img, mask, mapping, cutoff, type_hist='quartile'):
"""
Performs the standardisation of a given image.
:param img: image to standardise
:param mask: mask over which to determine the landmarks
:param mapping: mapping landmarks to use for the piecewise linear
transformations
:param cutoff: cutoff points for the mapping
:param type_hist: Type of landmarks scheme to use: choice between
quartile percentile and median
:return new_img: the standardised image
"""
image_shape = img.shape
img = img.reshape(-1)
mask = mask.reshape(-1)
if type_hist == 'quartile':
range_to_use = [0, 3, 6, 9, 12]
elif type_hist == 'percentile':
range_to_use = [0, 1, 2, 4, 5, 6, 7, 8, 10, 11, 12]
elif type_hist == 'median':
range_to_use = [0, 6, 12]
else:
raise ValueError('unknown cutting points type_str')
assert len(mapping) >= len(range_to_use), \
"wrong mapping format, please check the histogram reference file"
mapping = np.asarray(mapping)
cutoff = __standardise_cutoff(cutoff, type_hist)
perc = __compute_percentiles(img, mask, cutoff)
# Apply linear histogram standardisation
range_mapping = mapping[range_to_use]
range_perc = perc[range_to_use]
diff_mapping = range_mapping[1:] - range_mapping[:-1]
diff_perc = range_perc[1:] - range_perc[:-1]
# handling the case where two landmarks are the same
# for a given input image. This usually happens when
# image background are not removed from the image.
diff_perc[diff_perc == 0] = np.inf
affine_map = np.zeros([2, len(range_to_use) - 1])
# compute slopes of the linear models
affine_map[0] = diff_mapping / diff_perc
# compute intercepts of the linear models
affine_map[1] = range_mapping[:-1] - affine_map[0] * range_perc[:-1]
bin_id = np.digitize(img, range_perc[1:-1], right=False)
lin_img = affine_map[0, bin_id]
aff_img = affine_map[1, bin_id]
# handling below cutoff[0] over cutoff[1]
# values are mapped linearly and then smoothed
new_img = lin_img * img + aff_img
# Apply smooth thresholding (exponential)
# below cutoff[0] and over cutoff[1]
# this might not guarantee one to one mapping
# lowest_values = img <= range_perc[0]
# highest_values = img >= range_perc[-1]
# new_img[lowest_values] = smooth_threshold(
# new_img[lowest_values], mode='low')
# new_img[highest_values] = smooth_threshold(
# new_img[highest_values], mode='high')
# Apply mask and set background to zero
# new_img[mask == False] = 0.
new_img = new_img.reshape(image_shape)
return new_img
def smooth_threshold(value, mode='high'):
smoothness = 1.
if mode == 'high':
affine = np.min(value)
smooth_value = (value - affine) / smoothness
smooth_value = (1. - np.exp((-1) * smooth_value)) + affine
elif mode == 'low':
affine = np.max(value)
smooth_value = (value - affine) / smoothness
smooth_value = (np.exp(smooth_value) - 1.) + affine
else:
smooth_value = value
return smooth_value
def read_mapping_file(mapping_file):
"""
Reads an existing mapping file with the given modalities.
:param mapping_file: file in which mapping is stored
:return mapping_dict: dictionary containing the mapping landmarks for
each modality stated in the mapping file
"""
mapping_dict = {}
if not mapping_file:
return mapping_dict
if not os.path.isfile(mapping_file):
return mapping_dict
with open(mapping_file, "r") as f:
for line in f:
if len(line) <= 2:
continue
line = line.split()
if len(line) < 2:
continue
try:
map_name, map_value = line[0], np.float32(line[1:])
mapping_dict[map_name] = tuple(map_value)
except ValueError:
print(
"unknown input format: {}".format(mapping_file))
raise
return mapping_dict
# Function to modify the model file with the mapping if needed according
# to existent mapping and modalities
def write_all_mod_mapping(hist_model_file, mapping):
# backup existing file first
if os.path.exists(hist_model_file):
backup_name = '{}.backup'.format(hist_model_file)
from shutil import copyfile
try:
copyfile(hist_model_file, backup_name)
except OSError:
print('cannot backup file {}'.format(hist_model_file))
raise
print(
"moved existing histogram reference file\n"
" from {} to {}".format(hist_model_file, backup_name))
touch_folder(os.path.dirname(hist_model_file))
__force_writing_new_mapping(hist_model_file, mapping)
def __force_writing_new_mapping(filename, mapping_dict):
"""
Writes a mapping dictionary to file
:param filename: name of the file in which to write the saved mapping
:param mapping_dict: mapping dictionary to save in the file
:return:
"""
with open(filename, 'w+') as f:
for mod in mapping_dict.keys():
mapping_string = ' '.join(map(str, mapping_dict[mod]))
string_fin = '{} {}\n'.format(mod, mapping_string)
f.write(string_fin)
return
class RandomHistNormLayer:
def __init__(self,
modalities=['T1', 'FLAIR'],
model_filename=None,
norm_type='percentile',
cutoff=(0.05, 0.95),
name='hist_norm'):
self.name = name
self.acquisition_type = '2D'
self.norm_type = norm_type
self.cutoff = cutoff
self.modalities = modalities
if model_filename is None:
model_filename = os.path.join('.', 'histogram_ref_file.txt')
self.model_file = os.path.abspath(model_filename)
self.mapping = read_mapping_file(self.model_file)
def __check_modalities_to_train(self):
modalities_to_train = [mod for mod in self.modalities
if mod not in self.mapping]
return set(modalities_to_train)
def is_ready(self):
mod_to_train = self.__check_modalities_to_train()
return False if mod_to_train else True
def train(self, image_list):
# check modalities to train, using the first subject in subject list
# to find input modality list
if self.is_ready():
print(
"normalisation histogram reference models ready"
" for {}:{}".format(self.image_name, self.modalities))
return
mod_to_train = self.__check_modalities_to_train()
print(mod_to_train)
print(image_list.shape[0])
print("training normalisation histogram references using {} subjects".format(image_list.shape[0]))
trained_mapping = create_mapping_from_multimod_arrayfiles(
images=image_list,
modalities=self.modalities,
mod_to_train=mod_to_train,
cutoff=self.cutoff)
# merging trained_mapping dict and self.mapping dict
self.mapping.update(trained_mapping)
all_maps = read_mapping_file(self.model_file)
all_maps.update(self.mapping)
write_all_mod_mapping(self.model_file, all_maps)
def layer_op(self, inputs, interp_orders, train_on=False, *args, **kwargs):
assert self.is_ready(), \
"histogram normalisation layer needs to be trained first."
image_5d = np.asarray(inputs, dtype=np.float32)
image_mask = np.ones_like(image_5d, dtype=np.bool)
normalised = self._normalise_5d(image_5d, image_mask)
return normalised
def get_piecewise_monotonic_perturbation(self, mapping):
pass
def _normalise_5d(self, data_array, mask_array):
assert self.modalities
if not self.mapping:
print(
"calling normaliser with empty mapping,"
"probably {} is not loaded".format(self.model_file))
raise RuntimeError
mask_array = np.asarray(mask_array, dtype=np.bool)
for mod_id, mod_name in enumerate(self.modalities):
if not np.any(data_array[mod_id, ...]):
continue # missing modality
if np.sum(np.isnan(self.mapping[mod_name])) > 0:
print('FOUND NaNs, not normalising {}'.format(mod_name))
continue
data_array[mod_id, ...] = transform_by_mapping(
img=data_array[mod_id, ...],
mask=mask_array[mod_id, ...],
mapping=self.mapping[mod_name],
cutoff=self.cutoff,
type_hist=self.norm_type)
return data_array