-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataset_class.py
57 lines (42 loc) · 1.64 KB
/
dataset_class.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from typing import Tuple, Optional, Union, Dict
from pickle import load as pickle_load
from pathlib import Path
from torch.utils.data import Dataset
import numpy as np
from utils import get_files_from_dir_with_pathlib
__docformat__ = 'reStructuredText'
__all__ = ['IrmasDataset']
class IrmasDataset(Dataset):
def __init__(self,
data_dir: Union[str, Path],
key_features: Optional[str] = 'features',
key_class: Optional[str] = 'class') \
-> None:
super().__init__()
self.key_features = key_features
self.key_class = key_class
data_path = Path(data_dir)
self.files = get_files_from_dir_with_pathlib(data_path)
for i, a_file in enumerate(self.files):
self.files[i] = self._load_file(a_file)
@staticmethod
def _load_file(file_path: Path)\
-> Dict[np.ndarray, np.ndarray]:
"""Loads a file using pathlib.Path
:param file_path: File path.
:type file_path: pathlib.Path
:return: The file.
:rtype: dict[str, int|numpy.ndarray]
"""
with file_path.open('rb') as f:
return pickle_load(f)
def __len__(self) \
-> int:
"""Returns the lenght of the dataset."""
return len(self.files)
def __getitem__(self,
item: int) \
-> Tuple[np.ndarray, np.ndarray]:
the_item: Dict[str, Union[int, np.ndarray]] = self.files[item]
return the_item[0], the_item[1]
# EOF