-
Notifications
You must be signed in to change notification settings - Fork 2
/
deap_cnn_loader.py
92 lines (78 loc) · 3.05 KB
/
deap_cnn_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import torch
import numpy as np
import random
import os
import sys
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
sys.path.append('../')
import configparam
#attack = 'Gaussian noise'
attack = None
eps = 0.1
class deap_cnn_loader(Dataset):
def __init__(self, param):
if param.use_predefined_idx == 1:
self.use_pre_idx = True
else:
self.use_pre_idx = False
self.data_path = param.data_path
target_label = param.target_label
num_sub = param.num_subject
num_trial = param.num_trial
max_num_seq = 1000
self.eeg_data = []
self.eeg_label = []
self.gt_type = 0
if target_label == 'valence':
print('----------------valence--------------')
self.gt_type = 1
elif target_label == 'arousal':
print('----------------arousal--------------')
self.gt_type = 2
elif target_label == 'both':
print('----------------both--------------')
self.gt_type = 3
for s in range(num_sub):
if param.target_subject[0] != 0 and not s+1 in param.target_subject:
print('%d is not in target subject list'%(s+1))
continue
for v in range(num_trial):
# we don't know exact length of each trial. So, if the npy file is not exist, skip to next trial.
for t in range(10, max_num_seq):
eeg_name = self.data_path+'S%02dT%02d_%04d.npy'%(s+1,v+1,t+1)
if os.path.exists(eeg_name):
self.eeg_data.append(self.data_path + 'S%02dT%02d_%04d.npy'%(s+1,v+1,t+1))
if self.gt_type == 1:
self.eeg_label.append(self.data_path + 'S%02dT%02d_%04d_valence.txt'%(s+1,v+1,t+1))
if self.gt_type == 2:
self.eeg_label.append(self.data_path + 'S%02dT%02d_%04d_arousal.txt'%(s+1,v+1,t+1))
if self.gt_type == 3:
self.eeg_label.append(self.data_path + 'S%02dT%02d_%04d_multi.txt'%(s+1,v+1,t+1))
else:
# print('End: %d'%t)
break
self.len = len(self.eeg_data)
print(self.len)
def __getitem__(self, index):
x = np.load(self.eeg_data[index]).astype(np.float32)
f = open(self.eeg_label[index], 'r')
val = float(f.read().replace('\n', ''))
if self.gt_type < 3:
if val >= 5:
y = 1
else:
y = 0
else:
y = int(val)
f.close()
x = x.reshape(-1, x.shape[0], x.shape[1])
x = x.astype(np.float32)
# Add Gaussian noise
if attack == 'Gaussian noise':
for i in range(len(x)):
x[i] = x[i] + np.random.randn(x[i].shape[0], x[i].shape[1]) * eps
x[i] = np.clip(x[i], 0, 1)
return x, y
def __len__(self):
return self.len