-
Notifications
You must be signed in to change notification settings - Fork 5
/
dataset.py
124 lines (78 loc) · 3.99 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import StratifiedKFold
from random import shuffle, randrange
import numpy as np
import random
#from .DataLoaders.hcpRestLoader import hcpRestLoader
#from .DataLoaders.hcpTaskLoader import hcpTaskLoader
from .DataLoaders.abide1Loader import abide1Loader
loaderMapper = {
#"hcpRest" : hcpRestLoader,
#"hcpTask" : hcpTaskLoader,
"abide1" : abide1Loader,
}
def getDataset(options):
return SupervisedDataset(options)
class SupervisedDataset(Dataset):
def __init__(self, datasetDetails):
self.batchSize = datasetDetails.batchSize
self.dynamicLength = datasetDetails.dynamicLength
self.foldCount = datasetDetails.foldCount
self.seed = datasetDetails.datasetSeed
loader = loaderMapper[datasetDetails.datasetName]
self.kFold = StratifiedKFold(datasetDetails.foldCount, shuffle=False, random_state=None) if datasetDetails.foldCount is not None else None
self.k = None
self.data, self.labels, self.subjectIds = loader(datasetDetails.atlas, datasetDetails.targetTask)
random.Random(self.seed).shuffle(self.data)
random.Random(self.seed).shuffle(self.labels)
random.Random(self.seed).shuffle(self.subjectIds)
self.targetData = None
self.targetLabel = None
self.targetSubjIds = None
self.randomRanges = None
self.trainIdx = None
self.testIdx = None
def __len__(self):
return len(self.data) if isinstance(self.targetData, type(None)) else len(self.targetData)
def get_nOfTrains_perFold(self):
if(self.foldCount != None):
return int(np.ceil(len(self.data) * (self.foldCount - 1) / self.foldCount))
else:
return len(self.data)
def setFold(self, fold, train=True):
self.k = fold
self.train = train
if(self.foldCount == None): # if this is the case, train must be True
trainIdx = list(range(len(self.data)))
else:
trainIdx, testIdx = list(self.kFold.split(self.data, self.labels))[fold]
self.trainIdx = trainIdx
self.testIdx = testIdx
random.Random(self.seed).shuffle(trainIdx)
self.targetData = [self.data[idx] for idx in trainIdx] if train else [self.data[idx] for idx in testIdx]
self.targetLabels = [self.labels[idx] for idx in trainIdx] if train else [self.labels[idx] for idx in testIdx]
self.targetSubjIds = [self.subjectIds[idx] for idx in trainIdx] if train else [self.subjectIds[idx] for idx in testIdx]
if(train and not isinstance(self.dynamicLength, type(None))):
np.random.seed(self.seed+1)
self.randomRanges = [[np.random.randint(0, self.data[idx].shape[-1] - self.dynamicLength) for k in range(9999)] for idx in trainIdx]
def getFold(self, fold, train=True):
self.setFold(fold, train)
if(train):
return DataLoader(self, batch_size=self.batchSize, shuffle=False)
else:
return DataLoader(self, batch_size=1, shuffle=False)
def __getitem__(self, idx):
subject = self.targetData[idx]
label = self.targetLabels[idx]
subjId = self.targetSubjIds[idx]
# normalize timeseries
timeseries = subject # (numberOfRois, time)
timeseries = (timeseries - np.mean(timeseries, axis=1, keepdims=True)) / np.std(timeseries, axis=1, keepdims=True)
timeseries = np.nan_to_num(timeseries, 0)
# dynamic sampling if train
if(self.train and not isinstance(self.dynamicLength, type(None))):
if(timeseries.shape[1] < self.dynamicLength):
print(timeseries.shape[1], self.dynamicLength)
samplingInit = self.randomRanges[idx].pop()
timeseries = timeseries[:, samplingInit : samplingInit + self.dynamicLength]
return {"timeseries" : timeseries.astype(np.float32), "label" : label, "subjId" : subjId}