-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.py
119 lines (91 loc) · 4.07 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import os, time, h5py, argparse, sys
import scipy.io as sio
import numpy as np
import tensorflow as tf
import keras
from keras.backend.tensorflow_backend import set_session
from keras import backend as K
from utils import set_parser, TrainConfig
from data import generate_patch_data, generate_slice_data
from train import train_model
START_TIME = time.strftime('%Y%m%d-%H%M', time.gmtime())
'''Handling Parameter'''
parser = argparse.ArgumentParser()
parser = set_parser(parser)
args = parser.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_device
np.random.seed(42)
tf.set_random_seed(42)
config = tf.ConfigProto()
config.gpu_options.allow_growth=True
sess = tf.Session(config=config)
K.set_session(sess)
if __name__ == '__main__':
#print_options(vars(args))
# Parameter Setting
TRSH = args.TRSH
random_num = args.random_num
RESTORE = args.restore
data_chn_num = args.data_chn_num
if not args.Patch:
img_x, img_y, img_z = 256, 256, 1
else:
img_x, img_y, img_z = args.img_size, args.img_size, 1
''' Load Data '''
train_config_dir = 'data/com_test_configs_2fold_adni60'
test_config_dir = 'data/com_test_configs_2fold_adni60'
restore_weights_path = 'results/basic_experiments_9/'
win_shape = (img_x, img_y, img_z)
print("Reading TRAIN data..")
train_data, train_trgt, train_list = generate_patch_data(train_config_dir, 0, TRSH, win_shape, random_num, data_chn_num, args.test)
print("Reading TEST data..")
test_data, test_trgt, target_list, affine = generate_slice_data(test_config_dir, 1, random_num, data_chn_num, args.test)
print(train_data.shape, ' class max val:',np.max(train_trgt))
# Reshape datat to 4-dims
# Data Order [FLAIR, IAM, T1W]
train_data = [np.expand_dims(train_data[:,:,:,i], axis=3) for i in range(data_chn_num)]
test_data = [np.expand_dims(test_data[:,:,:,i], axis=3) for i in range(data_chn_num)]
''' Train Networks'''
train_config = TrainConfig(args)
# U-Net (only FLAIR)
train_dat = [train_data[0], train_trgt]
test_dat = [test_data[0], test_trgt]
train_model(train_config,START_TIME, net_depth=args.depth, SALIENCY=False, DILATION=False,
restore_dir=None, net_type='FLAIR', train_dat=train_dat, test_dat=test_dat)
# U-Net (only IAM)
train_dat = [train_data[1], train_trgt]
test_dat = [test_data[1], test_trgt]
train_model(train_config,START_TIME, net_depth=args.depth, SALIENCY=False, DILATION=False,
restore_dir=None, net_type='IAM', train_dat=train_dat, test_dat=test_dat)
# U-Net (FLAIR + IAM)
K.clear_session()
sess = tf.Session(config=config)
K.set_session(sess)
train_dat = np.concatenate(train_data[0:2], axis=3)
test_dat = np.concatenate(test_data[0:2], axis=3)
train_dat = [train_dat, train_trgt]
test_dat = [test_dat, test_trgt]
train_model(train_config,START_TIME, net_depth=args.depth, SALIENCY=False, DILATION=False,
restore_dir=None, net_type='F+I', train_dat=train_dat, test_dat=test_dat)
# Saliency U-Net (FLAIR+IAM)
K.clear_session()
sess = tf.Session(config=config)
K.set_session(sess)
train_dat = [train_data[0:2], train_trgt]
test_dat = [test_data[0:2], test_trgt]
train_model(train_config,START_TIME, net_depth=args.depth, SALIENCY=True, DILATION=False, restore_dir=None, net_type='F+I', train_dat=train_dat, test_dat=test_dat)
# Dilated Saliency U-Net (FLAIR + IAM)
# Dilation Factor - 1224
K.clear_session()
sess = tf.Session(config=config)
K.set_session(sess)
train_dat = [train_data[0:2], train_trgt]
test_dat = [test_data[0:2], test_trgt]
train_model(train_config,START_TIME, net_depth=args.depth, SALIENCY=True, DILATION=True, restore_dir=None, net_type='F+I_1224', train_dat=train_dat, test_dat=test_dat, dilation_factor = [[1,2],[2,4]])
# Clear memory
train_trgt = None
test_trgt = None
train_dat = None
targt_dat = None
train_data = None
test_data = None