-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathtransformer.py
597 lines (500 loc) · 23.5 KB
/
transformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch import Tensor
from typing import Optional, Tuple
def get_attn_pad_mask(inputs, input_lengths, expand_length):
""" mask position is set to 1 """
def get_transformer_non_pad_mask(inputs: Tensor, input_lengths: Tensor) -> Tensor:
""" Padding position is set to 0, either use input_lengths or pad_id """
batch_size = inputs.size(0)
if len(inputs.size()) == 2:
non_pad_mask = inputs.new_ones(inputs.size()) # B x T
elif len(inputs.size()) == 3:
non_pad_mask = inputs.new_ones(inputs.size()[:-1]) # B x T
else:
raise ValueError(f"Unsupported input shape {inputs.size()}")
for i in range(batch_size):
non_pad_mask[i, input_lengths[i]:] = 0
return non_pad_mask
non_pad_mask = get_transformer_non_pad_mask(inputs, input_lengths)
pad_mask = non_pad_mask.lt(1)
attn_pad_mask = pad_mask.unsqueeze(1).expand(-1, expand_length, -1)
return attn_pad_mask
def get_attn_subsequent_mask(seq):
assert seq.dim() == 2
attn_shape = [seq.size(0), seq.size(1), seq.size(1)]
subsequent_mask = torch.triu(torch.ones(attn_shape), diagonal=1)
if seq.is_cuda:
subsequent_mask = subsequent_mask.cuda()
return subsequent_mask
class PositionalEncoding(nn.Module):
"""
Positional Encoding proposed in "Attention Is All You Need".
Since transformer contains no recurrence and no convolution, in order for the model to make
use of the order of the sequence, we must add some positional information.
"Attention Is All You Need" use sine and cosine functions of different frequencies:
PE_(pos, 2i) = sin(pos / power(10000, 2i / d_model))
PE_(pos, 2i+1) = cos(pos / power(10000, 2i / d_model))
"""
def __init__(self, d_model: int = 80, max_len: int = 5000) -> None:
super(PositionalEncoding, self).__init__()
pe = torch.zeros(max_len, d_model, requires_grad=False)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1).float()
div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, length: int) -> Tensor:
return self.pe[:, :length]
class Embedding(nn.Module):
"""
Embedding layer. Similarly to other sequence transduction models, transformer use learned embeddings
to convert the input tokens and output tokens to vectors of dimension d_model.
In the embedding layers, transformer multiply those weights by sqrt(d_model)
"""
def __init__(self, num_embeddings: int, pad_id: int, d_model: int = 512) -> None:
super(Embedding, self).__init__()
self.sqrt_dim = math.sqrt(d_model)
self.embedding = nn.Embedding(num_embeddings, d_model, padding_idx=pad_id)
def forward(self, inputs: Tensor) -> Tensor:
return self.embedding(inputs) * self.sqrt_dim
class AddNorm(nn.Module):
"""
Add & Normalization layer proposed in "Attention Is All You Need".
Transformer employ a residual connection around each of the two sub-layers,
(Multi-Head Attention & Feed-Forward) followed by layer normalization.
"""
def __init__(self, sublayer: nn.Module, d_model: int = 512) -> None:
super(AddNorm, self).__init__()
self.sublayer = sublayer
self.layer_norm = nn.LayerNorm(d_model)
def forward(self, *args):
residual = args[0]
output = self.sublayer(*args)
if isinstance(output, tuple):
return self.layer_norm(output[0] + residual), output[1]
else:
return self.layer_norm(output + residual)
class ScaledDotProductAttention(nn.Module):
"""
Scaled Dot-Product Attention proposed in "Attention Is All You Need"
Compute the dot products of the query with all keys, divide each by sqrt(dim),
and apply a softmax function to obtain the weights on the values
Args: dim, mask
dim (int): dimention of attention
mask (torch.Tensor): tensor containing indices to be masked
Inputs: query, key, value, mask
- **query** (batch, q_len, d_model): tensor containing projection vector for decoder.
- **key** (batch, k_len, d_model): tensor containing projection vector for encoder.
- **value** (batch, v_len, d_model): tensor containing features of the encoded input sequence.
- **mask** (-): tensor containing indices to be masked
Returns: context, attn
- **context**: tensor containing the context vector from attention mechanism.
- **attn**: tensor containing the attention (alignment) from the encoder outputs.
"""
def __init__(self, dim: int) -> None:
super(ScaledDotProductAttention, self).__init__()
self.sqrt_dim = np.sqrt(dim)
def forward(self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
score = torch.bmm(query, key.transpose(1, 2)) / self.sqrt_dim
if mask is not None:
score.masked_fill_(mask.view(score.size()), -float('Inf'))
attn = F.softmax(score, -1)
context = torch.bmm(attn, value)
return context, attn
class MultiHeadAttention(nn.Module):
"""
Multi-Head Attention proposed in "Attention Is All You Need"
Instead of performing a single attention function with d_model-dimensional keys, values, and queries,
project the queries, keys and values h times with different, learned linear projections to d_head dimensions.
These are concatenated and once again projected, resulting in the final values.
Multi-head attention allows the model to jointly attend to information from different representation
subspaces at different positions.
MultiHead(Q, K, V) = Concat(head_1, ..., head_h) · W_o
where head_i = Attention(Q · W_q, K · W_k, V · W_v)
Args:
d_model (int): The dimension of keys / values / quries (default: 512)
num_heads (int): The number of attention heads. (default: 8)
Inputs: query, key, value, mask
- **query** (batch, q_len, d_model): In transformer, three different ways:
Case 1: come from previoys decoder layer
Case 2: come from the input embedding
Case 3: come from the output embedding (masked)
- **key** (batch, k_len, d_model): In transformer, three different ways:
Case 1: come from the output of the encoder
Case 2: come from the input embeddings
Case 3: come from the output embedding (masked)
- **value** (batch, v_len, d_model): In transformer, three different ways:
Case 1: come from the output of the encoder
Case 2: come from the input embeddings
Case 3: come from the output embedding (masked)
- **mask** (-): tensor containing indices to be masked
Returns: output, attn
- **output** (batch, output_len, dimensions): tensor containing the attended output features.
- **attn** (batch * num_heads, v_len): tensor containing the attention (alignment) from the encoder outputs.
"""
def __init__(self, d_model: int = 512, num_heads: int = 8) -> None:
super(MultiHeadAttention, self).__init__()
assert d_model % num_heads == 0, "d_model % num_heads should be zero."
self.d_head = int(d_model / num_heads)
self.num_heads = num_heads
self.scaled_dot_attn = ScaledDotProductAttention(self.d_head)
self.linear_q = nn.Linear(d_model, self.d_head * num_heads)
self.linear_k = nn.Linear(d_model, self.d_head * num_heads)
self.linear_v = nn.Linear(d_model, self.d_head * num_heads)
self.linear = nn.Linear(d_model, d_model)
def forward(self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
batch_size = value.size(0)
query = self.linear_q(query).view(batch_size, -1, self.num_heads, self.d_head) # BxQ_LENxNxD
key = self.linear_k(key).view(batch_size, -1, self.num_heads, self.d_head) # BxK_LENxNxD
value = self.linear_v(value).view(batch_size, -1, self.num_heads, self.d_head) # BxV_LENxNxD
query = query.permute(2, 0, 1, 3).contiguous().view(batch_size * self.num_heads, -1, self.d_head) # BNxQ_LENxD
key = key.permute(2, 0, 1, 3).contiguous().view(batch_size * self.num_heads, -1, self.d_head) # BNxK_LENxD
value = value.permute(2, 0, 1, 3).contiguous().view(batch_size * self.num_heads, -1, self.d_head) # BNxV_LENxD
if mask is not None:
mask = mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1) # BxNxQ_LENxK_LEN
context, attn = self.scaled_dot_attn(query, key, value, mask)
context = context.view(self.num_heads, batch_size, -1, self.d_head)
context = context.permute(1, 2, 0, 3).contiguous().view(batch_size, -1, self.num_heads * self.d_head) # BxTxND
return self.linear(context)
class PoswiseFeedForwardNet(nn.Module):
"""
Position-wise Feedforward Networks proposed in "Attention Is All You Need".
Fully connected feed-forward network, which is applied to each position separately and identically.
This consists of two linear transformations with a ReLU activation in between.
Another way of describing this is as two convolutions with kernel size 1.
"""
def __init__(self, d_model: int = 512, d_ff: int = 2048, dropout_p: float = 0.3) -> None:
super(PoswiseFeedForwardNet, self).__init__()
self.feed_forward = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.Dropout(dropout_p),
nn.ReLU(),
nn.Linear(d_ff, d_model),
nn.Dropout(dropout_p)
)
def forward(self, inputs: Tensor) -> Tensor:
return self.feed_forward(inputs)
class TransformerEncoderLayer(nn.Module):
"""
EncoderLayer is made up of self-attention and feedforward network.
This standard encoder layer is based on the paper "Attention Is All You Need".
"""
def __init__(
self,
d_model: int = 512,
num_heads: int = 8,
d_ff: int = 2048,
dropout_p: float = 0.3,
) -> None:
super(TransformerEncoderLayer, self).__init__()
self.self_attention = AddNorm(MultiHeadAttention(d_model, num_heads), d_model)
self.feed_forward = AddNorm(PoswiseFeedForwardNet(d_model, d_ff, dropout_p), d_model)
def forward(
self,
inputs: Tensor,
self_attn_mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]:
output = self.self_attention(inputs, inputs, inputs, self_attn_mask)
output = self.feed_forward(output)
return output
class TransformerEncoder(nn.Module):
"""
The TransformerEncoder is composed of a stack of N identical layers.
Each layer has two sub-layers. The first is a multi-head self-attention mechanism,
and the second is a simple, position-wise fully connected feed-forward network.
"""
def __init__(
self,
num_embeddings: int,
d_model: int = 512,
d_ff: int = 2048,
num_layers: int = 6,
num_heads: int = 8,
dropout_p: float = 0.1,
pad_id: int = 0,
) -> None:
super(TransformerEncoder, self).__init__()
self.d_model = d_model
self.num_layers = num_layers
self.num_heads = num_heads
self.pad_id = pad_id
self.embedding = Embedding(num_embeddings, pad_id, d_model)
self.pos_encoding = PositionalEncoding(d_model)
self.input_dropout = nn.Dropout(p=dropout_p)
self.layers = nn.ModuleList(
[TransformerEncoderLayer(d_model, num_heads, d_ff, dropout_p) for _ in range(num_layers)]
)
def forward(self, inputs: Tensor, input_lengths: Tensor = None):
length = inputs.size(1)
output = self.input_dropout(self.embedding(inputs) + self.pos_encoding(length))
self_attn_mask = get_attn_pad_mask(inputs, input_lengths, length)
for layer in self.layers:
output = layer(output, self_attn_mask)
return output
class TransformerDecoderLayer(nn.Module):
r"""
DecoderLayer is made up of self-attention, multi-head attention and feedforward network.
This standard decoders layer is based on the paper "Attention Is All You Need".
Args:
d_model: dimension of model (default: 512)
num_heads: number of attention heads (default: 8)
d_ff: dimension of feed forward network (default: 2048)
dropout_p: probability of dropout (default: 0.3)
Inputs:
inputs (torch.FloatTensor): input sequence of transformer decoder layer
encoder_outputs (torch.FloatTensor): outputs of encoder
self_attn_mask (torch.BoolTensor): mask of self attention
encoder_output_mask (torch.BoolTensor): mask of encoder outputs
Returns:
(Tensor, Tensor, Tensor)
* outputs (torch.FloatTensor): output of transformer decoder layer
* self_attn (torch.FloatTensor): output of self attention
* encoder_attn (torch.FloatTensor): output of encoder attention
Reference:
Ashish Vaswani et al.: Attention Is All You Need
https://arxiv.org/abs/1706.03762
"""
def __init__(
self,
d_model: int = 512,
num_heads: int = 8,
d_ff: int = 2048,
dropout_p: float = 0.3,
) -> None:
super(TransformerDecoderLayer, self).__init__()
self.self_attention_prenorm = nn.LayerNorm(d_model)
self.decoder_attention_prenorm = nn.LayerNorm(d_model)
self.feed_forward_prenorm = nn.LayerNorm(d_model)
self.self_attention = MultiHeadAttention(d_model, num_heads)
self.decoder_attention = MultiHeadAttention(d_model, num_heads)
self.feed_forward = PoswiseFeedForwardNet(d_model, d_ff, dropout_p)
def forward(
self,
inputs: Tensor,
encoder_outputs: Tensor,
self_attn_mask: Optional[Tensor] = None,
encoder_attn_mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor, Tensor]:
r"""
Forward propagate transformer decoder layer.
Inputs:
inputs (torch.FloatTensor): input sequence of transformer decoder layer
encoder_outputs (torch.FloatTensor): outputs of encoder
self_attn_mask (torch.BoolTensor): mask of self attention
encoder_output_mask (torch.BoolTensor): mask of encoder outputs
Returns:
outputs (torch.FloatTensor): output of transformer decoder layer
self_attn (torch.FloatTensor): output of self attention
encoder_attn (torch.FloatTensor): output of encoder attention
"""
residual = inputs
inputs = self.self_attention_prenorm(inputs)
outputs = self.self_attention(inputs, inputs, inputs, self_attn_mask)
outputs += residual
residual = outputs
outputs = self.decoder_attention_prenorm(outputs)
outputs = self.decoder_attention(outputs, encoder_outputs, encoder_outputs, encoder_attn_mask)
outputs += residual
residual = outputs
outputs = self.feed_forward_prenorm(outputs)
outputs = self.feed_forward(outputs)
outputs += residual
return outputs
class TransformerDecoder(nn.Module):
r"""
The TransformerDecoder is composed of a stack of N identical layers.
Each layer has three sub-layers. The first is a multi-head self-attention mechanism,
and the second is a multi-head attention mechanism, third is a feed-forward network.
Args:
num_classes: umber of classes
d_model: dimension of model
d_ff: dimension of feed forward network
num_layers: number of layers
num_heads: number of attention heads
dropout_p: probability of dropout
pad_id (int, optional): index of the pad symbol (default: 0)
sos_id (int, optional): index of the start of sentence symbol (default: 1)
eos_id (int, optional): index of the end of sentence symbol (default: 2)
max_length (int): max decoding length
"""
def __init__(
self,
num_classes: int,
d_model: int = 512,
d_ff: int = 512,
num_layers: int = 6,
num_heads: int = 8,
dropout_p: float = 0.3,
pad_id: int = 0,
sos_id: int = 1,
eos_id: int = 2,
max_length: int = 512,
) -> None:
super(TransformerDecoder, self).__init__()
self.d_model = d_model
self.num_layers = num_layers
self.num_heads = num_heads
self.max_length = max_length
self.pad_id = pad_id
self.sos_id = sos_id
self.eos_id = eos_id
self.embedding = Embedding(num_classes, pad_id, d_model)
self.positional_encoding = PositionalEncoding(d_model)
self.input_dropout = nn.Dropout(p=dropout_p)
self.layers = nn.ModuleList([
TransformerDecoderLayer(
d_model=d_model,
num_heads=num_heads,
d_ff=d_ff,
dropout_p=dropout_p,
) for _ in range(num_layers)
])
self.fc = nn.Linear(d_model, num_classes, bias=False)
def forward_step(
self,
decoder_inputs: torch.Tensor,
decoder_input_lengths: torch.Tensor,
encoder_outputs: torch.Tensor,
encoder_output_lengths: torch.Tensor,
positional_encoding_length: int,
) -> torch.Tensor:
dec_self_attn_pad_mask = get_attn_pad_mask(
decoder_inputs, decoder_input_lengths, decoder_inputs.size(1)
)
dec_self_attn_subsequent_mask = get_attn_subsequent_mask(decoder_inputs)
self_attn_mask = torch.gt((dec_self_attn_pad_mask + dec_self_attn_subsequent_mask), 0)
encoder_attn_mask = get_attn_pad_mask(encoder_outputs, encoder_output_lengths, decoder_inputs.size(1))
outputs = self.embedding(decoder_inputs) + self.positional_encoding(positional_encoding_length)
outputs = self.input_dropout(outputs)
for layer in self.layers:
outputs = layer(
inputs=outputs,
encoder_outputs=encoder_outputs,
self_attn_mask=self_attn_mask,
encoder_attn_mask=encoder_attn_mask,
)
return outputs
def forward(
self,
encoder_outputs: torch.Tensor,
targets: Optional[torch.LongTensor] = None,
encoder_output_lengths: torch.Tensor = None,
target_lengths: torch.Tensor = None,
) -> torch.Tensor:
r"""
Forward propagate a `encoder_outputs` for training.
Args:
targets (torch.LongTensor): A target sequence passed to decoders. `IntTensor` of size
``(batch, seq_length)``
encoder_outputs (torch.FloatTensor): A output sequence of encoders. `FloatTensor` of size
``(batch, seq_length, dimension)``
encoder_output_lengths (torch.LongTensor): The length of encoders outputs. ``(batch)``
Returns:
* logits (torch.FloatTensor): Log probability of model predictions.
"""
logits = list()
batch_size = encoder_outputs.size(0)
if targets is not None:
targets = targets[targets != self.eos_id].view(batch_size, -1)
target_length = targets.size(1)
step_outputs = self.forward_step(
decoder_inputs=targets,
decoder_input_lengths=target_lengths,
encoder_outputs=encoder_outputs,
encoder_output_lengths=encoder_output_lengths,
positional_encoding_length=target_length,
)
step_outputs = self.fc(step_outputs).log_softmax(dim=-1)
for di in range(step_outputs.size(1)):
step_output = step_outputs[:, di, :]
logits.append(step_output)
# Inference
else:
input_var = encoder_outputs.new_zeros(batch_size, self.max_length).long()
input_var = input_var.fill_(self.pad_id)
input_var[:, 0] = self.sos_id
for di in range(1, self.max_length):
input_lengths = torch.IntTensor(batch_size).fill_(di)
outputs = self.forward_step(
decoder_inputs=input_var[:, :di],
decoder_input_lengths=input_lengths,
encoder_outputs=encoder_outputs,
encoder_output_lengths=encoder_output_lengths,
positional_encoding_length=di,
)
step_output = self.fc(outputs).log_softmax(dim=-1)
logits.append(step_output[:, -1, :])
input_var[:, di] = logits[-1].topk(1)[1].squeeze()
return torch.stack(logits, dim=1)
class Transformer(nn.Module):
"""
A Transformer model. User is able to modify the attributes as needed.
The architecture is based on the paper "Attention Is All You Need".
Args:
pad_id (int): identification of <PAD_token>
num_input_embeddings (int): dimension of input embeddings
num_output_embeddings (int): dimenstion of output embeddings
d_model (int): dimension of model (default: 512)
d_ff (int): dimension of feed forward network (default: 2048)
num_encoder_layers (int): number of encoder layers (default: 6)
num_decoder_layers (int): number of decoder layers (default: 6)
num_heads (int): number of attention heads (default: 8)
dropout_p (float): dropout probability (default: 0.3)
Inputs: inputs, targets
- **inputs** (batch, input_length): tensor containing input sequences
- **targets** (batch, target_length): tensor contatining target sequences
Returns: output
- **output**: tensor containing the outputs
"""
def __init__(
self,
num_input_embeddings: int,
num_output_embeddings: int,
d_model: int = 512,
d_ff: int = 2048,
num_heads: int = 8,
num_encoder_layers: int = 6,
num_decoder_layers: int = 6,
dropout_p: float = 0.1,
pad_id: int = 0,
sos_id: int = 1,
eos_id: int = 2,
max_length: int = 512,
) -> None:
super(Transformer, self).__init__()
self.pad_id = pad_id
self.encoder = TransformerEncoder(
num_embeddings=num_input_embeddings,
d_model=d_model,
d_ff=d_ff,
num_layers=num_encoder_layers,
num_heads=num_heads,
dropout_p=dropout_p,
pad_id=pad_id,
)
self.decoder = TransformerDecoder(
num_classes=num_output_embeddings,
d_model=d_model,
d_ff=d_ff,
num_layers=num_decoder_layers,
num_heads=num_heads,
dropout_p=dropout_p,
pad_id=pad_id,
sos_id=sos_id,
eos_id=eos_id,
max_length=max_length,
)
def forward(
self,
inputs: Tensor,
input_lengths: Tensor,
targets: Optional[Tensor] = None,
target_lengths: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
encoder_outputs = self.encoder(inputs, input_lengths)
return self.decoder(encoder_outputs, targets, input_lengths, target_lengths)