-
Notifications
You must be signed in to change notification settings - Fork 23
/
Copy pathdataset_augmented.py
86 lines (73 loc) · 3.05 KB
/
dataset_augmented.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import os
import torch
import numpy as np
from torch.utils.data import Dataset
import cv2 as cv
import h5py
from core import DefaultConfig
config = DefaultConfig()
class HDFDataset2(Dataset):
"""Dataset from HDF5 archives formed of 'groups' of specific persons."""
def __init__(self, hdf_file_path, use_aug=True, use_real=True,
prefixes=None,
):
assert os.path.isfile(hdf_file_path)
self.hdf_path = hdf_file_path
self.hdf = None # h5py.File(hdf_file, 'r')
with h5py.File(self.hdf_path, 'r', libver='latest', swmr=True) as h5f:
hdf_keys = sorted(list(h5f.keys()))
self.prefixes = hdf_keys if prefixes is None else prefixes
# Pick all entries of person
self.prefixes = [ # to address erroneous inputs
k for k in self.prefixes if k in h5f
and len(next(iter(h5f[k].values()))) > 0
]
if use_aug and use_real:
self.index_to_query = sum([
[(prefix, i) for i in range(len(next(iter(h5f[prefix].values()))))]
for prefix in self.prefixes], [])
elif use_real:
self.index_to_query = sum([
[(prefix, i) for i in range(len(next(iter(h5f[prefix].values())))) if h5f[prefix]['real'][i]]
for prefix in self.prefixes], [])
else:
assert use_aug
self.index_to_query = sum([
[(prefix, i) for i in range(len(next(iter(h5f[prefix].values())))) if (not h5f[prefix]['real'][i])]
for prefix in self.prefixes], [])
def __len__(self):
return len(self.index_to_query)
def close_hdf(self):
if self.hdf is not None:
self.hdf.close()
self.hdf = None
def preprocess_image(self, image):
return image
def preprocess_entry(self, entry):
for key, val in entry.items():
if isinstance(val, np.ndarray):
entry[key] = torch.from_numpy(val.astype(np.float32))
elif isinstance(val, int):
# NOTE: maybe ints should be signed and 32-bits sometimes
entry[key] = torch.tensor(val, dtype=torch.int16, requires_grad=False)
return entry
def __getitem__(self, idx):
if self.hdf is None: # Need to lazy-open this to avoid read error
self.hdf = h5py.File(self.hdf_path, 'r', libver='latest', swmr=True)
# Pick entry a and b from same person
key_a, idx_a = self.index_to_query[idx]
group_a = self.hdf[key_a]
def retrieve(group, index):
eyes = self.preprocess_image(group['image'][index])
g = group['gaze'][index]
h = group['head'][index]
return eyes, g, h
# Grab 1st (input) entry
eyes_a, g_a, h_a = retrieve(group_a, idx_a)
entry = {
'key': key_a,
'image_a': eyes_a,
'gaze_a': g_a,
'head_a': h_a,
}
return self.preprocess_entry(entry)