-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathsoft_embedding.py
68 lines (60 loc) · 2.6 KB
/
soft_embedding.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch
import torch.nn as nn
class SoftEmbedding(nn.Module):
def __init__(self,
wte: nn.Embedding,
n_tokens: int = 10,
n_prompts: int = 1,
random_range: float = 0.5,
initialize_from_vocab: bool = True):
"""appends learned embedding to
Args:
wte (nn.Embedding): original transformer word embedding
n_tokens (int, optional): number of tokens for task. Defaults to 10.
random_range (float, optional): range to init embedding (if not initialize from vocab). Defaults to 0.5.
initialize_from_vocab (bool, optional): initalizes from default vocab. Defaults to True.
"""
super(SoftEmbedding, self).__init__()
self.wte = wte
self.n_tokens = n_tokens
self.n_prompts = n_prompts
self.learned_embedding = nn.parameter.Parameter(
self.initialize_embedding(wte,
n_tokens,
n_prompts,
random_range,
initialize_from_vocab
)
)
def initialize_embedding(self,
wte: nn.Embedding,
n_tokens: int = 10,
n_prompts: int = 1,
random_range: float = 0.5,
initialize_from_vocab: bool = True):
"""initializes learned embedding
Args:
same as __init__
Returns:
torch.float: initialized using original schemes
"""
if initialize_from_vocab:
return torch.cat([
self.wte.weight[:n_tokens].clone().detach().unsqueeze(0)
for _ in range(n_prompts)
], 0)
return torch.FloatTensor(n_prompts, n_tokens, wte.weight.size(1)).uniform_(-random_range, random_range)
def forward(self, tokens):
"""run forward pass
Args:
tokens (torch.long): input tokens before encoding
Returns:
torch.float: encoding of text concatenated with learned task specifc embedding
"""
input_embedding = self.wte(tokens[:, self.n_tokens:])
index_prompt_tokens = tokens[:, 0]
if not torch.any(index_prompt_tokens == -1):
learned_embedding = self.learned_embedding.index_select(0, index_prompt_tokens)
return torch.cat([learned_embedding, input_embedding], 1)
else:
return input_embedding