-
Notifications
You must be signed in to change notification settings - Fork 0
/
attention.py
77 lines (60 loc) · 2.16 KB
/
attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# import the necessary packages
import tensorflow as tf
from tensorflow.keras.layers import Add, Layer, LayerNormalization, MultiHeadAttention
class BaseAttention(Layer):
"""
The base attention module. All the other attention modules will
be subclassed from this module.
"""
def __init__(self, **kwargs):
# Note the use of kwargs here, it is used to initialize the
# MultiHeadAttention layer for all the subclassed modules
super().__init__()
# initialize a multihead attention layer, layer normalization layer, and
# an addition layer
self.mha = MultiHeadAttention(**kwargs)
self.layernorm = LayerNormalization()
self.add = Add()
class CrossAttention(BaseAttention):
def call(self, x, context):
# apply multihead attention to the query and the context inputs
(attentionOutputs, attentionScores) = self.mha(
query=x,
key=context,
value=context,
return_attention_scores=True,
)
# store the attention scores that will be later visualized
self.lastAttentionScores = attentionScores
# apply residual connection and layer norm
x = self.add([x, attentionOutputs])
x = self.layernorm(x)
# return the processed query
return x
class GlobalSelfAttention(BaseAttention):
def call(self, x):
# apply self multihead attention
attentionOutputs = self.mha(
query=x,
key=x,
value=x,
)
# apply residual connection and layer norm
x = self.add([x, attentionOutputs])
x = self.layernorm(x)
# return the processed query
return x
class CausalSelfAttention(BaseAttention):
def call(self, x):
# apply self multi head attention with causal masking (look-ahead-mask)
attentionOutputs = self.mha(
query=x,
key=x,
value=x,
use_causal_mask=True,
)
# apply residual connection and layer norm
x = self.add([x, attentionOutputs])
x = self.layernorm(x)
# return the processed query
return x