-
Notifications
You must be signed in to change notification settings - Fork 0
/
OCT_54_dataset.py
82 lines (63 loc) · 2.47 KB
/
OCT_54_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from dataclasses import dataclass
from pandas.core.frame import DataFrame
from torch.utils.data import Dataset
from torchvision.transforms import transforms
from pathlib import Path
import torch
from PIL import Image
import pandas as pd
from oct_Utils import *
@dataclass
class OCT_54_Dataset(Dataset):
data:pd.DataFrame
mode:str
config:dict
#total_slice:int
def __post_init__(self):
#self.countNums,_=calculate_position(str_to_np_mat(self.config['map_matrix']))
#self.point_to_image_slice=reverse_dict(self.countNums)
self.crop_size=self.config['crop_size']
self.image_root=self.config['image_root']
self.label_col=self.config['label_col']
self.mode=self.mode
self.trans = {
'train': transforms.Compose([
transforms.Resize([224,224]),
transforms.ColorJitter(brightness=0.3),
transforms.Grayscale(num_output_channels=3),
#transforms.RandomCrop((self.crop_size, self.crop_size)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
]),
'valid': transforms.Compose([
transforms.Resize([224,224]),
transforms.Grayscale(num_output_channels=3),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
]),
'test': transforms.Compose([
transforms.Resize([224,224]),
transforms.Grayscale(num_output_channels=3),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
}
def read_image(self,path):
img = Image.open(path).convert('RGB')
img = self.trans[self.mode](img)
return img
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
all_columns_name=self.data.columns.values
img = self.read_image(str(self.image_root / self.data.loc[self.data.index[idx], 'image_path']))
label = self.data.loc[self.data.index[idx],self.label_col]
result = {
'img': img,
'label': torch.tensor(label, dtype=torch.float),
#'label_position':str(all_columns_name[column+11])
}
return result