-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathPairedDataSet.py
76 lines (50 loc) · 2.47 KB
/
PairedDataSet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from utils import load_img, save_img, load_itk, save_itk, tensor2img, img2tensor
import torch
from torch.utils.data import Dataset
import os
import glob
import random
class PairedData(Dataset):
def __init__(self, root, target = 'train', use_fine_mask = True, use_coarse_mask=False, use_num=-100):
super(Dataset, self).__init__()
self.use_fine_mask = use_fine_mask
self.use_coarse_mask = use_coarse_mask
name_list = os.listdir(os.path.join(root, target, "fd"))
self.HR_path = []
self.LR_path = []
self.MASK_fine = []
self.MASK_coarse = []
for i, name in enumerate(name_list):
self.HR_path.append(os.path.join(root, target, "fd", name))
self.LR_path.append(os.path.join(root, target, "qd", name))
self.MASK_fine.append(os.path.join(root, target, "qd_seg_123456", name.replace('.jpg', '.nii.gz').replace('.png', '.nii.gz').replace('.bmp', '.nii.gz')))
self.MASK_coarse.append(os.path.join(root, target, "qd_seg_123", name.replace('.jpg', '.nii.gz').replace('.png', '.nii.gz').replace('.bmp', '.nii.gz')))
if i==use_num-1:
break
self.length = len(self.HR_path)
self.target = target
def __len__(self):
return self.length
def __getitem__(self, idx):
hr_img = img2tensor(load_img(self.HR_path[idx], grayscale=True))
lr_img = img2tensor(load_img(self.LR_path[idx], grayscale=True))
_, file_name = os.path.split(self.LR_path[idx])
if self.use_fine_mask:
mask_fine = torch.from_numpy(load_itk(self.MASK_fine[idx]))
else:
mask_fine = []
if self.use_coarse_mask:
mask_coarse = torch.from_numpy(load_itk(self.MASK_coarse[idx]))
else:
mask_coarse = []
if self.target == "train":
i = random.choice([1, 2, 3, 4])
hr_img = torch.rot90(hr_img, i, [1, 2])
lr_img = torch.rot90(lr_img, i, [1, 2])
if self.use_fine_mask:
mask_fine = torch.rot90(mask_fine, i, [0, 1])
if self.use_coarse_mask:
mask_coarse = torch.rot90(mask_coarse, i, [0, 1])
# import pdb
# pdb.set_trace()
return hr_img, lr_img, mask_fine, mask_coarse, file_name