forked from zhangks98/eeg-adapt
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_adapt.py
187 lines (161 loc) · 6.71 KB
/
train_adapt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
#!/usr/bin/env python
# coding: utf-8
'''Subject-adaptative classification with KU Data,
using Deep ConvNet model from [1].
References
----------
.. [1] Schirrmeister, R. T., Springenberg, J. T., Fiederer, L. D. J.,
Glasstetter, M., Eggensperger, K., Tangermann, M., Hutter, F. & Ball, T. (2017).
Deep learning with convolutional neural networks for EEG decoding and
visualization.
Human Brain Mapping , Aug. 2017. Online: http://dx.doi.org/10.1002/hbm.23730
'''
import argparse
import json
import logging
import sys
from os.path import join as pjoin
import h5py
import torch
import torch.nn.functional as F
from braindecode.models.deep4 import Deep4Net
from braindecode.torch_ext.optimizers import AdamW
from braindecode.torch_ext.util import set_random_seeds
from torch import nn
logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',
level=logging.INFO, stream=sys.stdout)
parser = argparse.ArgumentParser(
description='Subject-adaptative classification with KU Data')
parser.add_argument('datapath', type=str, help='Path to the h5 data file')
parser.add_argument('modelpath', type=str,
help='Path to the base model folder')
parser.add_argument('outpath', type=str, help='Path to the result folder')
parser.add_argument('-scheme', type=int, help='Adaptation scheme', default=4)
parser.add_argument(
'-trfrate', type=int, help='The percentage of data for adaptation', default=100)
parser.add_argument('-lr', type=float, help='Learning rate', default=0.0005)
parser.add_argument('-gpu', type=int, help='The gpu device to use', default=0)
args = parser.parse_args()
datapath = args.datapath
outpath = args.outpath
modelpath = args.modelpath
scheme = args.scheme
rate = args.trfrate
lr = args.lr
dfile = h5py.File(datapath, 'r')
torch.cuda.set_device(args.gpu)
set_random_seeds(seed=20200205, cuda=True)
BATCH_SIZE = 16
TRAIN_EPOCH = 200
# Randomly shuffled subject.
subjs = [35, 47, 46, 37, 13, 27, 12, 32, 53, 54, 4, 40, 19, 41, 18, 42, 34, 7,
49, 9, 5, 48, 29, 15, 21, 17, 31, 45, 1, 38, 51, 8, 11, 16, 28, 44, 24,
52, 3, 26, 39, 50, 6, 23, 2, 14, 25, 20, 10, 33, 22, 43, 36, 30]
# Get data from single subject.
def get_data(subj):
dpath = '/s' + str(subj)
X = dfile[pjoin(dpath, 'X')]
Y = dfile[pjoin(dpath, 'Y')]
return X[:], Y[:]
X, Y = get_data(subjs[0])
n_classes = 2
in_chans = X.shape[1]
# final_conv_length = auto ensures we only get a single output in the time dimension
model = Deep4Net(in_chans=in_chans, n_classes=n_classes,
input_time_length=X.shape[2],
final_conv_length='auto').cuda()
# Deprecated.
def reset_conv_pool_block(network, block_nr):
suffix = "_{:d}".format(block_nr)
conv = getattr(network, 'conv' + suffix)
kernel_size = conv.kernel_size
n_filters_before = conv.in_channels
n_filters = conv.out_channels
setattr(network, 'conv' + suffix,
nn.Conv2d(
n_filters_before,
n_filters,
kernel_size,
stride=(1, 1),
bias=False,
))
setattr(network, 'bnorm' + suffix,
nn.BatchNorm2d(
n_filters,
momentum=0.1,
affine=True,
eps=1e-5,
))
# Initialize the layers.
conv = getattr(network, 'conv' + suffix)
bnorm = getattr(network, 'bnorm' + suffix)
nn.init.xavier_uniform_(conv.weight, gain=1)
nn.init.constant_(bnorm.weight, 1)
nn.init.constant_(bnorm.bias, 0)
def reset_model(checkpoint):
# Load the state dict of the model.
model.network.load_state_dict(checkpoint['model_state_dict'])
# # Resets the last conv block
# reset_conv_pool_block(model.network, block_nr=4)
# reset_conv_pool_block(model.network, block_nr=3)
# reset_conv_pool_block(model.network, block_nr=2)
# # Resets the fully-connected layer.
# # Parameters of newly constructed modules have requires_grad=True by default.
# n_final_conv_length = model.network.conv_classifier.kernel_size[0]
# n_prev_filter = model.network.conv_classifier.in_channels
# n_classes = model.network.conv_classifier.out_channels
# model.network.conv_classifier = nn.Conv2d(
# n_prev_filter, n_classes, (n_final_conv_length, 1), bias=True)
# nn.init.xavier_uniform_(model.network.conv_classifier.weight, gain=1)
# nn.init.constant_(model.network.conv_classifier.bias, 0)
if scheme != 5:
# Freeze all layers.
for param in model.network.parameters():
param.requires_grad = False
if scheme in {1, 2, 3, 4}:
# Unfreeze the FC layer.
for param in model.network.conv_classifier.parameters():
param.requires_grad = True
if scheme in {2, 3, 4}:
# Unfreeze the conv4 layer.
for param in model.network.conv_4.parameters():
param.requires_grad = True
for param in model.network.bnorm_4.parameters():
param.requires_grad = True
if scheme in {3, 4}:
# Unfreeze the conv3 layer.
for param in model.network.conv_3.parameters():
param.requires_grad = True
for param in model.network.bnorm_3.parameters():
param.requires_grad = True
if scheme == 4:
# Unfreeze the conv2 layer.
for param in model.network.conv_2.parameters():
param.requires_grad = True
for param in model.network.bnorm_2.parameters():
param.requires_grad = True
# Only optimize parameters that requires gradient.
optimizer = AdamW(filter(lambda p: p.requires_grad, model.network.parameters()),
lr=lr, weight_decay=0.5*0.001)
model.compile(loss=F.nll_loss, optimizer=optimizer,
iterator_seed=20200205, )
cutoff = int(rate * 200 / 100)
# Use only session 1 data for training
assert(cutoff <= 200)
for fold, subj in enumerate(subjs):
suffix = '_s' + str(subj) + '_f' + str(fold)
checkpoint = torch.load(pjoin(modelpath, 'model_f' + str(fold) + '.pt'),
map_location='cuda:' + str(args.gpu))
reset_model(checkpoint)
X, Y = get_data(subj)
X_train, Y_train = X[:cutoff], Y[:cutoff]
X_val, Y_val = X[200:300], Y[200:300]
X_test, Y_test = X[300:], Y[300:]
model.fit(X_train, Y_train, epochs=TRAIN_EPOCH,
batch_size=BATCH_SIZE, scheduler='cosine',
validation_data=(X_val, Y_val), remember_best_column='valid_loss')
model.epochs_df.to_csv(pjoin(outpath, 'epochs' + suffix + '.csv'))
test_loss = model.evaluate(X_test, Y_test)
with open(pjoin(outpath, 'test' + suffix + '.json'), 'w') as f:
json.dump(test_loss, f)
dfile.close()