-
Notifications
You must be signed in to change notification settings - Fork 5
/
AttentionPooling
23 lines (20 loc) · 901 Bytes
/
AttentionPooling
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
class AttentionPooling(BasicAttention):
def __init__(self, embd_size, q_k_hidden_size=1, num_heads=1,drop_rate=0., **kwargs):
q_embd_size = embd_size
k_embd_size = embd_size
v_embd_size = embd_size
output_hidden_size = embd_size
is_q = True
# is_v=True
score_func = self.score
super(AttentionPooling, self).__init__(q_embd_size, k_embd_size, v_embd_size, q_k_hidden_size,
output_hidden_size, num_heads, score_func,drop_rate= drop_rate,is_q=is_q, **kwargs)
def score(self, q, k):
score = k.permute(0, 2, 1)
# score = nn.functional.softmax(score, dim=-1)
return score
def forward(self, embd, mask=None):
q_embd = embd
k_embd = embd
v_embd = embd
return super(AttentionPooling, self).forward(q_embd, k_embd, v_embd, mask)