Skip to content

Commit

Permalink
Merge pull request #29 from IvanKuchin/development
Browse files Browse the repository at this point in the history
wip: tile dataset
  • Loading branch information
IvanKuchin authored Aug 23, 2024
2 parents 2473346 + f478fa2 commit f4019f9
Show file tree
Hide file tree
Showing 7 changed files with 59 additions and 23 deletions.
8 changes: 5 additions & 3 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
# BACKGROUND_WEIGHT = 1 # must be calculated dynamically
# FOREGROUND_WEIGHT = 7 # must be calculated dynamically

INITIAL_LEARNING_RATE = 0.001
INITIAL_LEARNING_RATE = 1e-4
INSTANCE_NORM = False # not supported yet
BATCH_NORM = True
BATCH_SIZE = 1
BATCH_SIZE = 4
BATCH_NORM_MOMENTUM = 0.8

GRADIENT_ACCUMULATION_STEPS = 4 # https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Adam#args
Expand All @@ -27,7 +27,7 @@
PANCREAS_MIN_HU = -512 # -512
PANCREAS_MAX_HU = 1024 # 1024

IMAGE_DIMENSION_X = 160
IMAGE_DIMENSION_X = 96
IMAGE_DIMENSION_Y = IMAGE_DIMENSION_X
IMAGE_DIMENSION_Z = IMAGE_DIMENSION_X

Expand All @@ -47,6 +47,8 @@
IMAGE_ORIGINAL_DIMENSION_Y = IMAGE_ORIGINAL_DIMENSION_X
IMAGE_ORIGINAL_DIMENSION_Z = IMAGE_ORIGINAL_DIMENSION_X

IS_TILE = True

# Dataset used for training
# consists of pickle files of 3d numpy arrays
TFRECORD_FOLDER = "c:/Users/ikuchin/Downloads/pancreas_data/dataset/"
Expand Down
2 changes: 2 additions & 0 deletions dataset/pomc_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ def __init__(self, patients_src_folder, labels_src_folder, TFRECORD_FOLDER):
self.min_HU = float("inf")
self.max_HU = float("-inf")

# self.saver = SaverFactory(config.IS_TILE)

def get_patient_id_from_folder(self, folder):
result = None
m = re.search("(\\w+)$", folder)
Expand Down
33 changes: 20 additions & 13 deletions dataset/saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import config as config

class Slicer:
def __init__(self, data, label):
def __init__(self, data, label, augment_margin = [0, 0, 0]):
x = math.ceil(data.shape[0] / config.IMAGE_DIMENSION_X) * config.IMAGE_DIMENSION_X
y = math.ceil(data.shape[1] / config.IMAGE_DIMENSION_Y) * config.IMAGE_DIMENSION_Y
z = math.ceil(data.shape[2] / config.IMAGE_DIMENSION_Z) * config.IMAGE_DIMENSION_Z
Expand All @@ -23,26 +23,26 @@ def __init__(self, data, label):
self.data[:data.shape[0], :data.shape[1], :data.shape[2]] = data
self.label[:label.shape[0], :label.shape[1], :label.shape[2]] = label

self.augment_margin = augment_margin

def __iter__(self):
augment_margin = [
int(config.IMAGE_DIMENSION_X * config.AUGMENTATIO_SHIFT_MARGIN),
int(config.IMAGE_DIMENSION_Y * config.AUGMENTATIO_SHIFT_MARGIN),
int(config.IMAGE_DIMENSION_Z * config.AUGMENTATIO_SHIFT_MARGIN)
]
for x in range(0, self.data.shape[0], config.IMAGE_DIMENSION_X):
for y in range(0, self.data.shape[1], config.IMAGE_DIMENSION_Y):
for z in range(0, self.data.shape[2], config.IMAGE_DIMENSION_Z):
x_start = np.max([x - augment_margin[0], 0])
y_start = np.max([y - augment_margin[1], 0])
z_start = np.max([z - augment_margin[2], 0])
x_start = np.max([x - self.augment_margin[0], 0])
y_start = np.max([y - self.augment_margin[1], 0])
z_start = np.max([z - self.augment_margin[2], 0])

x_finish = np.min([x + config.IMAGE_DIMENSION_X + augment_margin[0], self.data.shape[0]])
y_finish = np.min([y + config.IMAGE_DIMENSION_Y + augment_margin[1], self.data.shape[1]])
z_finish = np.min([z + config.IMAGE_DIMENSION_Z + augment_margin[2], self.data.shape[2]])
x_finish = np.min([x + config.IMAGE_DIMENSION_X + self.augment_margin[0], self.data.shape[0]])
y_finish = np.min([y + config.IMAGE_DIMENSION_Y + self.augment_margin[1], self.data.shape[1]])
z_finish = np.min([z + config.IMAGE_DIMENSION_Z + self.augment_margin[2], self.data.shape[2]])

data = self.data [x_start:x_finish, y_start:y_finish, z_start:z_finish]
label = self.label[x_start:x_finish, y_start:y_finish, z_start:z_finish]

if np.max(label) == 0:
continue

yield data, label, x, y, z


Expand All @@ -61,7 +61,14 @@ def save(self, src_data, label_data):
src_data = np.cast[np.float32](src_data)
label_data = np.cast[np.int8](label_data)

for (data, label, x, y, z) in Slicer(src_data, label_data):
augment_margin = [
int(config.IMAGE_DIMENSION_X * config.AUGMENTATIO_SHIFT_MARGIN),
int(config.IMAGE_DIMENSION_Y * config.AUGMENTATIO_SHIFT_MARGIN),
int(config.IMAGE_DIMENSION_Z * config.AUGMENTATIO_SHIFT_MARGIN)
]


