-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcityscapes.py
128 lines (112 loc) · 4.33 KB
/
cityscapes.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
#!/usr/bin/python
# -*- encoding: utf-8 -*-
from torch.utils.data import Dataset
import torchvision.transforms as transforms
import os.path as osp
import os
from PIL import Image
import numpy as np
import json
from transform import ColorJitter, HorizontalFlip, RandomCrop, RandomScale, Compose
class CityScapes(Dataset):
def __init__(self,
rootpth,
cropsize=(640, 480),
mode='train',
randomscale=(0.125, 0.25, 0.375, 0.5, 0.675, 0.75, 0.875, 1.0,
1.25, 1.5),
*args,
**kwargs):
super(CityScapes, self).__init__(*args, **kwargs)
assert mode in ('train', 'val', 'test', 'trainval')
self.mode = mode
print('self.mode', self.mode)
self.ignore_lb = 255
# CityScapes数据集的说明,label和色彩的对应
with open('./cityscapes_info.json', 'r') as fr:
labels_info = json.load(fr)
self.lb_map = {el['id']: el['trainId'] for el in labels_info}
# parse img directory
# 获取mode对应的文件名和路径
self.imgs = {}
imgnames = []
impth = osp.join(rootpth, 'leftImg8bit', mode)
folders = os.listdir(impth)
for fd in folders:
fdpth = osp.join(impth, fd)
im_names = os.listdir(fdpth)
names = [el.replace('_leftImg8bit.png', '') for el in im_names]
impths = [osp.join(fdpth, el) for el in im_names]
imgnames.extend(names)
self.imgs.update(dict(zip(names, impths)))
## parse gt directory
# 获取gt
self.labels = {}
gtnames = []
gtpth = osp.join(rootpth, 'gtFine', mode)
folders = os.listdir(gtpth)
for fd in folders:
fdpth = osp.join(gtpth, fd)
lbnames = os.listdir(fdpth)
lbnames = [el for el in lbnames if 'labelIds' in el]
names = [el.replace('_gtFine_labelIds.png', '') for el in lbnames]
lbpths = [osp.join(fdpth, el) for el in lbnames]
gtnames.extend(names)
self.labels.update(dict(zip(names, lbpths)))
self.imnames = imgnames
self.len = len(self.imnames)
print('self.len', self.mode, self.len)
# 异常判断,False自动退出
assert set(imgnames) == set(gtnames)
assert set(self.imnames) == set(self.imgs.keys())
assert set(self.imnames) == set(self.labels.keys())
# pre-processing
self.to_tensor = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
# 自定义的一些图形处理方法
self.trans_train = Compose([
ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),
HorizontalFlip(),
# RandomScale((0.25, 0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0)),
RandomScale(randomscale),
# RandomScale((0.125, 1)),
# RandomScale((0.125, 0.25, 0.375, 0.5, 0.675, 0.75, 0.875, 1.0)),
# RandomScale((0.125, 0.25, 0.375, 0.5, 0.675, 0.75, 0.875, 1.0, 1.125, 1.25, 1.375, 1.5)),
RandomCrop(cropsize)
])
def __getitem__(self, idx):
# 读文件名
fn = self.imnames[idx]
# 根据文件名找路径
impth = self.imgs[fn]
lbpth = self.labels[fn]
# 读取文件
img = Image.open(impth).convert('RGB')
label = Image.open(lbpth)
if self.mode == 'train' or self.mode == 'trainval':
im_lb = dict(im=img, lb=label)
# 统一处理
im_lb = self.trans_train(im_lb)
img, label = im_lb['im'], im_lb['lb']
img = self.to_tensor(img)
label = np.array(label).astype(np.int64)[np.newaxis, :]
label = self.convert_labels(label)
return img, label
def __len__(self):
return self.len
# 根据info中id转换颜色trainID
def convert_labels(self, label):
for k, v in self.lb_map.items():
label[label == k] = v
return label
if __name__ == "__main__":
from tqdm import tqdm
ds = CityScapes('./data/', n_classes=19, mode='val')
uni = []
for im, lb in tqdm(ds):
lb_uni = np.unique(lb).tolist()
uni.extend(lb_uni)
print(uni)
print(set(uni))