-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathproduct_key_memory.py
99 lines (77 loc) · 3.15 KB
/
product_key_memory.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from typing import Dict, Tuple, List, Any
import torch
import torch.nn as nn
import math
def fetch_pkm_value_parameters(
module: nn.Module,
) -> Tuple[List[nn.Parameter], List[nn.Parameter]]:
params: List[nn.Parameter] = []
for m in module.modules():
if isinstance(m, PKM):
params.append(m.values.weight) # type: ignore
paramset = set(params)
rest = [p for p in module.parameters() if p not in paramset]
return params, rest
def fetch_optimizer_parameters(
module: nn.Module, pkm_learning_rate: float = 1e-2
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
pkm_params, rest = fetch_pkm_value_parameters(module)
return {"params": rest}, {"params": pkm_params, "lr": pkm_learning_rate}
class PKM(nn.Module):
def __init__(
self,
dim: int,
heads: int = 4,
num_keys: int = 128,
topk: int = 32,
dim_head: int = 256,
input_dropout: float = 0.0,
query_dropout: float = 0.0,
value_dropout: float = 0.0,
):
super().__init__()
assert dim % heads == 0, "dimension must be divisible by number of heads"
self.topk = topk
self.heads = heads
self.num_keys = num_keys
dim_query = dim_head * heads
self.to_queries = nn.Linear(dim, dim_query, bias=False)
self.norm = nn.LayerNorm(dim_query)
self.keys = nn.Parameter(torch.zeros(heads, num_keys, 2, dim_head // 2))
self.values = nn.EmbeddingBag(num_keys ** 2, dim, mode="sum")
self.input_dropout = nn.Dropout(input_dropout)
self.query_dropout = nn.Dropout(query_dropout)
self.value_dropout = nn.Dropout(value_dropout)
self.init_weights()
def init_weights(self):
nn.init.normal_(self.keys, std=1 / self.keys.size(-1))
nn.init.normal_(
self.values.weight, std=1 / math.sqrt(self.values.weight.size(-1))
)
def forward(self, x, input_mask=None, **kwargs):
t, b, e = x.shape
h = self.heads
x = self.input_dropout(x)
queries = self.to_queries(x)
queries = self.norm(queries)
queries = self.query_dropout(queries)
queries = queries.chunk(2, dim=-1)
queries = torch.stack(queries).reshape(2, t, b, h, -1)
dots = torch.einsum("ptbhd,hnpd->tbhpn", queries, self.keys)
scores, indices = dots.topk(k=self.topk, dim=-1)
scores, indices = map(lambda x: x.chunk(2, dim=3), (scores, indices))
all_topk = self.topk ** 2
shape = (t, b, h, all_topk)
all_scores = (scores[0][..., :, None] + scores[1][..., None, :]).reshape(*shape)
all_indices = (
indices[0][..., :, None] * self.num_keys + indices[1][..., None, :]
).reshape(*shape)
final_topk, final_indices = all_scores.topk(self.topk, dim=-1)
value_indices = all_indices.gather(-1, final_indices)
attn = final_topk.softmax(dim=-1)
value_indices, attn = map(
lambda x: x.reshape(-1, self.topk * h), (value_indices, attn)
)
out = self.values(value_indices, per_sample_weights=attn)
out = self.value_dropout(out)
return out.reshape(t, b, e)