-
Notifications
You must be signed in to change notification settings - Fork 1
/
classifiers_wavenet.py
141 lines (112 loc) · 4.37 KB
/
classifiers_wavenet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import os
import torch
from torch.nn import Embedding
from scipy.io import loadmat
from wavenets_simple import WavenetSimple
from classifiers_simpleNN import SimpleClassifier, ClassifierModule
class WavenetClassifier(SimpleClassifier):
'''
This class adds a classifier on top of the normal wavenet.
'''
def loaded(self, args):
super(WavenetClassifier, self).loaded(args)
self.wavenet.loaded(args)
def kernel_network_FIR(self):
self.wavenet.kernel_network_FIR()
def analyse_kernels(self):
self.wavenet.analyse_kernels()
def kernelPFI(self, data, sid=None):
return self.wavenet.kernelPFI(data, sid)
def build_model(self, args):
if args.wavenet_class:
self.wavenet = args.wavenet_class(args)
else:
self.wavenet = WavenetSimple(args)
self.class_dim = self.wavenet.ch * int(args.sample_rate/args.rf)
self.classifier = ClassifierModule(args, self.class_dim)
def forward(self, x, sid=None):
'''
Run wavenet on input then feed the output into the classifier.
'''
output, x = self.wavenet(x, sid)
x = x[:, :, ::self.args.rf].reshape(x.shape[0], -1)
x = self.classifier(x)
return output, x
class WavenetClassifierSemb(WavenetClassifier):
'''
Wavenet Classifier for multi-subject data using subject embeddings.
'''
def set_sub_dict(self):
# this dictionary is needed because
# subject embeddings and subjects have a different ordering
self.sub_dict = {0: 10,
1: 7,
2: 3,
3: 11,
4: 8,
5: 4,
6: 12,
7: 9,
8: 5,
9: 13,
10: 1,
11: 14,
12: 2,
13: 6,
14: 0}
def __init__(self, args):
super(WavenetClassifierSemb, self).__init__(args)
self.set_sub_dict()
def loaded(self, args):
super(WavenetClassifierSemb, self).loaded(args)
self.set_sub_dict()
# change embedding to an already trained one
if 'trained_semb' in args.result_dir:
path = os.path.join(args.load_model, '..', 'sub_emb.mat')
semb = torch.tensor(loadmat(path)['X']).cuda()
self.wavenet.subject_emb.weight = torch.nn.Parameter(semb)
def build_model(self, args):
self.wavenet = args.wavenet_class(args)
self.class_dim = self.wavenet.ch * int(args.sample_rate/args.rf)
self.classifier = ClassifierModule(args, self.class_dim)
def save_embeddings(self):
self.wavenet.save_embeddings()
def get_sid(self, sid):
'''
Get subject id based on result directory name.
'''
ind = int(self.args.result_dir.split('_')[-1].split('/')[0])
ind = self.sub_dict[ind]
sid = torch.LongTensor([ind]).repeat(*list(sid.shape)).cuda()
return sid
def get_sid_exc(self, sid):
'''
Get subject embedding of untrained subject
'''
ind = int(self.args.result_dir.split('_')[-1].split('/')[0])
sid = torch.LongTensor([ind]).repeat(*list(sid.shape)).cuda()
return sid
def get_sid_best(self, sid):
ind = 8
sid = torch.LongTensor([ind]).repeat(*list(sid.shape)).cuda()
return sid
def ensemble_forward(self, x, sid):
outputs = []
for i in range(15):
subid = torch.LongTensor([i]).repeat(*list(sid.shape)).cuda()
_, out_class = super(WavenetClassifierSemb, self).forward(x, subid)
outputs.append(out_class.detach())
outputs = torch.stack(outputs)
outputs = torch.mean(outputs, dim=0)
return None, outputs
def forward(self, x, sid=None):
if not self.args.keep_sid:
if 'sub' in self.args.result_dir:
sid = self.get_sid(sid)
if 'exc' in self.args.result_dir:
sid = self.get_sid_exc(sid)
if 'best' in self.args.result_dir:
sid = self.get_sid_best(sid)
if 'ensemble' in self.args.result_dir:
return self.ensemble_forward(x, sid)
return super(WavenetClassifierSemb, self).forward(x, sid)