-
Notifications
You must be signed in to change notification settings - Fork 23
/
dataloader.py
179 lines (153 loc) · 6.57 KB
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from pc_utils import (rotate_point_cloud, PointcloudScaleAndTranslate)
import rs_cnn.data.data_utils as rscnn_d_utils
from rs_cnn.data.ModelNet40Loader import ModelNet40Cls as rscnn_ModelNet40Cls
import pointnet2.utils.pointnet2_utils as pointnet2_utils
from pointnet2_tf.modelnet_h5_dataset import ModelNetH5Dataset as pointnet2_ModelNetH5Dataset
from dgcnn.pytorch.data import ModelNet40 as dgcnn_ModelNet40
# distilled from the following sources:
# https://github.com/Yochengliu/Relation-Shape-CNN/blob/master/data/ModelNet40Loader.py
# https://github.com/Yochengliu/Relation-Shape-CNN/blob/master/train_cls.py
class ModelNet40Rscnn(Dataset):
def __init__(self, split, data_path, train_data_path,
valid_data_path, test_data_path, num_points):
self.split = split
self.num_points = num_points
_transforms = transforms.Compose([rscnn_d_utils.PointcloudToTensor()])
rscnn_params = {
'num_points': 1024, # although it does not matter
'root': data_path,
'transforms': _transforms,
'train': (split in ["train", "valid"]),
'data_file': {
'train': train_data_path,
'valid': valid_data_path,
'test': test_data_path
}[self.split]
}
self.rscnn_dataset = rscnn_ModelNet40Cls(**rscnn_params)
self.PointcloudScaleAndTranslate = PointcloudScaleAndTranslate()
def __len__(self):
return self.rscnn_dataset.__len__()
def __getitem__(self, idx):
point, label = self.rscnn_dataset.__getitem__(idx)
# for compatibility with the overall code
point = np.array(point)
label = label[0].item()
return {'pc': point, 'label': label}
def batch_proc(self, data_batch, device):
point = data_batch['pc'].to(device)
if self.split == "train":
# (B, npoint)
fps_idx = pointnet2_utils.furthest_point_sample(point, 1200)
fps_idx = fps_idx[:, np.random.choice(1200, self.num_points,
False)]
point = pointnet2_utils.gather_operation(
point.transpose(1, 2).contiguous(),
fps_idx).transpose(1, 2).contiguous() # (B, N, 3)
point.data = self.PointcloudScaleAndTranslate(point.data)
else:
fps_idx = pointnet2_utils.furthest_point_sample(
point, self.num_points) # (B, npoint)
point = pointnet2_utils.gather_operation(
point.transpose(1, 2).contiguous(),
fps_idx).transpose(1, 2).contiguous()
# to maintain compatibility
point = point.cpu()
return {'pc': point, 'label': data_batch['label']}
# distilled from the following sources:
# https://github.com/charlesq34/pointnet2/blob/7961e26e31d0ba5a72020635cee03aac5d0e754a/modelnet_h5_dataset.py
# https://github.com/charlesq34/pointnet2/blob/7961e26e31d0ba5a72020635cee03aac5d0e754a/train.py
class ModelNet40PN2(Dataset):
def __init__(self, split, train_data_path,
valid_data_path, test_data_path, num_points):
self.split = split
self.dataset_name = 'modelnet40_pn2'
data_path = {
"train": train_data_path,
"valid": valid_data_path,
"test": test_data_path
}[self.split]
pointnet2_params = {
'list_filename': data_path,
# this has nothing to do with actual dataloader batch size
'batch_size': 32,
'npoints': num_points,
'shuffle': False
}
# loading all the pointnet2data
self._dataset = pointnet2_ModelNetH5Dataset(**pointnet2_params)
all_pc = []
all_label = []
while self._dataset.has_next_batch():
# augmentation here has nothing to do with actual data_augmentation
pc, label = self._dataset.next_batch(augment=False)
all_pc.append(pc)
all_label.append(label)
self.all_pc = np.concatenate(all_pc)
self.all_label = np.concatenate(all_label)
def __len__(self):
return self.all_pc.shape[0]
def __getitem__(self, idx):
return {'pc': self.all_pc[idx], 'label': np.int64(self.all_label[idx])}
def batch_proc(self, data_batch, device):
if self.split == "train":
point = np.array(data_batch['pc'])
point = self._dataset._augment_batch_data(point)
# converted to tensor to maintain compatibility with the other code
data_batch['pc'] = torch.tensor(point)
else:
pass
return data_batch
class ModelNet40Dgcnn(Dataset):
def __init__(self, split, train_data_path,
valid_data_path, test_data_path, num_points):
self.split = split
self.data_path = {
"train": train_data_path,
"valid": valid_data_path,
"test": test_data_path
}[self.split]
dgcnn_params = {
'partition': 'train' if split in ['train', 'valid'] else 'test',
'num_points': num_points,
"data_path": self.data_path
}
self.dataset = dgcnn_ModelNet40(**dgcnn_params)
def __len__(self):
return self.dataset.__len__()
def __getitem__(self, idx):
pc, label = self.dataset.__getitem__(idx)
return {'pc': pc, 'label': label.item()}
def create_dataloader(split, cfg):
num_workers = cfg.DATALOADER.num_workers
batch_size = cfg.DATALOADER.batch_size
dataset_args = {
"split": split
}
if cfg.EXP.DATASET == "modelnet40_rscnn":
dataset_args.update(dict(**cfg.DATALOADER.MODELNET40_RSCNN))
# augmentation directly done in the code so that
# it is as similar to the vanilla code as possible
dataset = ModelNet40Rscnn(**dataset_args)
elif cfg.EXP.DATASET == "modelnet40_pn2":
dataset_args.update(dict(**cfg.DATALOADER.MODELNET40_PN2))
dataset = ModelNet40PN2(**dataset_args)
elif cfg.EXP.DATASET == "modelnet40_dgcnn":
dataset_args.update(dict(**cfg.DATALOADER.MODELNET40_DGCNN))
dataset = ModelNet40Dgcnn(**dataset_args)
else:
assert False
if "batch_proc" not in dir(dataset):
dataset.batch_proc = None
return DataLoader(
dataset,
batch_size,
num_workers=num_workers,
shuffle=(split == "train"),
drop_last=(split == "train"),
pin_memory=(torch.cuda.is_available()) and (not num_workers)
)