-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataset.py
2155 lines (1528 loc) · 68.9 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import scipy.io
import numpy as np
import os
import re
import csv
import os
import json
import glob
import keras
from keras.preprocessing.image import ImageDataGenerator, array_to_img
import tensorflow as tf
import PIL
from PIL import Image
import shutil
import sys
import matplotlib.pyplot as plt
plt.switch_backend('agg')
import pydicom
from PIL import Image
"""
confidence_thresholds: 1d array float
each element in the array is the lower bound of a bin
The upperbound is 1.
Each interval follows the following format:
(a,b) where indices with prediction values >= a and < b are stored.
Exception: when b==1, any index with prediction value 1 is stored.
preds: 1d array (#,)
Outputs a dictionary where keys are tuple of floats representing the bin
values are 1d arrays containing the indices rel. to the preds
"""
def bin_by_confidence(confidence_thresholds,preds):
ind_by_confidence={}
num_confidence_bins=len(confidence_thresholds)
for index_confidence_lb in range(num_confidence_bins):
confidence_bin_lb=confidence_thresholds[index_confidence_lb]
confidence_bin_ub=1 if index_confidence_lb==num_confidence_bins-1 \
else confidence_thresholds[index_confidence_lb+1]
# indices are relative to the order of preds
ind_preds_lb=np.where(preds>=confidence_bin_lb)
if int(confidence_bin_ub)==1:
ind_preds_ub=np.where(preds<=confidence_bin_ub)
else:
ind_preds_ub=np.where(preds<confidence_bin_ub)
ind_with_desired_confidence= np.intersect1d(ind_preds_lb,ind_preds_ub)
ind_by_confidence[(confidence_bin_lb,confidence_bin_ub)]=ind_with_desired_confidence
if len(ind_with_desired_confidence)>0:
assert np.amin(preds[ind_with_desired_confidence]) >= confidence_bin_lb
assert np.amax(preds[ind_with_desired_confidence]) <= confidence_bin_ub
return ind_by_confidence
"""
dest_folder: str folder path to export data to
data_partition_label: str train/val/test
all_subject_data: 1d array of SubjectData
is_task_iqa: bool
Saves the data to a numpy array of the following shape: (# slices, 5) with a ***STRING*** data type
For each slice, the attributes are recorded i.e., roi_label,quality_label,data_path...
"""
def export_data_partition(dest_folder,data_partition_label,all_slice_data):
all_slice_data_to_array=map(lambda x: x.get_attributes(),all_slice_data)
all_slice_data_to_array=np.array(all_slice_data_to_array)
np.save(os.path.join(dest_folder,data_partition_label+'.npy'),all_slice_data_to_array)
# save labels to CSV for computing further metrics
all_slice_quality_labels=all_slice_data_to_array[:,1].astype('int')
np.savetxt(os.path.join(dest_folder,data_partition_label+'_labels.csv'),
all_slice_quality_labels)
return
"""
Stores the meta data associated with 2d slice in a volume
"""
class SliceData(object):
"""
roi_label: int 0: no roi, 1: roi present
quality_label: int 0: good, 1: bad, -1: uncertain
subject_folder_name: str subject
stack_folder_name: str stack
dicom_path: str filename of the dicom/png with ext e.g., dcm/IMA
rel. to the folder that contains all the data
"""
def __init__(self,roi_label,quality_label,
subject_folder_name='',stack_folder_name='',dicom_path=''):
self.roi_label=roi_label
self.quality_label=quality_label
self.dicom_path=dicom_path
self.stack_folder_name=stack_folder_name
self.subject_folder_name=subject_folder_name
def is_valid_for_iqa(self):
return self.roi_label==1 and self.quality_label!=-1
def is_bad_slice(self):
return self.roi_label==1 and self.quality_label==1
def is_good_slice(self):
return self.roi_label==1 and self.quality_label==0
def get_complete_path(self):
return os.path.join(self.subject_folder_name,self.stack_folder_name,self.dicom_path)
def get_attributes(self):
return np.array([self.roi_label,self.quality_label,self.subject_folder_name,
self.stack_folder_name,self.dicom_path])
def __eq__(self,another_slice):
return self.roi_label==another_slice.roi_label and \
self.quality_label==another_slice.quality_label and \
self.dicom_path == another_slice.dicom_path and \
self.stack_folder_name==another_slice.stack_folder_name and \
self.subject_folder_name == another_slice.subject_folder_name
def __hash__(self):
return hash((self.roi_label,self.quality_label,self.dicom_path,self.stack_folder_name,self.subject_folder_name))
"""
Stores meta data associated with the 3d stack
"""
class StackData(object):
"""
slice_data: 1d array of SliceData objects
stack_folder_name: str stack
subject_folder_name: str subject
"""
def __init__(self,slice_data,stack_folder_name='',subject_folder_name=''):
self.num_slices=len(slice_data)
self.slice_data=slice_data
self.stack_folder_name=stack_folder_name
self.subject_folder_name=subject_folder_name
def __eq__(self,another_stack):
return self.num_slices==another_stack.num_slices and \
set(self.slice_data)==set(another_stack.slice_data) and \
self.stack_folder_name==another_stack.stack_folder_name and \
self.subject_folder_name==another_stack.subject_folder_name
def __hash__(self):
return hash((self.num_slices,self.stack_folder_name,self.subject_folder_name))
def get_number_bad_slices(self):
bad_slices=filter(lambda x: x.is_bad_slice(), self.slice_data)
num_bad_slices=len(bad_slices)
return num_bad_slices
def get_number_roi_slices(self):
roi_slices=filter(lambda x: x.roi_label==1,self.slice_data)
num_roi_slices=len(roi_slices)
return num_roi_slices
def is_contaminated_stack(self):
is_contaminated=True if self.get_number_bad_slices()>=1 else False
return is_contaminated
def get_fraction_contaminated(self):
return self.get_number_bad_slices()*1.0/self.get_number_roi_slices()
# heuristic way of deciding whether the stack is a brain/body scan
def is_brain_stack(self):
threshold=0.3
frac_with_roi=self.get_number_roi_slices()*1.0/self.num_slices
return frac_with_roi>=threshold
"""
returns the parameters of the stack
"""
def get_parameters(self):
pass
"""
Returns a 1d array of SliceData objects where the slices are ordered
"""
def get_ordered_slice_data(self):
pass
class SubjectData(object):
"""
stack_data: 1d array of StackData objects
subject_folder_name: str ID of the subject
"""
def __init__(self,stack_data,subject_folder_name=''):
self.num_stacks=len(stack_data)
self.stack_data=stack_data
self.subject_folder_name=subject_folder_name
def __eq__(self,another_subject):
return self.num_stacks==another_subject.num_stacks and \
self.subject_folder_name == another_subject.subject_folder_name and \
set(self.stack_data) == set(another_subject.stack_data)
def __hash__(self):
return hash((self.num_stacks,self.subject_folder_name))
def get_all_slices(self):
slice_data=map(lambda x: x.slice_data, self.stack_data)
slice_data=[slice_data for stack_data in self.stack_data for slice_data in stack_data.slice_data] # flatten
slice_data=np.array(slice_data)
return slice_data
def get_number_bad_slices(self):
num_bad_slices_per_stack=map(lambda x: x.get_number_bad_slices(),
self.stack_data)
number_bad_slices=sum(num_bad_slices_per_stack)
return number_bad_slices
def get_number_roi_slices(self):
num_roi_slices_per_stack=map(lambda x: x.get_number_roi_slices(),
self.stack_data)
number_roi_slices=sum(num_roi_slices_per_stack)
return number_roi_slices
def get_brain_stacks(self):
brain_stacks=filter(lambda x: x.is_brain_stack(),self.stack_data)
return brain_stacks
def filter_stack_data(self):
filtered_stacks=map(lambda x: x.get_roi_slices(), self.stack_data)
return filtered_stacks
def get_number_contaminated_stacks(self):
num_contaminated_stacks=len(filter(lambda x: x.is_contaminated_stack(),self.stack_data))
return num_contaminated_stacks
def get_subject_data_with_brain_stacks(self):
brain_stacks=self.get_brain_stacks()
return SubjectData(brain_stacks,self.subject_folder_name)
def get_fraction_contamination_per_stack(self):
return map(lambda x: x.get_fraction_contaminated(),self.stack_data)
def get_average_stack_contamination(self):
return np.mean(self.get_fraction_contamination_per_stack())
def get_average_number_bad_slices_across_stacks(self):
return np.mean(self.get_number_bad_slices_across_stacks)
def get_number_bad_slices_per_stack(self):
return map(lambda x: x.get_number_bad_slices(),self.stack_data)
def get_number_roi_slices_per_stack(self):
return map(lambda x: x.get_number_roi_slices(),self.stack_data)
"""
all_subject_data: 1d array of SubjectData
returns 1d array of StackData
"""
def get_all_stack_data(all_subject_data):
all_stack_data=map(lambda x: x.stack_data,all_subject_data)
return flatten_nested_list(all_stack_data)
def get_all_slice_data(all_stack_data):
all_slice_data=map(lambda x:x.slice_data,all_stack_data)
return flatten_nested_list(all_slice_data)
def flatten_nested_list(nested_list):
flattened_list=[data_unit for sublist in nested_list \
for data_unit in sublist]
flattened_list=np.array(flattened_list)
return flattened_list
"""
data_by_stack: 1d array of 1d arrays of SliceData objects representing the stack data
subject_name: str
"""
def plot_stack_distribution(data_by_stack,subject_name):
fig=plt.figure()
ax=fig.add_subplot(1,1,1)
stack_names=[]
for stack_index,stack_data in enumerate(data_by_stack):
# create a bar for this stack
for slice_index,slice_data in enumerate(stack_data):
color='gray'
if slice_data.roi_label==1:
color='lightcoral' if slice_data.quality_label==1 else 'mediumseagreen'
ax.bar(stack_index,height=1,width=0.5,bottom=slice_index,color=color,align='center',
edgecolor='black')
stack_name=slice_data.dicom_path.split("\\")[-1]
stack_name=stack_name.split(".")
stack_name=str(int(stack_name[2]))
stack_names.append(stack_name)
ax.set_xlabel("stack ID")
ax.set_ylabel("slice index")
xleft=0-1
num_stacks=len(data_by_stack)
ax.set_xlim(left=xleft,right=num_stacks+1)
ax.set_xticks(np.arange(num_stacks))
ax.set_xticklabels(stack_names,rotation='vertical')
ylims=ax.get_ylim()
max_num_slices=int(ylims[1])+1
# ax.set_yticks(np.arange(max_num_slices)+0.5)
# ax.set_yticklabels(np.arange(max_num_slices))
slice_indices=np.arange(max_num_slices,step=5)
ax.set_yticks(slice_indices+0.5)
ax.set_yticklabels(slice_indices)
ax.set_title("Stack distribution %s" %subject_name)
from matplotlib.lines import Line2D
custom_lines = [Line2D([0], [0], color='gray', lw=4),
Line2D([0], [0], color='lightcoral', lw=4),
Line2D([0], [0], color='mediumseagreen', lw=4)]
ax.legend(custom_lines, ['no brain', 'Bad', 'Good'])
return fig
"""
data_by_stack: 1d array of 1d arrays of SliceData objects representing the stack data
subject_name: str
"""
def plot_subject_stack_distribution(subject_data):
fig=plt.figure()
ax=fig.add_subplot(1,1,1)
stack_names=[]
for stack_index,stack_data in enumerate(subject_data.stack_data):
# create a bar for this stack
for slice_index,slice_data in enumerate(stack_data.slice_data):
color='gray'
if slice_data.roi_label==1:
color='lightcoral' if slice_data.quality_label==1 else 'mediumseagreen'
ax.bar(stack_index,height=1,width=0.5,bottom=slice_index,color=color,align='center',
edgecolor='black')
stack_name=slice_data.dicom_path.split("\\")[-1]
stack_name=stack_name.split(".")
stack_name=str(int(stack_name[2]))
stack_names.append(stack_name)
ax.set_xlabel("stack ID")
ax.set_ylabel("slice index")
xleft=0-1
num_stacks=subject_data.num_stacks
ax.set_xlim(left=xleft,right=num_stacks+1)
ax.set_xticks(np.arange(num_stacks))
ax.set_xticklabels(stack_names,rotation='vertical')
ylims=ax.get_ylim()
max_num_slices=int(ylims[1])+1
# ax.set_yticks(np.arange(max_num_slices)+0.5)
# ax.set_yticklabels(np.arange(max_num_slices))
slice_indices=np.arange(max_num_slices,step=5)
ax.set_yticks(slice_indices+0.5)
ax.set_yticklabels(slice_indices)
ax.set_title("Stack distribution %s" %subject_data.subject_folder_name)
from matplotlib.lines import Line2D
custom_lines = [Line2D([0], [0], color='gray', lw=4),
Line2D([0], [0], color='lightcoral', lw=4),
Line2D([0], [0], color='mediumseagreen', lw=4)]
ax.legend(custom_lines, ['no brain', 'Bad', 'Good'])
return fig
"""
Based on the saveHASTEImages script, assumes the same order
in data_by_stack and all_dicom_names
"""
def save_dicoms_by_stack(source):
text_file=open(os.path.join(source,'vol_dicom_names.txt'),"r")
all_dicom_names=text_file.readlines()
all_dicom_names_reformatted=map(lambda x: reformat_dicom_name(x), all_dicom_names)
text_file.close()
data_by_stack=scipy.io.loadmat(os.path.join(source,'loaded_haste_vols_v2.mat'))
images_by_stack=data_by_stack['subject_data_by_vol'][0]
num_stacks=len(images_by_stack)
curr_stack_head=0
for i in range(num_stacks):
num_slices=images_by_stack[i].shape[-1]
stack_dicom_fnames=all_dicom_names_reformatted[curr_stack_head:curr_stack_head+num_slices]
curr_stack_head+=num_slices
np.save(os.path.join(source,'vol_%d'%(i+1),'stack_dicom_names'),
stack_dicom_fnames)
"""
dicom_name: str path of the dicom file
outputs the part of the dicom path just containing the dicom filepath with the extension
"""
def reformat_dicom_name(dicom_name):
reformatted_dicom_name=dicom_name[:]
reformatted_dicom_name=reformatted_dicom_name.strip()
# get only the dicom part of the path
path_parts=reformatted_dicom_name.split("\\")
dicom_path=path_parts[-1]
return dicom_path
"""
stack_labels_dict: key: str image name value: label
dicom_names: 1d array str
Assumes dicom_names are of the format: XXX.IMA
corresponds to the order of the slices i.e., 1st dicom is for "slice_1"
Changes the keys of the stack_labels_dict to the ones in dicom_names
"""
def map_keys_to_dicom(stack_labels_dict,dicom_names):
num_slices=len(dicom_names)
stack_labels_dict_with_dicom_keys={}
for slice_index in range(num_slices):
slice_dicom_name=dicom_names[slice_index]
labeled_slice_index=slice_index+1
stack_labels_dict_with_dicom_keys[slice_dicom_name]=stack_labels_dict['slice_%d'%labeled_slice_index]
return stack_labels_dict_with_dicom_keys
"""
source: str path to folder for stack data
Assumes of the following format: ../../../dicom_source_folder/subject_folder/stack_folder
Outputs a StackData
"""
def get_stack_data(source):
stack_data=[]
labels_file=glob.glob(os.path.join(source,'*.csv'))[0]
stack_roi_labels,stack_quality_labels=load_single_stack_labels(labels_file)
dicom_names_file_path=os.path.join(source,'stack_dicom_names.npy')
if os.path.isfile(dicom_names_file_path):
dicom_names=np.load(dicom_names_file_path)
dicom_names=map(lambda x: reformat_dicom_name(x),dicom_names)
dicom_names=np.array(dicom_names)
stack_roi_labels=map_keys_to_dicom(stack_roi_labels,dicom_names)
stack_quality_labels=map_keys_to_dicom(stack_quality_labels,dicom_names)
stack_roi_labels={slice_key: reformat_roi_label(slice_label) for slice_key, slice_label in stack_roi_labels.items()}
stack_quality_labels={slice_key: reformat_quality_label(slice_label) for slice_key, slice_label in stack_quality_labels.items()}
slice_keys=stack_roi_labels.keys()
stack_data=[]
for slice_key in slice_keys:
roi_label=stack_roi_labels[slice_key]
quality_label=stack_quality_labels[slice_key]
# attach only the part of the path relevant to the dicom source folder
source_path_parts=source.split("/")
subject_folder_name,stack_folder_name=source_path_parts[-2:]
slice_data=SliceData(roi_label,quality_label,
subject_folder_name=subject_folder_name,
stack_folder_name=stack_folder_name,
dicom_path=slice_key)
stack_data.append(slice_data)
stack_data=np.array(stack_data)
return StackData(stack_data,subject_folder_name=subject_folder_name,
stack_folder_name=stack_folder_name)
"""
source: str
Outputs a 1d array of arrays where each array indexes a stack
"""
def get_subject_data(source):
subject_data_by_stack=[]
subject_folder_name=os.path.split(source)[-1]
for stack_folder_name in os.listdir(source):
if not os.path.isdir(os.path.join(source,stack_folder_name)):
continue
stack_data=get_stack_data(os.path.join(source,stack_folder_name))
if stack_data.num_slices>0:
subject_data_by_stack.append(stack_data)
subject_data_by_stack=np.array(subject_data_by_stack)
return SubjectData(subject_data_by_stack,subject_folder_name)
"""
image: 3d array (#row, #col, 1)
"""
def reshape_image_for_transfer(image):
reshaped_image=np.concatenate((image,image,image),axis=2)
return reshaped_image
"""
fname: str assumes of the following form ../../CASENAME/VOL
Returns the part containing the case/volume i.e., CASENAME/VOL
"""
def extract_relevant_part_filename(fname):
path_parts=fname.split("/")
relevant_path_parts=path_parts[-2:]
relevant_part_fname=os.path.join(relevant_path_parts[0],relevant_path_parts[1])
return relevant_part_fname
"""
data_partition: 4d-list
order: images, iqa_labels, roi_labels, fnames
fnames: each filename path will end in at least this form: caseCASE_NUM/vol_VOL_NUM_SLICE_NUM
correction_labels: (# files, correction_type),
correction types: index 0 corresponds to no_roi, mislabeled
correpsonding_fnames: (# files, 1) -- order of this does not nec. corresponding to the order in data partition
will end in at least this form: ../caseCASE_NUM/vol_VOL_NUM_SLICE_NUM
Outputs the data_partition, corrected version
"""
def correct_dataset(data_partition,correction_labels,corresponding_fnames):
images,iqa_labels,roi_labels,fnames=data_partition
relevant_part_fnames=np.array(map(extract_relevant_part_filename,fnames))
# due to the weird format from matlab need to process this differently
relevant_part_corresponding_fnames=np.array(map(extract_relevant_part_filename,corresponding_fnames))
corrected_iqa_labels=iqa_labels.copy()
corrected_roi_labels=roi_labels.copy()
ind_corrections=np.where(np.any(correction_labels,axis=1)==True)[0]
for ind_correction in ind_corrections:
corrected_fname=relevant_part_corresponding_fnames[ind_correction]
index_correction_rel_dataset=np.where(relevant_part_fnames==corrected_fname)[0]
if correction_labels[ind_correction][0]:
corrected_roi_labels[index_correction_rel_dataset]=1-corrected_roi_labels[index_correction_rel_dataset]
if correction_labels[ind_correction][1]:
corrected_iqa_labels[index_correction_rel_dataset]=1-corrected_iqa_labels[index_correction_rel_dataset]
return images,corrected_iqa_labels,corrected_roi_labels,fnames
"""
TODO TUNE noise parameters
img: 3d tensor (width,height,channels) dtype:uint16
"""
def add_noise(img):
noise_std=1
noise=np.random.normal(loc=0,scale=noise_std,size=img.shape)
noise_img=img+noise
return noise_img
"""
Generator
skeleton code: https://stanford.edu/~shervine/blog/keras-how-to-generate-data-on-the-fly
Generating augmentations source code:
-https://github.com/keras-team/keras-preprocessing/blob/master/keras_preprocessing/image/image_data_generator.py
-https://github.com/keras-team/keras-preprocessing/blob/master/keras_preprocessing/image/affine_transformations.py
other options:
multi-input generator
options:
-https://github.com/keras-team/keras/issues/8130
-https://github.com/keras-team/keras/issues/2568
Decision: implement the Sequence class
this will make it easier to apply my own processing on the input before augmentations, etc
e.g., brain masking or apply the same augmentations to the input images?
also if I want to incorporate multiple reference inputs, my own data generator class
would support doing this although
it might be more complex logic... so might want to rethink this
"""
"""
ROTATION_RANGE=360
WIDTH_SHIFT_RANGE=20
HEIGHT_SHIFT_RANGE=20
BRIGHTNESS_RANGE= [0.25,1.25]
SHEAR_RANGE=0.0
ZOOM_RANGE=0.0
CHANNEL_SHIFT_RANGE= 20.0
HORIZONTAL_FLIP=True
VERTICAL_FLIP=True
FILL_MODE='nearest'
CVAL=0
PREPROCESSING_FUNCTION= add_noise
"""
ROTATION_RANGE=20
WIDTH_SHIFT_RANGE=20
HEIGHT_SHIFT_RANGE=20
BRIGHTNESS_RANGE= None # [0.25,1.25]
SHEAR_RANGE=0.0
ZOOM_RANGE=0.0
CHANNEL_SHIFT_RANGE= 0.0 #20.0
HORIZONTAL_FLIP=True
VERTICAL_FLIP=True
FILL_MODE='constant'
CVAL=0
PREPROCESSING_FUNCTION= None #add_noise
def generate_generator_multiple(generator,dir1, dir2, batch_size, img_height,img_width):
genX1 = generator.flow_from_directory(dir1,
target_size = (img_height,img_width),
class_mode = 'categorical',
batch_size = batch_size,
shuffle=False,
seed=7)
genX2 = generator.flow_from_directory(dir2,
target_size = (img_height,img_width),
class_mode = 'categorical',
batch_size = batch_size,
shuffle=False,
seed=7)
while True:
X1i = genX1.next()
X2i = genX2.next()
yield [X1i[0], X2i[0]], X2i[1] #Yield both images and their mutual label
"""
mask: 2d array binary mask
assumes mask has a circle
"""
def add_context_padding(mask):
new_mask=np.zeros(mask.shape)
mask_indices=np.where(mask==1)
top_row=np.amin(mask_indices[0])
bottom_row=np.amax(mask_indices[0])
left_col=np.amin(mask_indices[1])
right_col=np.amax(mask_indices[1])
new_mask[top_row:bottom_row,left_col:right_col]=1
return new_mask
"""
image_path: str corresponding to the path of the dicom ending with some dicom extension
For now, loads from the png/jpg instead of dicom
Returns ndarray representing the image
"""
def get_image_array(data_source_dir,subject_folder_path,stack_folder_path,dicom_path):
dicom_path_sans_ext,ext=os.path.splitext(dicom_path)
image_file=glob.glob(os.path.join(data_source_dir,
subject_folder_path,stack_folder_path,
'jpegs',dicom_path_sans_ext+"*"))[0] #handles any extension
image_obj=Image.open(image_file)
image=np.array(image_obj)
return image
class DataGenerator(keras.utils.Sequence):
"""
data_partition_path: str path to the .npy file containing the meta data
for the slices in this data partition
dicom_folder_path: str path to the folder containing all the dicom/jpeg data
batch_size
dim: if the # channels == 3, the image is reformatted by replicating the image along the channels
shuffle
augmentation_flag
save_images
save_images_path
save_labels: Bool should only be used when evaluating
"""
def __init__(self, data_partition_path, data_source_dir,
batch_size=1, dim=(256,256,1),
shuffle=False, augmentation_flag=False,
save_images=False,save_images_path='',save_labels=False):
self.all_slice_data=np.load(data_partition_path)
self.num_instances=len(self.all_slice_data)
self.data_source_dir=data_source_dir
self.batch_size = batch_size
self.dim=dim
self.shuffle = shuffle
self.augmentation_flag=augmentation_flag
if self.augmentation_flag:
self.image_transform_gen=ImageDataGenerator(samplewise_center=True,
samplewise_std_normalization=True,
rotation_range=ROTATION_RANGE,
width_shift_range=WIDTH_SHIFT_RANGE,
height_shift_range=HEIGHT_SHIFT_RANGE,
brightness_range=BRIGHTNESS_RANGE,
shear_range=SHEAR_RANGE,
zoom_range=ZOOM_RANGE,
channel_shift_range=CHANNEL_SHIFT_RANGE,
horizontal_flip=HORIZONTAL_FLIP,
vertical_flip=VERTICAL_FLIP,
fill_mode=FILL_MODE,
cval=CVAL,
preprocessing_function=PREPROCESSING_FUNCTION)
else: # only normalize image
self.image_transform_gen=ImageDataGenerator(samplewise_center=True,
samplewise_std_normalization=True)
self.save_images=save_images
if self.save_images:
self.save_images_path=os.path.join(save_images_path)
if os.path.isdir(self.save_images_path):
shutil.rmtree(self.save_images_path)
os.mkdir(self.save_images_path)
# useful for the test dataset
self.save_labels=save_labels
if self.save_labels:
self.labels=[]
self.on_epoch_end() # reshuffles data
"""
Number of batches per epoch
"""
def __len__(self):
return int(np.floor(self.num_instances / self.batch_size))
"""
Retrieves a batch of data
Outputs a batch of images and corresponding quality labels
"""
def __getitem__(self, index):
'Generate one batch of data'
# Generate indexes of the data points relative to the total set of images in this partition
indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
# Generate data
X,y=self.data_generation(indexes)
if self.save_labels:
self.labels.extend(y)
return X, y
def on_epoch_end(self):
self.indexes = np.arange(self.num_instances)
if self.shuffle == True:
np.random.shuffle(self.indexes)
"""
Outputs the batch of images and labels
Reads the image files rather than dicoms to potentially make data reading faster and less memory intensive
"""
def data_generation(self,indices):
# Initialization
n_channels=self.dim[-1]
X = np.empty((self.batch_size, self.dim[0],self.dim[1], n_channels))
y = np.empty((self.batch_size), dtype=int)
# Generate data
for index_rel_batch, index_rel_to_original_data in enumerate(indices):
# load image and label
slice_data=self.all_slice_data[index_rel_to_original_data]
roi_label,quality_label,subject_folder_path,stack_folder_path,dicom_path=slice_data
roi_label=int(roi_label)
quality_label=int(quality_label)
# maybe reading just the raw png is faster than the dicom? less memory?
# strip any extension in case
# image_data=pydicom.dcmread(os.path.join(self.dicom_folder_path,dicom_path))
# image=image_data.pixel_array
"""
dicom_path_sans_ext,ext=os.path.splitext(dicom_path)
image_file=glob.glob(os.path.join(self.data_source_dir,
subject_folder_path,stack_folder_path,'jpegs',dicom_path_sans_ext+"*"))[0] #handles any extension
image_obj=Image.open(image_file)
image=np.array(image_obj)
"""
image=get_image_array(self.data_source_dir,
subject_folder_path,stack_folder_path,dicom_path)
label=quality_label
image=image.astype('float')
# image*=255/np.amax(image) # rescale to 0-255 for data augmentations involving intensity shifts
image=np.reshape(image,(self.dim[0],self.dim[1],1))
transform_parameters=self.image_transform_gen.get_random_transform((self.dim[0],self.dim[1],1))
image=self.image_transform_gen.apply_transform(image,transform_parameters)
image=self.image_transform_gen.standardize(image)
if n_channels==3:
image=reshape_image_for_transfer(image)
if self.save_images:
converted_image=array_to_img(image)
image_file=os.path.join(self.save_images_path,'im_id_%d_%d.jpg'
%(index_rel_to_original_data,np.random.randint(0,100000)))
converted_image.save(image_file)
X[index_rel_batch,]=image
y[index_rel_batch]=label
# print(X.shape,y.shape)
return X, y
"""
vol_images: 1d array (# volumes, )
all_vol_fnames_concatenated: 1d array (# slices altogether in stack)
"""
def organize_dicom_names_by_vol(vol_images,all_vol_fnames_concatenated):
fnames_by_vol=[]
slice_index_start_vol=0
num_vols=len(vol_images)
for i in range(num_vols):
num_slices_in_vol=vol_images[i].shape[-1]
slice_index_end_vol=slice_index_start_vol+num_slices_in_vol
vol_fnames=all_vol_fnames_concatenated[slice_index_start_vol:slice_index_end_vol]
fnames_by_vol.append(vol_fnames)
slice_index_start_vol=slice_index_end_vol
fnames_by_vol=np.array(fnames_by_vol)
return fnames_by_vol
"""
fname: path representing a case
Returns 5 1d arrays for each case, containing data per vol i.e., length of each array is the number of volumes in this subject
all_vol_names: 1d array length: # vols in the subject
each array contains the folder it came from followed by the slice index
the indexing starts at 1
"""
def load_subject_data(fname):
data_file_name=os.path.join(fname,'loaded_haste_vols_v2.mat')
data=scipy.io.loadmat(data_file_name)
vol_images=data['subject_data_by_vol']
vol_images=np.reshape(vol_images,vol_images.shape[1]) # reshapes it into 1d array
num_vols=len(vol_images)
all_vol_fnames_concatenated=np.loadtxt(os.path.join(fname,'vol_dicom_names.txt'),dtype='str')
vol_fnames=organize_dicom_names_by_vol(vol_images,all_vol_fnames_concatenated)
"""
all_vol_fnames=[]
for i in range(num_vols):
vol_fnames=np.array([os.path.join(fname,'vol_%d_%d'%(i+1,j+1)) for j in range(vol_images[i].shape[-1])])
all_vol_fnames.append(vol_fnames)
"""
vol_roi_labels,vol_quality_labels=load_volumes_labels(fname)
return vol_images, vol_roi_labels, vol_quality_labels, vol_fnames
"""
labels_file: str csv path
Returns 2 arrays
"""
def load_single_volume_labels(labels_file):
# for each vol agregate the roi, quality labels
labels_by_slice={}
with open(labels_file) as csvfile:
reader=csv.DictReader(csvfile)
for row in reader:
labels=json.loads(row['Label'])
fname=row['External ID']
slice_key=fname.strip('.png')
labels_by_slice[slice_key]=[]
if 'roi' in labels:
if labels['roi']=='no': # explicitly labeled no
labels_by_slice[slice_key].append('no')
labels_by_slice[slice_key].append('bad') # for now label these bad
else:
labels_by_slice[slice_key].append('yes')
if 'image_quality' in labels.keys():
labels_by_slice[slice_key].append(labels['image_quality'])
elif 'good/bad/uncertain' in labels.keys():
labels_by_slice[slice_key].append(labels['good/bad/uncertain'])
else: # assumes that if roi is not labeled, then it is present
labels_by_slice[slice_key].append('yes')
if 'image_quality' in labels.keys():
labels_by_slice[slice_key].append(labels["image_quality"])
elif 'good/bad/uncertain' in labels.keys():
labels_by_slice[slice_key].append(labels['good/bad/uncertain'])
# reorder the data in the order of slices
roi_labels_by_slice=[]
quality_labels_by_slice=[]
num_slices=len(labels_by_slice.keys())
for i in range(1,num_slices+1):
roi_labels_by_slice.append(labels_by_slice['slice_'+str(i)][0])
quality_labels_by_slice.append(labels_by_slice['slice_'+str(i)][1])
return np.array(roi_labels_by_slice),np.array(quality_labels_by_slice)
"""
labels_file: str csv path
Returns a dictionary mapping each dicom key (str) to an array ROI, Quality label
"""
def load_single_stack_labels(labels_file):
# for each vol agregate the roi, quality labels
roi_labels_by_slice={}
quality_labels_by_slice={}
with open(labels_file) as csvfile: