-
Notifications
You must be signed in to change notification settings - Fork 15
/
example.py
219 lines (184 loc) · 8.52 KB
/
example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
import logging
import sys
import os.path
from collections import OrderedDict
import numpy as np
from braindecode.datasets.bbci import BBCIDataset
from braindecode.datautil.signalproc import highpass_cnt
import torch.nn.functional as F
import torch as th
from torch import optim
from braindecode.torch_ext.util import set_random_seeds
from braindecode.models.deep4 import Deep4Net
from braindecode.models.shallow_fbcsp import ShallowFBCSPNet
from braindecode.models.util import to_dense_prediction_model
from braindecode.experiments.experiment import Experiment
from braindecode.torch_ext.util import np_to_var
from braindecode.datautil.iterators import CropsFromTrialsIterator
from braindecode.experiments.stopcriteria import MaxEpochs, NoDecrease, Or
from braindecode.torch_ext.constraints import MaxNormDefaultConstraint
from braindecode.experiments.monitors import LossMonitor, MisclassMonitor, \
RuntimeMonitor, CroppedTrialMisclassMonitor
from braindecode.datautil.splitters import split_into_two_sets
from braindecode.datautil.trial_segment import \
create_signal_target_from_raw_mne
from braindecode.mne_ext.signalproc import mne_apply, resample_cnt
from braindecode.datautil.signalproc import exponential_running_standardize
log = logging.getLogger(__name__)
log.setLevel('DEBUG')
def load_bbci_data(filename, low_cut_hz, debug=False):
load_sensor_names = None
if debug:
load_sensor_names = ['C3', 'C4', 'C2']
# we loaded all sensors to always get same cleaning results independent of sensor selection
# There is an inbuilt heuristic that tries to use only EEG channels and that definitely
# works for datasets in our paper
loader = BBCIDataset(filename, load_sensor_names=load_sensor_names)
log.info("Loading data...")
cnt = loader.load()
# Cleaning: First find all trials that have absolute microvolt values
# larger than +- 800 inside them and remember them for removal later
log.info("Cutting trials...")
marker_def = OrderedDict([('Right Hand', [1]), ('Left Hand', [2],),
('Rest', [3]), ('Feet', [4])])
clean_ival = [0, 4000]
set_for_cleaning = create_signal_target_from_raw_mne(cnt, marker_def,
clean_ival)
clean_trial_mask = np.max(np.abs(set_for_cleaning.X), axis=(1, 2)) < 800
log.info("Clean trials: {:3d} of {:3d} ({:5.1f}%)".format(
np.sum(clean_trial_mask),
len(set_for_cleaning.X),
np.mean(clean_trial_mask) * 100))
# now pick only sensors with C in their name
# as they cover motor cortex
C_sensors = ['FC5', 'FC1', 'FC2', 'FC6', 'C3', 'C4', 'CP5',
'CP1', 'CP2', 'CP6', 'FC3', 'FCz', 'FC4', 'C5', 'C1', 'C2',
'C6',
'CP3', 'CPz', 'CP4', 'FFC5h', 'FFC3h', 'FFC4h', 'FFC6h',
'FCC5h',
'FCC3h', 'FCC4h', 'FCC6h', 'CCP5h', 'CCP3h', 'CCP4h', 'CCP6h',
'CPP5h',
'CPP3h', 'CPP4h', 'CPP6h', 'FFC1h', 'FFC2h', 'FCC1h', 'FCC2h',
'CCP1h',
'CCP2h', 'CPP1h', 'CPP2h']
if debug:
C_sensors = load_sensor_names
cnt = cnt.pick_channels(C_sensors)
# Further preprocessings as descibed in paper
log.info("Resampling...")
cnt = resample_cnt(cnt, 250.0)
log.info("Highpassing...")
cnt = mne_apply(
lambda a: highpass_cnt(
a, low_cut_hz, cnt.info['sfreq'], filt_order=3, axis=1),
cnt)
log.info("Standardizing...")
cnt = mne_apply(
lambda a: exponential_running_standardize(a.T, factor_new=1e-3,
init_block_size=1000,
eps=1e-4).T,
cnt)
# Trial interval, start at -500 already, since improved decoding for networks
ival = [-500, 4000]
dataset = create_signal_target_from_raw_mne(cnt, marker_def, ival)
dataset.X = dataset.X[clean_trial_mask]
dataset.y = dataset.y[clean_trial_mask]
return dataset
def load_train_valid_test(
train_filename, test_filename, low_cut_hz, debug=False):
log.info("Loading train...")
full_train_set = load_bbci_data(
train_filename, low_cut_hz=low_cut_hz, debug=debug)
log.info("Loading test...")
test_set = load_bbci_data(
test_filename, low_cut_hz=low_cut_hz, debug=debug)
valid_set_fraction = 0.8
train_set, valid_set = split_into_two_sets(full_train_set,
valid_set_fraction)
log.info("Train set with {:4d} trials".format(len(train_set.X)))
if valid_set is not None:
log.info("Valid set with {:4d} trials".format(len(valid_set.X)))
log.info("Test set with {:4d} trials".format(len(test_set.X)))
return train_set, valid_set, test_set
def run_exp_on_high_gamma_dataset(train_filename, test_filename,
low_cut_hz, model_name,
max_epochs, max_increase_epochs,
np_th_seed,
debug):
input_time_length = 1000
batch_size = 60
lr = 1e-3
weight_decay = 0
train_set, valid_set, test_set = load_train_valid_test(
train_filename=train_filename,
test_filename=test_filename,
low_cut_hz=low_cut_hz, debug=debug)
if debug:
max_epochs = 4
set_random_seeds(np_th_seed, cuda=True)
#torch.backends.cudnn.benchmark = True# sometimes crashes?
n_classes = int(np.max(train_set.y) + 1)
n_chans = int(train_set.X.shape[1])
if model_name == 'deep':
model = Deep4Net(n_chans, n_classes,
input_time_length=input_time_length,
final_conv_length=2).create_network()
elif model_name == 'shallow':
model = ShallowFBCSPNet(
n_chans, n_classes, input_time_length=input_time_length,
final_conv_length=30).create_network()
to_dense_prediction_model(model)
model.cuda()
model.eval()
out = model(np_to_var(train_set.X[:1, :, :input_time_length, None]).cuda())
n_preds_per_input = out.cpu().data.numpy().shape[2]
optimizer = optim.Adam(model.parameters(), weight_decay=weight_decay,
lr=lr)
iterator = CropsFromTrialsIterator(batch_size=batch_size,
input_time_length=input_time_length,
n_preds_per_input=n_preds_per_input,
seed=np_th_seed)
monitors = [LossMonitor(), MisclassMonitor(col_suffix='sample_misclass'),
CroppedTrialMisclassMonitor(
input_time_length=input_time_length), RuntimeMonitor()]
model_constraint = MaxNormDefaultConstraint()
loss_function = lambda preds, targets: F.nll_loss(th.mean(preds, dim=2),
targets)
run_after_early_stop = True
do_early_stop = True
remember_best_column = 'valid_misclass'
stop_criterion = Or([MaxEpochs(max_epochs),
NoDecrease('valid_misclass', max_increase_epochs)])
exp = Experiment(model, train_set, valid_set, test_set, iterator=iterator,
loss_function=loss_function, optimizer=optimizer,
model_constraint=model_constraint,
monitors=monitors,
stop_criterion=stop_criterion,
remember_best_column=remember_best_column,
run_after_early_stop=run_after_early_stop, cuda=True,
do_early_stop=do_early_stop)
exp.run()
return exp
if __name__ == '__main__':
logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',
level=logging.DEBUG, stream=sys.stdout)
subject_id = 1
# have to change the data_folder here to make it run.
data_folder = '/data/schirrmr/schirrmr/HGD-public/reduced/'
train_filename = os.path.join(
data_folder, 'train/{:d}.mat'.format(subject_id))
test_filename = os.path.join(
data_folder, 'test/{:d}.mat'.format(subject_id))
max_epochs = 800
max_increase_epochs = 80
model_name = 'deep' # or shallow
low_cut_hz = 0 # or 4
np_th_seed = 0 # random seed for numpy and pytorch
debug = False
exp = run_exp_on_high_gamma_dataset(train_filename, test_filename,
low_cut_hz, model_name,
max_epochs, max_increase_epochs,
np_th_seed,
debug)
log.info("Last 10 epochs")
log.info("\n" + str(exp.epochs_df.iloc[-10:]))