-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathlanguage_model.py
129 lines (108 loc) · 4.62 KB
/
language_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
"""
This code is modified from Hengyuan Hu's repository.
https://github.com/hengyuan-hu/bottom-up-attention-vqa
"""
import torch
import torch.nn as nn
import numpy as np
from lxrt.tokenization import BertTokenizer
from lxrt.entry import convert_sents_to_features
from lxrt.modeling import BertFeatureExtraction
class WordEmbedding(nn.Module):
"""Word Embedding
The ntoken-th dim is used for padding_idx, which agrees *implicitly*
with the definition in Dictionary.
"""
def __init__(self, ntoken, emb_dim, dropout, op=''):
super(WordEmbedding, self).__init__()
self.op = op
self.emb = nn.Embedding(ntoken+1, emb_dim, padding_idx=ntoken)
if 'c' in op:
self.emb_ = nn.Embedding(ntoken+1, emb_dim, padding_idx=ntoken)
self.emb_.weight.requires_grad = False # fixed
self.dropout = nn.Dropout(dropout)
self.ntoken = ntoken
self.emb_dim = emb_dim
def init_embedding(self, np_file, tfidf=None, tfidf_weights=None):
weight_init = torch.from_numpy(np.load(np_file))
assert weight_init.shape == (self.ntoken, self.emb_dim)
self.emb.weight.data[:self.ntoken] = weight_init
if tfidf is not None:
if 0 < tfidf_weights.size:
weight_init = torch.cat([weight_init, torch.from_numpy(tfidf_weights)], 0)
weight_init = tfidf.matmul(weight_init) # (N x N') x (N', F)
if 'c' in self.op:
self.emb_.weight.requires_grad = True
if 'c' in self.op:
self.emb_.weight.data[:self.ntoken] = weight_init.clone()
def forward(self, x):
emb = self.emb(x)
if 'c' in self.op:
if len(x.size()) < 3:
emb = torch.cat((emb, self.emb_(x)), 2)
else:
emb = torch.cat((emb, self.emb_(x)), 3)
emb = self.dropout(emb)
return emb
class QuestionEmbedding(nn.Module):
def __init__(self, in_dim, num_hid, nlayers, bidirect, dropout, rnn_type='GRU'):
"""Module for question embedding
"""
super(QuestionEmbedding, self).__init__()
assert rnn_type == 'LSTM' or rnn_type == 'GRU'
rnn_cls = nn.LSTM if rnn_type == 'LSTM' else nn.GRU if rnn_type == 'GRU' else None
self.rnn = rnn_cls(
in_dim, num_hid, nlayers,
bidirectional=bidirect,
dropout=dropout,
batch_first=True)
self.in_dim = in_dim
self.num_hid = num_hid
self.nlayers = nlayers
self.rnn_type = rnn_type
self.ndirections = 1 + int(bidirect)
def init_hidden(self, batch):
# just to get the type of tensor
weight = next(self.parameters()).data
hid_shape = (self.nlayers * self.ndirections, batch, self.num_hid)
if self.rnn_type == 'LSTM':
return (weight.new(*hid_shape).zero_(),
weight.new(*hid_shape).zero_())
else:
return weight.new(*hid_shape).zero_()
def forward(self, x):
# x: [batch, sequence, in_dim]
batch = x.size(0)
hidden = self.init_hidden(batch)
output, hidden = self.rnn(x, hidden)
# self.rnn.flatten_parameters()
if self.ndirections == 1:
return output[:, -1]
forward_ = output[:, -1, :self.num_hid]
backward = output[:, 0, self.num_hid:]
return torch.cat((forward_, backward), dim=1)
def forward_all(self, x):
# x: [batch, sequence, in_dim]
batch = x.size(0)
hidden = self.init_hidden(batch)
output, hidden = self.rnn(x, hidden)
return output
class BertEmbedding(nn.Module):
def __init__(self, args, max_seq_length):
super(BertEmbedding, self).__init__()
# Using the bert tokenizer
self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", do_lower_case=True)
self.max_seq_length = max_seq_length
if args.from_scratch:
print("initializing all the weights")
self.model.apply(self.model.init_bert_weights)
# Bert embedding
self.bert = BertFeatureExtraction.from_pretrained('bert-base-uncased', args=args)
def forward(self, sents):
train_features = convert_sents_to_features(
sents, self.max_seq_length, self.tokenizer)
input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long).cuda()
input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long).cuda()
segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long).cuda()
output = self.bert(input_ids, segment_ids, input_mask)
return output