-
Notifications
You must be signed in to change notification settings - Fork 4
/
attention_layer.py
36 lines (32 loc) · 1.35 KB
/
attention_layer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from tensorflow.keras.layers import Layer
from tensorflow.keras import backend as K
class AttentionLayer(Layer):
def __init__(self, return_sequences=True,**kwargs):
self.return_sequences = return_sequences
super(AttentionLayer,self).__init__(**kwargs)
def build(self,input_shape):
self.W=self.add_weight(name='attention_weight', shape=(input_shape[-1],1),
initializer='random_normal', trainable=True)
self.b=self.add_weight(name='attention_bias', shape=(input_shape[1],1),
initializer='zeros', trainable=True)
super(AttentionLayer, self).build(input_shape)
def get_config(self):
config = super().get_config().copy()
config.update({
'return_sequences': self.return_sequences
})
return config
def call(self,x):
# Alignment scores. Pass them through tanh function
e = K.tanh(K.dot(x,self.W)+self.b)
# Remove dimension of size 1
e = K.squeeze(e, axis=-1)
# Compute the weights
alpha = K.softmax(e)
# Reshape to tensorFlow format
alpha = K.expand_dims(alpha, axis=-1)
# Compute the context vector
context = x * alpha
if self.return_sequences:
return context
return K.sum(context, axis=1)