-
Notifications
You must be signed in to change notification settings - Fork 0
/
cluster_model.py
68 lines (53 loc) · 1.94 KB
/
cluster_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch.nn as nn
import torch
from data_loader import torch_dataset_loader
from config import PARAS
torch.manual_seed(1)
class Model(nn.Module):
def __init__(self, feature=PARAS.N_MEL, hidden_size=PARAS.HS, embedding_dim=PARAS.E_DIM):
super(Model, self).__init__()
self.embedding_dim = embedding_dim
self.gru = nn.GRU(input_size=feature,
hidden_size=hidden_size,
num_layers=4,
batch_first=True,
dropout=0.5,
bidirectional=True)
self.embedding = nn.Linear(
hidden_size * 2,
PARAS.N_MEL * embedding_dim,
)
self.activation = nn.Tanh()
@staticmethod
def l2_normalize(x, dim=0, eps=1e-12):
assert (dim < x.dim())
norm = torch.norm(x, 2, dim, keepdim=True)
return x / (norm + eps)
def forward(self, inp):
# batch, seq, feature
n, t, f = inp.size()
out, _ = self.gru(inp)
out = self.embedding(out)
out = self.activation(out)
out = out.view(n, -1, self.embedding_dim)
# batch, TF, embedding
# normalization over embedding dim
out = self.l2_normalize(out, -1)
return out
D_model = Model()
if __name__ == '__main__':
from utils import loss_function
test_loader = torch_dataset_loader(PARAS.DATASET_PATH + 'test.h5', PARAS.BATCH_SIZE, True, PARAS.kwargs)
for _index, data in enumerate(test_loader):
spec_input = data['mix']
label = data['binary_mask']
if PARAS.CUDA:
spec_input = spec_input.cuda()
label = label.cuda()
with torch.no_grad():
predicted = D_model(spec_input)
print(predicted.size())
shape = label.size()
label = label.view((shape[0], shape[1]*shape[2], -1))
print(loss_function(predicted, label))
break