forked from microsoft/singleshotpose
-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataset.py
141 lines (118 loc) · 6.04 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
#!/usr/bin/python
# encoding: utf-8
import os
import random
from PIL import Image
import numpy as np
from image import *
import torch
from torch.utils.data import Dataset
from utils import read_truths_args, read_truths, get_all_files
class listDataset(Dataset):
def __init__(self, root, shape=None, shuffle=True, transform=None, target_transform=None, train=False, seen=0, batch_size=64, num_workers=4, cell_size=32, bg_file_names=None, num_keypoints=9, max_num_gt=50):
# root : list of training or test images
# shape : shape of the image input to the network
# shuffle : whether to shuffle or not
# tranform : any pytorch-specific transformation to the input image
# target_transform : any pytorch-specific tranformation to the target output
# train : whether it is training data or test data
# seen : the number of visited examples (iteration of the batch x batch size) # TODO: check if this is correctly assigned
# batch_size : how many examples there are in the batch
# num_workers : check what this is
# bg_file_names : the filenames for images from which you assign random backgrounds
# read the the list of dataset images
with open(root, 'r') as file:
self.lines = file.readlines()
# Shuffle
if shuffle:
random.shuffle(self.lines)
# Initialize variables
self.nSamples = len(self.lines)
self.transform = transform
self.target_transform = target_transform
self.train = train
self.shape = shape
self.seen = seen
self.batch_size = batch_size
self.num_workers = num_workers
self.bg_file_names = bg_file_names
self.cell_size = cell_size
self.nbatches = self.nSamples // self.batch_size
self.num_keypoints = num_keypoints
self.max_num_gt = max_num_gt # maximum number of ground-truth labels an image can have
# Get the number of samples in the dataset
def __len__(self):
return self.nSamples
# Get a sample from the dataset
def __getitem__(self, index):
# Ensure the index is smallet than the number of samples in the dataset, otherwise return error
assert index <= len(self), 'index range error'
# Get the image path
imgpath = self.lines[index].rstrip()
# Decide which size you are going to resize the image depending on the epoch (10, 20, etc.)
if self.train and index % self.batch_size== 0:
if self.seen < 10*self.nbatches*self.batch_size:
width = 13*self.cell_size
self.shape = (width, width)
elif self.seen < 20*self.nbatches*self.batch_size:
width = (random.randint(0,7) + 13)*self.cell_size
self.shape = (width, width)
elif self.seen < 30*self.nbatches*self.batch_size:
width = (random.randint(0,9) + 12)*self.cell_size
self.shape = (width, width)
elif self.seen < 40*self.nbatches*self.batch_size:
width = (random.randint(0,11) + 11)*self.cell_size
self.shape = (width, width)
elif self.seen < 50*self.nbatches*self.batch_size:
width = (random.randint(0,13) + 10)*self.cell_size
self.shape = (width, width)
elif self.seen < 60*self.nbatches*self.batch_size:
width = (random.randint(0,15) + 9)*self.cell_size
self.shape = (width, width)
elif self.seen < 70*self.nbatches*self.batch_size:
width = (random.randint(0,17) + 8)*self.cell_size
self.shape = (width, width)
else:
width = (random.randint(0,19) + 7)*self.cell_size
self.shape = (width, width)
if self.train:
# Decide on how much data augmentation you are going to apply
jitter = 0.2
hue = 0.1
saturation = 1.5
exposure = 1.5
# Get background image path
random_bg_index = random.randint(0, len(self.bg_file_names) - 1)
bgpath = self.bg_file_names[random_bg_index]
# Get the data augmented image and their corresponding labels
img, label = load_data_detection(imgpath, self.shape, jitter, hue, saturation, exposure, bgpath, self.num_keypoints, self.max_num_gt)
# Convert the labels to PyTorch variables
label = torch.from_numpy(label)
else:
# Get the validation image, resize it to the network input size
img = Image.open(imgpath).convert('RGB')
if self.shape:
img = img.resize(self.shape)
# Read the validation labels, allow upto 50 ground-truth objects in an image
labpath = imgpath.replace('images', 'labels').replace('JPEGImages', 'labels').replace('.jpg', '.txt').replace('.png','.txt')
num_labels = 2*self.num_keypoints+3 # +2 for ground-truth of width/height , +1 for class label
label = torch.zeros(self.max_num_gt*num_labels)
if os.path.getsize(labpath):
ow, oh = img.size
tmp = torch.from_numpy(read_truths_args(labpath))
tmp = tmp.view(-1)
tsz = tmp.numel()
if tsz > self.max_num_gt*num_labels:
label = tmp[0:self.max_num_gt*num_labels]
elif tsz > 0:
label[0:tsz] = tmp
# Tranform the image data to PyTorch tensors
if self.transform is not None:
img = self.transform(img)
# If there is any PyTorch-specific transformation, transform the label data
if self.target_transform is not None:
label = self.target_transform(label)
# Increase the number of seen examples
self.seen = self.seen + self.num_workers
# Return the retrieved image and its corresponding label
return (img, label)