Skip to content

Commit

Permalink
train on tiles
Browse files Browse the repository at this point in the history
  • Loading branch information
IvanKuchin committed Aug 18, 2024
1 parent 122cd30 commit 35bcd58
Showing 1 changed file with 24 additions and 2 deletions.
26 changes: 24 additions & 2 deletions dataset/craft_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def py_read_data_and_label(data_fname:str):
data_label = content["arr_0"]
data_array = data_label[0]
label_array = data_label[1]
return (tf.convert_to_tensor(data_array, dtype=tf.float32), tf.convert_to_tensor(label_array, dtype=tf.int32))
return data_array, label_array


def read_data_and_label(patient_id:str, src_folder:str):
Expand Down Expand Up @@ -69,6 +69,22 @@ class Array3d_read_and_resize:
def __init__(self, folder):
self.folder = folder

def random_crop(self, data, label, x, y, z):
data_shape = np.shape(data)
random_range = [data_shape[0] - x, data_shape[1] - y, data_shape[2] - z]
random_offset = np.random.randint(0, random_range, size = 3)
_data = data[
random_offset[0]:random_offset[0] + x,
random_offset[1]:random_offset[1] + y,
random_offset[2]:random_offset[2] + z,
...]
_label = label[
random_offset[0]:random_offset[0] + x,
random_offset[1]:random_offset[1] + y,
random_offset[2]:random_offset[2] + z,
...]
return _data, _label

def __call__(self):
self.file_list = FileIterator(self.folder)
for data_file in self.file_list:
Expand All @@ -79,14 +95,16 @@ def __call__(self):
data, label = read_data_and_label(patient_id, self.folder)
finish_reading = time.time()

data, label = self.random_crop(data, label, config.IMAGE_DIMENSION_X, config.IMAGE_DIMENSION_Y, config.IMAGE_DIMENSION_Z)

start_resize = time.time()
# data, label = borders.cut_and_resize_including_pancreas(data, label, np.random.rand(), np.random.rand())
finish_resize = time.time()

if DEBUG_DATA_LOADING_PERFORMANCE:
print(f"\tDATA_LOADING_PERFORMANCE: reading time: {finish_reading - start_reading:.1f} resize time: {finish_resize - start_resize:.1f}")

yield data, label
yield tf.convert_to_tensor(data, dtype=tf.float32), tf.convert_to_tensor(label, dtype=tf.int8)


# def array3d_read_and_resize():
Expand Down Expand Up @@ -209,6 +227,10 @@ def __run_through_data_wo_any_action(ds_train, ds_valid):


if __name__ == "__main__":
# read_and_resize = Array3d_read_and_resize(os.path.join(config.TFRECORD_FOLDER, "train"))
# item1 = next(read_and_resize())
# print("item1:", item1[0].shape, item1[1].shape)

train_ds = craft_datasets(os.path.join(config.TFRECORD_FOLDER, "train"))
valid_ds = craft_datasets(os.path.join(config.TFRECORD_FOLDER, "valid"))
__run_through_data_wo_any_action(train_ds, valid_ds)

0 comments on commit 35bcd58

Please sign in to comment.