forked from galsang/CNN-sentence-classification-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
60 lines (48 loc) · 2.41 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch
import torch.nn as nn
import torch.nn.functional as F
class CNN(nn.Module):
def __init__(self, **kwargs):
super(CNN, self).__init__()
self.MODEL = kwargs["MODEL"]
self.BATCH_SIZE = kwargs["BATCH_SIZE"]
self.MAX_SENT_LEN = kwargs["MAX_SENT_LEN"]
self.WORD_DIM = kwargs["WORD_DIM"]
self.VOCAB_SIZE = kwargs["VOCAB_SIZE"]
self.CLASS_SIZE = kwargs["CLASS_SIZE"]
self.FILTERS = kwargs["FILTERS"]
self.FILTER_NUM = kwargs["FILTER_NUM"]
self.DROPOUT_PROB = kwargs["DROPOUT_PROB"]
self.IN_CHANNEL = 1
assert (len(self.FILTERS) == len(self.FILTER_NUM))
# one for UNK and one for zero padding
self.embedding = nn.Embedding(self.VOCAB_SIZE + 2, self.WORD_DIM, padding_idx=self.VOCAB_SIZE + 1)
if self.MODEL == "static" or self.MODEL == "non-static" or self.MODEL == "multichannel":
self.WV_MATRIX = kwargs["WV_MATRIX"]
self.embedding.weight.data.copy_(torch.from_numpy(self.WV_MATRIX))
if self.MODEL == "static":
self.embedding.weight.requires_grad = False
elif self.MODEL == "multichannel":
self.embedding2 = nn.Embedding(self.VOCAB_SIZE + 2, self.WORD_DIM, padding_idx=self.VOCAB_SIZE + 1)
self.embedding2.weight.data.copy_(torch.from_numpy(self.WV_MATRIX))
self.embedding2.weight.requires_grad = False
self.IN_CHANNEL = 2
for i in range(len(self.FILTERS)):
conv = nn.Conv1d(self.IN_CHANNEL, self.FILTER_NUM[i], self.WORD_DIM * self.FILTERS[i], stride=self.WORD_DIM)
setattr(self, f'conv_{i}', conv)
self.fc = nn.Linear(sum(self.FILTER_NUM), self.CLASS_SIZE)
def get_conv(self, i):
return getattr(self, f'conv_{i}')
def forward(self, inp):
x = self.embedding(inp).view(-1, 1, self.WORD_DIM * self.MAX_SENT_LEN)
if self.MODEL == "multichannel":
x2 = self.embedding2(inp).view(-1, 1, self.WORD_DIM * self.MAX_SENT_LEN)
x = torch.cat((x, x2), 1)
conv_results = [
F.max_pool1d(F.relu(self.get_conv(i)(x)), self.MAX_SENT_LEN - self.FILTERS[i] + 1)
.view(-1, self.FILTER_NUM[i])
for i in range(len(self.FILTERS))]
x = torch.cat(conv_results, 1)
x = F.dropout(x, p=self.DROPOUT_PROB, training=self.training)
x = self.fc(x)
return x