diff --git a/modules.py b/modules.py index 7efa387..34bf4c9 100644 --- a/modules.py +++ b/modules.py @@ -53,7 +53,7 @@ def get_token_embeddings(vocab_size, num_units, zero_pad=True): embeddings[1:, :]), 0) return embeddings -def scaled_dot_product_attention(Q, K, V, +def scaled_dot_product_attention(Q, K, V, key_masks, causality=False, dropout_rate=0., training=True, scope="scaled_dot_product_attention"): @@ -61,6 +61,7 @@ def scaled_dot_product_attention(Q, K, V, Q: Packed queries. 3d tensor. [N, T_q, d_k]. K: Packed keys. 3d tensor. [N, T_k, d_k]. V: Packed values. 3d tensor. [N, T_k, d_v]. + key_masks: A 2d tensor with shape of [N, key_seqlen] causality: If True, applies masking for future blinding dropout_rate: A floating point number of [0, 1]. training: boolean for controlling droput @@ -76,7 +77,7 @@ def scaled_dot_product_attention(Q, K, V, outputs /= d_k ** 0.5 # key masking - outputs = mask(outputs, Q, K, type="key") + outputs = mask(outputs, key_masks=key_masks, type="key") # causality or future blinding masking if causality: @@ -87,8 +88,8 @@ def scaled_dot_product_attention(Q, K, V, attention = tf.transpose(outputs, [0, 2, 1]) tf.summary.image("attention", tf.expand_dims(attention[:1], -1)) - # query masking - outputs = mask(outputs, Q, K, type="query") + # # query masking + # outputs = mask(outputs, Q, K, type="query") # dropout outputs = tf.layers.dropout(outputs, rate=dropout_rate, training=training) @@ -98,65 +99,58 @@ def scaled_dot_product_attention(Q, K, V, return outputs -def mask(inputs, queries=None, keys=None, type=None): + +def mask(inputs, key_masks=None, type=None): """Masks paddings on keys or queries to inputs - inputs: 3d tensor. (N, T_q, T_k) - queries: 3d tensor. (N, T_q, d) - keys: 3d tensor. (N, T_k, d) + inputs: 3d tensor. (h*N, T_q, T_k) + key_masks: 3d tensor. (N, 1, T_k) + type: string. "key" | "future" e.g., - >> queries = tf.constant([[[1.], - [2.], - [0.]]], tf.float32) # (1, 3, 1) - >> keys = tf.constant([[[4.], - [0.]]], tf.float32) # (1, 2, 1) - >> inputs = tf.constant([[[4., 0.], - [8., 0.], - [0., 0.]]], tf.float32) - >> mask(inputs, queries, keys, "key") - array([[[ 4.0000000e+00, -4.2949673e+09], - [ 8.0000000e+00, -4.2949673e+09], - [ 0.0000000e+00, -4.2949673e+09]]], dtype=float32) - >> inputs = tf.constant([[[1., 0.], - [1., 0.], - [1., 0.]]], tf.float32) - >> mask(inputs, queries, keys, "query") - array([[[1., 0.], - [1., 0.], - [0., 0.]]], dtype=float32) + >> inputs = tf.zeros([2, 2, 3], dtype=tf.float32) + >> key_masks = tf.constant([[0., 0., 1.], + [0., 1., 1.]]) + >> mask(inputs, key_masks=key_masks, type="key") + array([[[ 0.0000000e+00, 0.0000000e+00, -4.2949673e+09], + [ 0.0000000e+00, 0.0000000e+00, -4.2949673e+09]], + + [[ 0.0000000e+00, -4.2949673e+09, -4.2949673e+09], + [ 0.0000000e+00, -4.2949673e+09, -4.2949673e+09]], + + [[ 0.0000000e+00, 0.0000000e+00, -4.2949673e+09], + [ 0.0000000e+00, 0.0000000e+00, -4.2949673e+09]], + + [[ 0.0000000e+00, -4.2949673e+09, -4.2949673e+09], + [ 0.0000000e+00, -4.2949673e+09, -4.2949673e+09]]], dtype=float32) """ padding_num = -2 ** 32 + 1 if type in ("k", "key", "keys"): - # Generate masks - masks = tf.sign(tf.reduce_sum(tf.abs(keys), axis=-1)) # (N, T_k) - masks = tf.expand_dims(masks, 1) # (N, 1, T_k) - masks = tf.tile(masks, [1, tf.shape(queries)[1], 1]) # (N, T_q, T_k) - - # Apply masks to inputs - paddings = tf.ones_like(inputs) * padding_num - outputs = tf.where(tf.equal(masks, 0), paddings, inputs) # (N, T_q, T_k) - elif type in ("q", "query", "queries"): - # Generate masks - masks = tf.sign(tf.reduce_sum(tf.abs(queries), axis=-1)) # (N, T_q) - masks = tf.expand_dims(masks, -1) # (N, T_q, 1) - masks = tf.tile(masks, [1, 1, tf.shape(keys)[1]]) # (N, T_q, T_k) - - # Apply masks to inputs - outputs = inputs*masks + key_masks = tf.to_float(key_masks) + key_masks = tf.tile(key_masks, [tf.shape(inputs)[0] // tf.shape(key_masks)[0], 1]) # (h*N, seqlen) + key_masks = tf.expand_dims(key_masks, 1) # (h*N, 1, seqlen) + outputs = inputs + key_masks * padding_num + # elif type in ("q", "query", "queries"): + # # Generate masks + # masks = tf.sign(tf.reduce_sum(tf.abs(queries), axis=-1)) # (N, T_q) + # masks = tf.expand_dims(masks, -1) # (N, T_q, 1) + # masks = tf.tile(masks, [1, 1, tf.shape(keys)[1]]) # (N, T_q, T_k) + # + # # Apply masks to inputs + # outputs = inputs*masks elif type in ("f", "future", "right"): diag_vals = tf.ones_like(inputs[0, :, :]) # (T_q, T_k) tril = tf.linalg.LinearOperatorLowerTriangular(diag_vals).to_dense() # (T_q, T_k) - masks = tf.tile(tf.expand_dims(tril, 0), [tf.shape(inputs)[0], 1, 1]) # (N, T_q, T_k) + future_masks = tf.tile(tf.expand_dims(tril, 0), [tf.shape(inputs)[0], 1, 1]) # (N, T_q, T_k) - paddings = tf.ones_like(masks) * padding_num - outputs = tf.where(tf.equal(masks, 0), paddings, inputs) + paddings = tf.ones_like(future_masks) * padding_num + outputs = tf.where(tf.equal(future_masks, 0), paddings, inputs) else: print("Check if you entered type correctly!") - return outputs -def multihead_attention(queries, keys, values, + +def multihead_attention(queries, keys, values, key_masks, num_heads=8, dropout_rate=0, training=True, @@ -166,6 +160,7 @@ def multihead_attention(queries, keys, values, queries: A 3d tensor with shape of [N, T_q, d_model]. keys: A 3d tensor with shape of [N, T_k, d_model]. values: A 3d tensor with shape of [N, T_k, d_model]. + key_masks: A 2d tensor with shape of [N, key_seqlen] num_heads: An int. Number of heads. dropout_rate: A floating point number. training: Boolean. Controller of mechanism for dropout. @@ -178,9 +173,9 @@ def multihead_attention(queries, keys, values, d_model = queries.get_shape().as_list()[-1] with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): # Linear projections - Q = tf.layers.dense(queries, d_model, use_bias=False) # (N, T_q, d_model) - K = tf.layers.dense(keys, d_model, use_bias=False) # (N, T_k, d_model) - V = tf.layers.dense(values, d_model, use_bias=False) # (N, T_k, d_model) + Q = tf.layers.dense(queries, d_model, use_bias=True) # (N, T_q, d_model) + K = tf.layers.dense(keys, d_model, use_bias=True) # (N, T_k, d_model) + V = tf.layers.dense(values, d_model, use_bias=True) # (N, T_k, d_model) # Split and concat Q_ = tf.concat(tf.split(Q, num_heads, axis=2), axis=0) # (h*N, T_q, d_model/h) @@ -188,7 +183,7 @@ def multihead_attention(queries, keys, values, V_ = tf.concat(tf.split(V, num_heads, axis=2), axis=0) # (h*N, T_k, d_model/h) # Attention - outputs = scaled_dot_product_attention(Q_, K_, V_, causality, dropout_rate, training) + outputs = scaled_dot_product_attention(Q_, K_, V_, key_masks, causality, dropout_rate, training) # Restore shape outputs = tf.concat(tf.split(outputs, num_heads, axis=0), axis=2 ) # (N, T_q, d_model)