-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
94 lines (84 loc) · 4.42 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch
from models import Transformer
from models.layers.self_attention import SelfAttention
from models.layers.multi_head_attention import MultiHeadAttention
from models.layers.encoder_layer import EncoderLayer
from models.blocks.encoder import Encoder
from models.blocks.decoder import Decoder
from utils import pad_tensor
src = torch.tensor([[2, 35, 12474, 4, 5438, 5, 2461, 29, 1154, 6,
192, 50, 151, 248, 3, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
])
tgt = torch.tensor([[2, 9, 12983, 4, 20848, 32, 1261, 7, 1507, 6,
0, 3, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
])
src_mask = torch.tensor([[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]])
src_enc = torch.src = torch.tensor([[[ 0.0202, 1.0201, 0.0276, 1.0342, 0.0347, 0.9632],
[ 0.8599, 0.5529, 0.0563, 1.0290, -0.0094, 1.0237],
[ 0.9085, -0.3905, 0.1088, 0.9823, 0.0186, 1.0135],
[ 0.1611, -0.9929, 0.1588, 1.0215, -0.0087, 1.0171],
[-0.7826, -0.6713, 0.1484, 0.9667, 0.0458, 1.0238],
[-0.9911, 0.2937, 0.2248, 0.9540, -0.0188, 1.0104],
[-0.3169, 0.9555, 0.3046, 0.9900, 0.0219, 0.9733],
[ 0.6241, 0.7764, 0.2903, 0.9178, 0.0084, 1.0049],
[ 1.0141, -0.1701, 0.3924, 0.9360, 0.0338, 0.9663],
[ 0.4027, -0.9138, 0.4287, 0.9417, 0.0420, 0.9894],
[-0.5603, -0.8216, 0.4154, 0.9134, 0.0373, 0.9836],
[-0.9832, 0.0410, 0.5117, 0.8532, 0.0020, 0.9838],
[-0.5524, 0.8516, 0.5175, 0.8668, 0.0152, 1.0208],
[ 0.4491, 0.8806, 0.5891, 0.8576, -0.0066, 1.0079],
[ 0.9528, 0.1025, 0.6164, 0.8104, 0.0535, 0.9871],
[ 0.6229, -0.7366, 0.6641, 0.7326, 0.0340, 0.9657],
[-0.3153, -0.9345, 0.6991, 0.7020, 0.0362, 0.9656],
[-0.9888, -0.2520, 0.7325, 0.6698, 0.0383, 0.9656],
[-0.7784, 0.6834, 0.7644, 0.6361, 0.0405, 0.9655],
[ 0.1225, 1.0118, 0.7948, 0.6010, 0.0426, 0.9654],
[ 0.8856, 0.4312, 0.8234, 0.5645, 0.0448, 0.9653],
[ 0.8093, -0.5246, 0.8504, 0.5267, 0.0469, 0.9652],
[-0.0362, -0.9768, 0.8755, 0.4877, 0.0491, 0.9651],
[-0.8736, -0.5097, 0.8988, 0.4476, 0.0512, 0.9650],
[-0.9330, 0.4473, 0.9203, 0.4064, 0.0534, 0.9649],
[-0.1597, 1.0143, 0.9398, 0.3643, 0.0556, 0.9648],
[ 0.7352, 0.6700, 0.9573, 0.3213, 0.0577, 0.9647],
[ 0.9290, -0.2690, 0.9728, 0.2776, 0.0599, 0.9645],
[ 0.2435, -0.9395, 0.9863, 0.2332, 0.0620, 0.9644],
[-0.6910, -0.7249, 0.9977, 0.1882, 0.0642, 0.9643],
[-1.0154, 0.1774, 1.0069, 0.1427, 0.0663, 0.9641],
[-0.4314, 0.9379, 1.0141, 0.0968, 0.0685, 0.9640],
[ 0.5240, 0.8573, 1.0192, 0.0507, 0.0706, 0.9639],
[ 0.9725, 0.0098, 1.0220, 0.0044, 0.0728, 0.9637],
[ 0.5017, -0.8255, 1.0228, -0.0420, 0.0749, 0.9636],
[-0.4556, -0.8806, 1.0214, -0.0884, 0.0771, 0.9634],
[-1.0192, -0.1048, 1.0178, -0.1347, 0.0792, 0.9632],
[-0.6709, 0.7885, 1.0121, -0.1808, 0.0813, 0.9631],
[ 0.2690, 0.9782, 1.0042, -0.2265, 0.0835, 0.9629],
[ 0.9364, 0.2898, 0.9943, -0.2718, 0.0856, 0.9627]]],)
"""
tgt_mask = torch.tensor([[[1, 0, 0, ..., 0, 0, 0],
[1, 1, 0, ..., 0, 0, 0],
[1, 1, 1, ..., 0, 0, 0],
...,
[1, 1, 1, ..., 0, 0, 0],
[1, 1, 1, ..., 0, 0, 0],
[1, 1, 1, ..., 0, 0, 0]]]) """
torch.manual_seed(0)
torch.use_deterministic_algorithms(True)
model = Transformer(
25_000,
6,
40,
0.3,
1,
6,
2048
)
model.eval()
torch.set_printoptions(profile="full")
out = model(src, tgt)
# test = MultiHeadAttention(6, 6)
# out = test(src_enc, src_enc, src_enc, src_mask)