Skip to content

Commit

Permalink
update tensorflow api to 1.8
Browse files Browse the repository at this point in the history
  • Loading branch information
xu-song authored Jun 8, 2018
1 parent 6672f93 commit c7e8c54
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def multihead_attention(queries,
# Causality = Future blinding
if causality:
diag_vals = tf.ones_like(outputs[0, :, :]) # (T_q, T_k)
tril = tf.contrib.linalg.LinearOperatorTriL(diag_vals).to_dense() # (T_q, T_k)
tril = tf.linalg.LinearOperatorLowerTriangular(diag_vals).to_dense() # (T_q, T_k)
masks = tf.tile(tf.expand_dims(tril, 0), [tf.shape(outputs)[0], 1, 1]) # (h*N, T_q, T_k)

paddings = tf.ones_like(masks)*(-2**32+1)
Expand Down

0 comments on commit c7e8c54

Please sign in to comment.