Skip to content

Commit

Permalink
Modify truncation on AISegment testset (#203)
Browse files Browse the repository at this point in the history
  • Loading branch information
seldauyanik-maxim authored Dec 20, 2022
1 parent ec7b362 commit 946077b
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions datasets/aisegment.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class AISegment(Dataset):
num_of_imgs_to_use_hr = 20000

def __init__(self, root_dir, d_type, transform=None, im_size=(80, 80), fold_ratio=1,
use_memory=False):
use_memory=False, truncate_testset=False):

if im_size not in ((80, 80), (352, 352)):
raise ValueError('im_size can only be set to (80, 80) or (352, 352)')
Expand All @@ -88,6 +88,8 @@ def __init__(self, root_dir, d_type, transform=None, im_size=(80, 80), fold_rati

self.d_type = d_type

self.is_truncated = False

vertical_crop_area = AISegment.org_img_dim[0] - AISegment.img_crp_dim[0]

if vertical_crop_area % (AISegment.num_of_cropped_imgs - 1) != 0:
Expand Down Expand Up @@ -201,13 +203,14 @@ def __init__(self, root_dir, d_type, transform=None, im_size=(80, 80), fold_rati
self.img_files_info = test_img_files_info
self.dataset_pkl_file_path = test_dataset_pkl_file_path
self.processed_folder_path = self.processed_test_data_folder
if truncate_testset:
self.is_truncated = True

else:
print(f'Unknown data type: {self.d_type}')
return

self.__create_pkl_files()
self.is_truncated = False

def __create_pkl_files(self):
if self.__check_pkl_files_exist():
Expand Down Expand Up @@ -407,7 +410,8 @@ def AISegment_get_datasets(data, load_train=True, load_test=True, im_size=(80, 8

train_dataset = AISegment(root_dir=data_dir, d_type='train',
transform=train_transform,
im_size=im_size, fold_ratio=fold_ratio, use_memory=use_memory)
im_size=im_size, fold_ratio=fold_ratio, use_memory=use_memory,
truncate_testset=False)
print(f'Train dataset length: {len(train_dataset)}\n')
else:
train_dataset = None
Expand All @@ -420,7 +424,8 @@ def AISegment_get_datasets(data, load_train=True, load_test=True, im_size=(80, 8

test_dataset = AISegment(root_dir=data_dir, d_type='test',
transform=test_transform,
im_size=im_size, fold_ratio=fold_ratio, use_memory=use_memory)
im_size=im_size, fold_ratio=fold_ratio, use_memory=use_memory,
truncate_testset=args.truncate_testset)

print(f'Test dataset length: {len(test_dataset)}\n')
else:
Expand Down

0 comments on commit 946077b

Please sign in to comment.