for (data, label, x, y, z) in Slicer(src_data, label_data, augment_margin=augment_margin):
# print(f"Saving slice at {x}, {y}, {z}...")
np.savez_compressed(os.path.join(self.folder, self.subfolder, self.patient_id + f"_cut-{self.percentage}_slice-{x}-{y}-{z}.npz", ), [data, label])

Expand Down
23 changes: 22 additions & 1 deletion predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import nibabel as nib
import tools.craft_network as craft_network
import config as config
from tools.predict.predict_no_tile import PredictNoTile
from tools.predict.predict_tile import PredictTile


class Predict:
Expand Down Expand Up @@ -125,6 +127,15 @@ def __resize_segmentation_to_dcm_shape(self, mask, dcm_slices):
return result

def __save_img_to_nifti(self, data, affine, result_file_name):
# TODO: add meta information
# affine = meta['affine'][0].cpu().numpy()
# pixdim = meta['pixdim'][0].cpu().numpy()
# dim = meta['dim'][0].cpu().numpy()

# img = nib.Nifti1Image(input_nii_array, affine=affine)
# img.header['dim'] = dim
# img.header['pixdim'] = pixdim

img_to_save = nib.Nifti1Image(data, affine)
nib.save(img_to_save, result_file_name)

Expand All @@ -136,6 +147,15 @@ def __print_stat(self, data, title=""):
tf.reduce_mean(tf.cast(data, dtype = tf.float32)),
tf.reduce_max(data), tf.reduce_sum(data)))

def __predict(self, src_data, model):
if config.IS_TILE == True:
predict_class = PredictTile(model)
else:
predict_class = PredictNoTile(model)

prediction = predict_class(src_data)
return prediction

def main(self, dcm_folder, result_file_name):
dcm_slices = self.__read_dcm_slices(dcm_folder)
raw_pixel_data = self.__get_pixel_data(dcm_slices)
Expand All @@ -146,7 +166,8 @@ def main(self, dcm_folder, result_file_name):

# model.summary()

prediction = model.predict(src_data)
# prediction = model.predict(src_data)
prediction = self.__predict(src_data, model)
mask = self.__create_segmentation(prediction)
mask = self.__resize_segmentation_to_dcm_shape(mask, dcm_slices)
mask = tf.squeeze(mask)
Expand Down
3 changes: 2 additions & 1 deletion tools/craft_network/att_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ def res_block(filters, input_shape, kernel_size, apply_batchnorm, apply_instance
if (apply_dropout):
x = tf.keras.layers.Dropout(0.5)(x)

return tf.keras.models.Model(inputs = input_layer, outputs = x)
model = tf.keras.models.Model(inputs = input_layer, outputs = x, name = "res_block_{}_{}".format(input_shape[-1], filters))
return model

def double_conv(filters, input_shape, kernel_size, apply_batchnorm, apply_instancenorm, apply_dropout=False):
model = tf.keras.models.Sequential()
Expand Down
6 changes: 4 additions & 2 deletions tools/craft_network/att_unet_dsv.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def craft_network(checkpoint_file = None, apply_batchnorm = True, apply_instance
if idx < len(filters) - 1:
x = tf.keras.layers.MaxPool3D(pool_size=(2, 2, 2), padding = "same")(x)

gating_base = get_gating_base(filters[-2], apply_batchnorm)(x)
gating_base = get_gating_base(filters[-1], apply_batchnorm)(x)

dsv_outputs = []
skip_conns = reversed(generator_steps_output[:-1])
Expand All @@ -40,7 +40,9 @@ def craft_network(checkpoint_file = None, apply_batchnorm = True, apply_instance
# --- don't gate signal due to no useful features at top level
gated_skip = skip_conn
else:
gated_skip = AttGate(apply_batchnorm = apply_batchnorm)((skip_conn, gating_base))
if idx == 0:
gated_skip = AttGate(apply_batchnorm = apply_batchnorm)((skip_conn, gating_base))
gated_skip = AttGate(apply_batchnorm = apply_batchnorm)((skip_conn, x))

x = tf.keras.layers.Concatenate(name = "concat_{}".format(_filter))([x, gated_skip])
x = res_block(_filter, x.shape, kernel_size = config.KERNEL_SIZE, apply_batchnorm = apply_batchnorm, apply_instancenorm = apply_instancenorm)(x)
Expand Down
7 changes: 4 additions & 3 deletions train_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,20 @@ def get_csv_dir():


def __dice_coef(y_true, y_pred):
gamma = 100.0
gamma = 0.01
y_true = tf.cast(y_true, dtype = tf.float32)
y_pred = tf.cast(y_pred[..., 1:2], dtype = tf.float32)

# print("y_true shape: ", y_true.shape)
# print("y_pred shape: ", y_pred.shape)

intersection = tf.reduce_sum(y_true * y_pred)
dice = (2. * intersection + gamma) / (tf.reduce_sum(y_true) + tf.reduce_sum(y_pred) + gamma)
union = tf.reduce_sum(y_true) + tf.reduce_sum(y_pred)
dice = (2. * intersection + gamma) / (union + gamma)
return dice

def __dice_loss(y_true, y_pred):
return -tf.math.log(__dice_coef(y_true, y_pred))
return 1 - __dice_coef(y_true, y_pred)


def __weighted_loss(y_true, y_pred):
Expand Down

0 comments on commit f4019f9

Please sign in to comment.