This repository has been archived by the owner on Jan 12, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
modules.py
128 lines (104 loc) · 4.79 KB
/
modules.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# Copyright (c) 2021 Kemal Kurniawan
from typing import Optional
from einops import rearrange
from einops.layers.torch import Rearrange
from torch import LongTensor, Tensor
import torch
import torch.nn as nn
class TransformerEncoderLayer(nn.TransformerEncoderLayer):
"""Transformer encoder layer that uses distance-aware self-attention layer."""
def __init__(
self,
d_model: int,
n_heads: int,
ff_size: int = 2048,
dropout: float = 0.1,
kv_size: int = 64,
) -> None:
# call with fake n_heads to avoid error creating MultiheadAttention
super().__init__(d_model, 1, dim_feedforward=ff_size, dropout=dropout)
# replace with custom attention
self.self_attn = DistanceAwareSelfAttention(
d_model, n_heads, dropout=dropout, kv_size=kv_size
)
class DistanceAwareSelfAttention(nn.Module):
"""Distance-aware self-attention layer from Ahmad et al. (2019)."""
def __init__(
self,
embed_dim: int,
n_heads: int,
dropout: float = 0.0,
clip_dist: int = 10,
kv_size: int = 64,
) -> None:
super().__init__()
self.in_proj = nn.Sequential(
nn.Linear(embed_dim, 3 * n_heads * kv_size),
Rearrange("slen bsz (n nhead dim) -> n bsz nhead slen dim", n=3, dim=kv_size),
)
self.k_dist_emb = nn.Embedding(clip_dist + 1, kv_size)
self.v_dist_emb = nn.Embedding(clip_dist + 1, kv_size)
self.attn_dropout = nn.Dropout(dropout)
self.out_proj = nn.Linear(n_heads * kv_size, embed_dim)
def forward(self, inputs, inputs2, inputs3, attn_mask=None, key_padding_mask=None):
assert inputs is inputs2 and inputs is inputs3, "must be a self-attention"
assert attn_mask is None, "attn_mask should not be given"
# shape: (slen, bsz, embed_dim)
assert inputs.dim() == 3
# shape: (bsz, slen)
assert key_padding_mask is None or key_padding_mask.shape == (
inputs.size(1),
inputs.size(0),
)
# each shape: (bsz, nhead, slen, qdim/vdim)
q, k, v = self.in_proj(inputs)
# shape: (slen, slen)
distances = self._get_distances(inputs.size(0)).to(inputs.device)
q *= q.size(-1) ** -0.5
k = rearrange(k, "bsz nhead slen qdim -> bsz nhead qdim slen")
# shape: (bsz, nhead, slen, slen)
attn_weights = q @ k + self._get_dist_attn_weights(q, distances)
if key_padding_mask is not None:
# broadcast over heads and queries
mask = rearrange(key_padding_mask, "bsz slen -> bsz () () slen")
attn_weights.masked_fill_(mask, float("-inf"))
# shape: (bsz, nhead, slen, slen)
attn_weights = attn_weights.softmax(dim=-1)
# shape: (bsz, nhead, slen, slen)
attn_weights = self.attn_dropout(attn_weights)
# shape: (bsz, nhead, slen, vdim)
attn_outputs = attn_weights @ v + self._get_dist_attn_outputs(attn_weights, distances)
attn_outputs = rearrange(attn_outputs, "bsz nhead slen vdim -> slen bsz (nhead vdim)")
# shape: (slen, bsz, embed_dim)
attn_outputs = self.out_proj(attn_outputs)
return attn_outputs, None # attn_weights is not needed
def _get_distances(self, slen: int) -> LongTensor:
x = rearrange(torch.arange(slen), "slen -> () slen")
y = rearrange(torch.arange(slen), "slen -> slen ()")
# shape: (slen, slen)
dist = torch.abs(x - y)
clip_dist = self.k_dist_emb.num_embeddings - 1
# shape: (slen, slen)
return dist.clamp(max=clip_dist).long() # type: ignore
def _get_dist_attn_weights(self, q: Tensor, dist: LongTensor) -> Tensor:
# shape: (bsz, nhead, slen, qdim)
assert q.dim() == 4
# shape: (slen, slen)
assert dist.shape == (q.size(2), q.size(2))
# shape: (slen, slen, qdim)
k_dist = self.k_dist_emb(dist)
q_dist = rearrange(q, "bsz nhead slen qdim -> bsz nhead slen () qdim") # bc over keys
k_dist = rearrange(k_dist, "slen slen2 qdim -> slen qdim slen2")
weights = q_dist @ k_dist
return rearrange(weights, "bsz nhead slen () slen2 -> bsz nhead slen slen2")
def _get_dist_attn_outputs(self, attn_weights: Tensor, dist: LongTensor) -> Tensor:
# shape: (bsz, nhead, slen, slen)
assert attn_weights.dim() == 4
assert attn_weights.size(-2) == attn_weights.size(-1)
# shape: (slen, slen)
assert dist.shape == attn_weights.shape[-2:]
# shape: (slen, slen, vdim)
v_dist = self.v_dist_emb(dist)
attn_dist = rearrange(attn_weights, "bsz nhead slen slen2 -> bsz nhead slen () slen2")
outputs = attn_dist @ v_dist
return rearrange(outputs, "bsz nhead slen () vdim -> bsz nhead slen vdim")