-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
111 lines (84 loc) · 4.36 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import tensorflow as tf
import numpy as np
class Encoder(tf.keras.Model):
def __init__(self, vocab_size, embedding_dim, enc_units, batch_sz):
super(Encoder, self).__init__()
self.batch_sz = batch_sz
self.enc_units = enc_units
self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
self.LSTM = tf.keras.layers.LSTM(self.enc_units,
return_sequences=True,
return_state=True)
def call(self, x, hidden):
x = self.embedding(x)
output, state_h, state_c = self.LSTM(x, initial_state = hidden)
return output, state_h, state_c
def initialize_hidden_state(self):
return (tf.zeros([self.batch_sz, self.enc_units]),
tf.zeros([self.batch_sz, self.enc_units]))
class LuongAttention(tf.keras.Model):
def __init__(self, rnn_size, attention_func):
super(LuongAttention, self).__init__()
self.attention_func = attention_func
if attention_func not in ['dot', 'general', 'concat']:
raise ValueError(
'Unknown attention score function! Must be either dot, general or concat.')
if attention_func == 'general':
# General score function
self.wa = tf.keras.layers.Dense(rnn_size)
elif attention_func == 'concat':
# Concat score function
self.wa = tf.keras.layers.Dense(rnn_size, activation='tanh')
self.va = tf.keras.layers.Dense(1)
def call(self, decoder_output, encoder_output):
if self.attention_func == 'dot':
# decoder_output has shape: (batch_size, 1, rnn_size)
# encoder_output has shape: (batch_size, max_len, rnn_size)
# score has shape: (batch_size, 1, max_len)
score = tf.matmul(decoder_output, encoder_output, transpose_b=True)
elif self.attention_func == 'general':
# score has shape: (batch_size, 1, max_len)
score = tf.matmul(decoder_output, self.wa(
encoder_output), transpose_b=True)
elif self.attention_func == 'concat':
decoder_output = tf.tile(
decoder_output, [1, encoder_output.shape[1], 1])
score = self.va(
self.wa(tf.concat((decoder_output, encoder_output), axis=-1)))
# (batch_size, max_len, 1) => (batch_size, 1, max_len)
score = tf.transpose(score, [0, 2, 1])
alignment = tf.nn.softmax(score, axis=2)
# context vector c_t is the weighted average sum of encoder output
context = tf.matmul(alignment, encoder_output)
return context, alignment
class LoungDecoder(tf.keras.Model):
def __init__(self, vocab_size, embedding_size, rnn_size, attention_func):
super(LoungDecoder, self).__init__()
self.attention = LuongAttention(rnn_size, attention_func)
self.rnn_size = rnn_size
self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_size)
self.lstm = tf.keras.layers.LSTM(
rnn_size, return_sequences=True, return_state=True)
self.wc = tf.keras.layers.Dense(rnn_size, activation='tanh')
self.ws = tf.keras.layers.Dense(vocab_size)
def call(self, sequence, state, encoder_output):
# shape of sequence is (batch_size, 1)
embed = self.embedding(sequence)
# shape of embed becomes (batch_size , 1 , embedding_size)
# the lstm_out has shape (batch_size, 1, rnn_size)
lstm_out, state_h, state_c = self.lstm(embed, initial_state=state)
# Use self.attention to compute the context and alignment vectors
# context vector's shape: (batch_size, 1, rnn_size)
# alignment vector's shape: (batch_size, 1, source_length)
context, alignment = self.attention(lstm_out, encoder_output)
# Combine the context vector and the LSTM output
# Before combined, both have shape of (batch_size, 1, rnn_size),
# so let's squeeze the axis 1 first
# After combined, it will have shape of (batch_size, 2 * rnn_size)
lstm_out = tf.concat(
[tf.squeeze(context, 1), tf.squeeze(lstm_out, 1)], 1)
# lstm_out now has shape (batch_size, rnn_size)
lstm_out = self.wc(lstm_out)
# Finally, it is converted back to vocabulary space: (batch_size, vocab_size)
logits = self.ws(lstm_out)
return logits, state_h, state_c, alignment