From 8d513424171ff72c7ba7db070ef8bbe725e12757 Mon Sep 17 00:00:00 2001 From: Sciroccogti Date: Thu, 20 Oct 2022 13:58:42 +0800 Subject: [PATCH 1/3] feat: update to python3.7 --- lm/reader.py | 2 +- lm/util.py | 2 +- requirements.txt | 3 +++ 3 files changed, 5 insertions(+), 2 deletions(-) create mode 100644 requirements.txt diff --git a/lm/reader.py b/lm/reader.py index cde09aa..85d1c3d 100644 --- a/lm/reader.py +++ b/lm/reader.py @@ -92,7 +92,7 @@ def ptb_raw_data(dataset="ptb", data_path=None, return_id2word=False): test_data = _file_to_word_ids(test_path, word_to_id) vocabulary = len(word_to_id) if return_id2word: - id2word = dict([(_id, word) for word, _id in word_to_id.iteritems()]) + id2word = dict([(_id, word) for word, _id in word_to_id.items()]) id2word_list = [ id2word[i] if i == _id else -1 for i, _id in enumerate(sorted(id2word))] if len(id2word) != len(id2word_list): diff --git a/lm/util.py b/lm/util.py index dd1ffa1..a66e973 100644 --- a/lm/util.py +++ b/lm/util.py @@ -18,7 +18,7 @@ from __future__ import print_function import os -from functools import partial +from functools import partial, reduce import tensorflow as tf from tensorflow.core.framework import variable_pb2 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..33ab8c9 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,3 @@ +numpy==1.21.5 +six==1.16.0 +tensorflow==1.15.0 From fb730b51df7f788e8d21cf41dda6ab4db21d9811 Mon Sep 17 00:00:00 2001 From: Sciroccogti Date: Tue, 25 Oct 2022 14:41:13 +0800 Subject: [PATCH 2/3] format all .py with autopep8 --- core/kd_quantizer.py | 340 +++--- core/kdq_embedding.py | 163 +-- lm/ptb_word_lm.py | 772 ++++++------- lm/reader.py | 205 ++-- lm/util.py | 375 +++---- nmt/attention_model.py | 318 +++--- nmt/gnmt_model.py | 573 +++++----- nmt/inference.py | 361 ++++--- nmt/inference_test.py | 272 ++--- nmt/model.py | 1647 ++++++++++++++-------------- nmt/model_helper.py | 1158 ++++++++++---------- nmt/model_test.py | 1985 +++++++++++++++++----------------- nmt/nmt.py | 1097 +++++++++---------- nmt/nmt_test.py | 108 +- nmt/train.py | 1174 ++++++++++---------- text_classification/data.py | 177 +-- text_classification/main.py | 121 ++- text_classification/model.py | 124 +-- text_classification/util.py | 119 +- 19 files changed, 5547 insertions(+), 5542 deletions(-) diff --git a/core/kd_quantizer.py b/core/kd_quantizer.py index 5c1e1fe..f114e23 100644 --- a/core/kd_quantizer.py +++ b/core/kd_quantizer.py @@ -2,184 +2,184 @@ def safer_log(x, eps=1e-10): - """Avoid nan when x is zero by adding small eps. + """Avoid nan when x is zero by adding small eps. - Note that if x.dtype=tf.float16, \forall eps, eps < 3e-8, is equal to zero. - """ - return tf.log(x + eps) + Note that if x.dtype=tf.float16, \forall eps, eps < 3e-8, is equal to zero. + """ + return tf.log(x + eps) def sample_gumbel(shape): - """Sample from Gumbel(0, 1)""" - U = tf.random_uniform(shape, minval=0, maxval=1) - return -safer_log(-safer_log(U)) + """Sample from Gumbel(0, 1)""" + U = tf.random_uniform(shape, minval=0, maxval=1) + return -safer_log(-safer_log(U)) class KDQuantizer(object): - def __init__(self, K, D, d_in, d_out, tie_in_n_out, - query_metric="dot", shared_centroids=False, - beta=0., tau=1.0, softmax_BN=True): - """ - Args: - K, D: int, size of KD code. - d_in: dim of continuous input for each of D axis. - d_out: dim of continuous output for each of D axis. - tie_in_n_out: boolean, whether or not to tie the input/output centroids. - If True, it is vector quantization, else it is tempering softmax. - query_metric: string, which metric to use for input centroid matching. - shared_centroids: boolean, whether or not to share centroids for - different bits in D. - beta: float, KDQ regularization coefficient. - tau: float or None, (tempering) softmax temperature. - If None, set to learnable. - softmax_BN: whether to use BN in (tempering) softmax. - """ - self._K = K - self._D = D - self._d_in = d_in - self._d_out = d_out - self._tie_in_n_out = tie_in_n_out - self._query_metric = query_metric - self._shared_centroids = shared_centroids - self._beta = beta - if tau is None: - self._tau = tf.get_variable( - "tau", [], initializer=tf.constant_initializer(1.0)) - else: - self._tau = tf.constant(tau); - self._softmax_BN = softmax_BN - - # Create centroids for keys and values. - D_to_create = 1 if shared_centroids else D - centroids_k = tf.get_variable( - "centroids_k", [D_to_create, K, d_in]) - if tie_in_n_out: - centroids_v = centroids_k - else: - centroids_v = tf.get_variable( - "centroids_v", [D_to_create, K, d_out]) - if shared_centroids: - centroids_k = tf.tile(centroids_k, [D, 1, 1]) - if tie_in_n_out: - centroids_v = centroids_k - else: - centroids_v = tf.tile(centroids_v, [D, 1, 1]) - self._centroids_k = centroids_k - self._centroids_v = centroids_v - - def forward(self, - inputs, - sampling=False, - is_training=True): - """Returns quantized embeddings from centroids. - - Args: - inputs: embedding tensor of shape (batch_size, D, d_in) - - Returns: - code: (batch_size, D) - embs_quantized: (batch_size, D, d_out) - """ - with tf.name_scope("kdq_forward"): - # House keeping. - centroids_k = self._centroids_k - centroids_v = self._centroids_v - - # Compute distance (in a metric) between inputs and centroids_k - # the response is in the shape of (batch_size, D, K) - if self._query_metric == "euclidean": - norm_1 = tf.reduce_sum(inputs**2, -1, keep_dims=True) # (bs, D, 1) - norm_2 = tf.expand_dims(tf.reduce_sum(centroids_k**2, -1), 0) # (1, D, K) - dot = tf.matmul(tf.transpose(inputs, perm=[1, 0, 2]), - tf.transpose(centroids_k, perm=[0, 2, 1])) # (D, bs, K) - response = -norm_1 + 2 * tf.transpose(dot, perm=[1, 0, 2]) - norm_2 - elif self._query_metric == "cosine": - inputs = tf.nn.l2_normalize(inputs, -1) - centroids_k = tf.nn.l2_normalize(centroids_k, -1) - response = tf.matmul(tf.transpose(inputs, perm=[1, 0, 2]), - tf.transpose(centroids_k, perm=[0, 2, 1])) # (D, bs, K) - response = tf.transpose(response, perm=[1, 0, 2]) - elif self._query_metric == "dot": - response = tf.matmul(tf.transpose(inputs, perm=[1, 0, 2]), - tf.transpose(centroids_k, perm=[0, 2, 1])) # (D, bs, K) - response = tf.transpose(response, perm=[1, 0, 2]) - else: - raise ValueError("Unknown metric {}".format(self._query_metric)) - response = tf.reshape(response, [-1, self._D, self._K]) - if self._softmax_BN: - # response = tf.contrib.layers.instance_norm( - # response, scale=False, center=False, - # trainable=False, data_format="NCHW") - response = tf.layers.batch_normalization( - response, scale=False, center=False, training=is_training) - # Layer norm as alternative to BN. - # response = tf.contrib.layers.layer_norm( - # response, scale=False, center=False) - response_prob = tf.nn.softmax(response / self._tau, -1) - - # Compute the codes based on response. - codes = tf.argmax(response, -1) # (batch_size, D) - if sampling: - response = safer_log(response_prob) - noises = sample_gumbel(tf.shape(response)) - neighbor_idxs = tf.argmax(response + noises, -1) # (batch_size, D) - else: - neighbor_idxs = codes - - # Compute the outputs, which has shape (batch_size, D, d_out) - if self._tie_in_n_out: - if not self._shared_centroids: - D_base = tf.convert_to_tensor( - [self._K*d for d in range(self._D)], dtype=tf.int64) - neighbor_idxs += tf.expand_dims(D_base, 0) # (batch_size, D) - neighbor_idxs = tf.reshape(neighbor_idxs, [-1]) # (batch_size * D) - centroids_v = tf.reshape(centroids_v, [-1, self._d_out]) - outputs = tf.nn.embedding_lookup(centroids_v, neighbor_idxs) - outputs = tf.reshape(outputs, [-1, self._D, self._d_out]) - outputs_final = tf.stop_gradient(outputs - inputs) + inputs - else: - nb_idxs_onehot = tf.one_hot(neighbor_idxs, - self._K) # (batch_size, D, K) - nb_idxs_onehot = response_prob - tf.stop_gradient( - response_prob - nb_idxs_onehot) - # nb_idxs_onehot = response_prob # use continuous output - outputs = tf.matmul( - tf.transpose(nb_idxs_onehot, [1, 0, 2]), # (D, bs, K) - centroids_v) # (D, bs, d) - outputs_final = tf.transpose(outputs, [1, 0, 2]) - - # Add regularization for updating centroids / stabilization. - if is_training: - print("[INFO] Adding KDQ regularization.") - if self._tie_in_n_out: - alpha = 1. - beta = self._beta - gamma = 0.0 - reg = alpha * tf.reduce_mean( - (outputs - tf.stop_gradient(inputs))**2, name="centroids_adjust") - reg += beta * tf.reduce_mean( - (tf.stop_gradient(outputs) - inputs)**2, name="input_commit") - minaxis = [0, 1] if self._shared_centroids else [0] - reg += gamma * tf.reduce_mean( # could sg(inputs), but still not eff. - tf.reduce_min(-response, minaxis), name="de_isolation") + def __init__(self, K, D, d_in, d_out, tie_in_n_out, + query_metric="dot", shared_centroids=False, + beta=0., tau=1.0, softmax_BN=True): + """ + Args: + K, D: int, size of KD code. + d_in: dim of continuous input for each of D axis. + d_out: dim of continuous output for each of D axis. + tie_in_n_out: boolean, whether or not to tie the input/output centroids. + If True, it is vector quantization, else it is tempering softmax. + query_metric: string, which metric to use for input centroid matching. + shared_centroids: boolean, whether or not to share centroids for + different bits in D. + beta: float, KDQ regularization coefficient. + tau: float or None, (tempering) softmax temperature. + If None, set to learnable. + softmax_BN: whether to use BN in (tempering) softmax. + """ + self._K = K + self._D = D + self._d_in = d_in + self._d_out = d_out + self._tie_in_n_out = tie_in_n_out + self._query_metric = query_metric + self._shared_centroids = shared_centroids + self._beta = beta + if tau is None: + self._tau = tf.get_variable( + "tau", [], initializer=tf.constant_initializer(1.0)) else: - beta = self._beta - reg = - beta * tf.reduce_mean( - tf.reduce_sum(nb_idxs_onehot * safer_log(response_prob), [2])) - # entropy regularization - # reg = - beta * tf.reduce_mean( - # tf.reduce_sum(response_prob * safer_log(response_prob), [2])) - tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, reg) - - return codes, outputs_final + self._tau = tf.constant(tau) + self._softmax_BN = softmax_BN + + # Create centroids for keys and values. + D_to_create = 1 if shared_centroids else D + centroids_k = tf.get_variable( + "centroids_k", [D_to_create, K, d_in]) + if tie_in_n_out: + centroids_v = centroids_k + else: + centroids_v = tf.get_variable( + "centroids_v", [D_to_create, K, d_out]) + if shared_centroids: + centroids_k = tf.tile(centroids_k, [D, 1, 1]) + if tie_in_n_out: + centroids_v = centroids_k + else: + centroids_v = tf.tile(centroids_v, [D, 1, 1]) + self._centroids_k = centroids_k + self._centroids_v = centroids_v + + def forward(self, + inputs, + sampling=False, + is_training=True): + """Returns quantized embeddings from centroids. + + Args: + inputs: embedding tensor of shape (batch_size, D, d_in) + + Returns: + code: (batch_size, D) + embs_quantized: (batch_size, D, d_out) + """ + with tf.name_scope("kdq_forward"): + # House keeping. + centroids_k = self._centroids_k # (D, K, d_in) + centroids_v = self._centroids_v + + # Compute distance (in a metric) between inputs and centroids_k + # the response is in the shape of (batch_size, D, K) + if self._query_metric == "euclidean": + norm_1 = tf.reduce_sum(inputs**2, -1, keep_dims=True) # (bs, D, 1) + norm_2 = tf.expand_dims(tf.reduce_sum(centroids_k**2, -1), 0) # (1, D, K) + dot = tf.matmul(tf.transpose(inputs, perm=[1, 0, 2]), + tf.transpose(centroids_k, perm=[0, 2, 1])) # (D, bs, K) + response = -norm_1 + 2 * tf.transpose(dot, perm=[1, 0, 2]) - norm_2 + elif self._query_metric == "cosine": + inputs = tf.nn.l2_normalize(inputs, -1) + centroids_k = tf.nn.l2_normalize(centroids_k, -1) + response = tf.matmul(tf.transpose(inputs, perm=[1, 0, 2]), + tf.transpose(centroids_k, perm=[0, 2, 1])) # (D, bs, K) + response = tf.transpose(response, perm=[1, 0, 2]) + elif self._query_metric == "dot": + response = tf.matmul(tf.transpose(inputs, perm=[1, 0, 2]), + tf.transpose(centroids_k, perm=[0, 2, 1])) # (D, bs, K) + response = tf.transpose(response, perm=[1, 0, 2]) + else: + raise ValueError("Unknown metric {}".format(self._query_metric)) + response = tf.reshape(response, [-1, self._D, self._K]) + if self._softmax_BN: + # response = tf.contrib.layers.instance_norm( + # response, scale=False, center=False, + # trainable=False, data_format="NCHW") + response = tf.layers.batch_normalization( + response, scale=False, center=False, training=is_training) + # Layer norm as alternative to BN. + # response = tf.contrib.layers.layer_norm( + # response, scale=False, center=False) + response_prob = tf.nn.softmax(response / self._tau, -1) + + # Compute the codes based on response. + codes = tf.argmax(response, -1) # (batch_size, D) + if sampling: + response = safer_log(response_prob) + noises = sample_gumbel(tf.shape(response)) + neighbor_idxs = tf.argmax(response + noises, -1) # (batch_size, D) + else: + neighbor_idxs = codes + + # Compute the outputs, which has shape (batch_size, D, d_out) + if self._tie_in_n_out: + if not self._shared_centroids: + D_base = tf.convert_to_tensor( + [self._K*d for d in range(self._D)], dtype=tf.int64) + neighbor_idxs += tf.expand_dims(D_base, 0) # (batch_size, D) + neighbor_idxs = tf.reshape(neighbor_idxs, [-1]) # (batch_size * D) + centroids_v = tf.reshape(centroids_v, [-1, self._d_out]) + outputs = tf.nn.embedding_lookup(centroids_v, neighbor_idxs) + outputs = tf.reshape(outputs, [-1, self._D, self._d_out]) + outputs_final = tf.stop_gradient(outputs - inputs) + inputs + else: + nb_idxs_onehot = tf.one_hot(neighbor_idxs, + self._K) # (batch_size, D, K) + nb_idxs_onehot = response_prob - tf.stop_gradient( + response_prob - nb_idxs_onehot) + # nb_idxs_onehot = response_prob # use continuous output + outputs = tf.matmul( + tf.transpose(nb_idxs_onehot, [1, 0, 2]), # (D, bs, K) + centroids_v) # (D, bs, d) + outputs_final = tf.transpose(outputs, [1, 0, 2]) + + # Add regularization for updating centroids / stabilization. + if is_training: + print("[INFO] Adding KDQ regularization.") + if self._tie_in_n_out: + alpha = 1. + beta = self._beta + gamma = 0.0 + reg = alpha * tf.reduce_mean( + (outputs - tf.stop_gradient(inputs))**2, name="centroids_adjust") + reg += beta * tf.reduce_mean( + (tf.stop_gradient(outputs) - inputs)**2, name="input_commit") + minaxis = [0, 1] if self._shared_centroids else [0] + reg += gamma * tf.reduce_mean( # could sg(inputs), but still not eff. + tf.reduce_min(-response, minaxis), name="de_isolation") + else: + beta = self._beta + reg = - beta * tf.reduce_mean( + tf.reduce_sum(nb_idxs_onehot * safer_log(response_prob), [2])) + # entropy regularization + # reg = - beta * tf.reduce_mean( + # tf.reduce_sum(response_prob * safer_log(response_prob), [2])) + tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, reg) + + return codes, outputs_final if __name__ == "__main__": - # VQ - with tf.variable_scope("VQ"): - kdq_demo = KDQuantizer(100, 10, 5, 5, True, "euclidean") - codes_vq, outputs_vq = kdq_demo.forward(tf.random_normal([64, 10, 5])) - # tempering softmax - with tf.variable_scope("tempering_softmax"): - kdq_demo = KDQuantizer(100, 10, 5, 10, False, "dot") - codes_ts, outputs_ts = kdq_demo.forward(tf.random_normal([64, 10, 5])) + # VQ + with tf.variable_scope("VQ"): + kdq_demo = KDQuantizer(100, 10, 5, 5, True, "euclidean") + codes_vq, outputs_vq = kdq_demo.forward(tf.random_normal([64, 10, 5])) + # tempering softmax + with tf.variable_scope("tempering_softmax"): + kdq_demo = KDQuantizer(100, 10, 5, 10, False, "dot") + codes_ts, outputs_ts = kdq_demo.forward(tf.random_normal([64, 10, 5])) diff --git a/core/kdq_embedding.py b/core/kdq_embedding.py index e1f3c9c..42e7d39 100644 --- a/core/kdq_embedding.py +++ b/core/kdq_embedding.py @@ -1,99 +1,100 @@ import tensorflow as tf from kd_quantizer import KDQuantizer + def full_embed(input, vocab_size, emb_size, hparams=None, - training=True, name="full_emb"): - """Full embedding baseline. + training=True, name="full_emb"): + """Full embedding baseline. - Args: - input: int multi-dim tensor, entity idxs. - vocab_size: int, vocab size - emb_size: int, output embedding size + Args: + input: int multi-dim tensor, entity idxs. + vocab_size: int, vocab size + emb_size: int, output embedding size - Returns: - input_emb: float tensor, embedding for entity idxs. - """ - with tf.variable_scope(name): - embedding = tf.get_variable("embedding", [vocab_size, emb_size]) - input_emb = tf.nn.embedding_lookup(embedding, input) - return input_emb + Returns: + input_emb: float tensor, embedding for entity idxs. + """ + with tf.variable_scope(name): + embedding = tf.get_variable("embedding", [vocab_size, emb_size]) + input_emb = tf.nn.embedding_lookup(embedding, input) + return input_emb def kdq_embed(input, vocab_size, emb_size, hparams=None, training=True, name="kdq_emb"): - """KDQ embedding with VQ or SMX. + """KDQ embedding with VQ or SMX. - This is an drop-in replacement of ``full_embed`` baseline above. + This is an drop-in replacement of ``full_embed`` baseline above. - Args: - input: int multi-dim tensor, entity idxs. - vocab_size: int, vocab size - emb_size: int, output embedding size - hparams: hparams for KDQ, see KDQhparam class for a reference. - training: whether or not this is in training mode (related to BN) + Args: + input: int multi-dim tensor, entity idxs. + vocab_size: int, vocab size + emb_size: int, output embedding size + hparams: hparams for KDQ, see KDQhparam class for a reference. + training: whether or not this is in training mode (related to BN) - Returns: - input_emb: float tensor, embedding for entity idxs. - """ - if hparams is None: - hparams = KDQhparam() - d, K, D = emb_size, hparams.K, hparams.D - d_in = d//D if hparams.kdq_d_in <= 0 else hparams.kdq_d_in # could use diff. d_in/d_out for smx - d_in = d if hparams.additive_quantization else d_in - d_out = d if hparams.additive_quantization else d//D - out_size = [D, emb_size] if hparams.additive_quantization else [emb_size] + Returns: + input_emb: float tensor, embedding for entity idxs. + """ + if hparams is None: + hparams = KDQhparam() + d, K, D = emb_size, hparams.K, hparams.D + d_in = d//D if hparams.kdq_d_in <= 0 else hparams.kdq_d_in # could use diff. d_in/d_out for smx + d_in = d if hparams.additive_quantization else d_in + d_out = d if hparams.additive_quantization else d//D + out_size = [D, emb_size] if hparams.additive_quantization else [emb_size] - with tf.variable_scope(name, reuse=tf.AUTO_REUSE): - query_wemb = tf.get_variable( - "query_wemb", [vocab_size, D * d_in], dtype=tf.float32) - idxs = tf.reshape(input, [-1]) - input_emb = tf.nn.embedding_lookup(query_wemb, idxs) # (bs*len, d) + with tf.variable_scope(name, reuse=tf.AUTO_REUSE): + query_wemb = tf.get_variable( + "query_wemb", [vocab_size, D * d_in], dtype=tf.float32) + idxs = tf.reshape(input, [-1]) + input_emb = tf.nn.embedding_lookup(query_wemb, idxs) # (bs*len, d) - if hparams.kdq_type == "vq": - assert hparams.kdq_d_in <= 0, ( - "kdq d_in cannot be changed (to %d) for vq" % hparams.kdq_d_in) - tie_in_n_out = True - dist_metric = "euclidean" - beta, tau, softmax_BN = 0.0, 1.0, True - share_subspace = hparams.kdq_share_subspace - else: - assert hparams.kdq_type == "smx", [ - "unknown kdq_type %s" % hparams.kdq_type] - tie_in_n_out = False - dist_metric = "dot" - beta, tau, softmax_BN = 0.0, 1.0, True - share_subspace = hparams.kdq_share_subspace - kdq = KDQuantizer(K, D, d_in, d_out, tie_in_n_out, - dist_metric, share_subspace, - beta, tau, softmax_BN) - _, input_emb = kdq.forward(tf.reshape(input_emb, [-1, D, d_in]), - is_training=training) - final_size = tf.concat( - [tf.shape(input), tf.constant(out_size)], 0) - input_emb = tf.reshape(input_emb, final_size) - if hparams.additive_quantization: - input_emb = tf.reduce_mean(input_emb, -2) - return input_emb + if hparams.kdq_type == "vq": + assert hparams.kdq_d_in <= 0, ( + "kdq d_in cannot be changed (to %d) for vq" % hparams.kdq_d_in) + tie_in_n_out = True + dist_metric = "euclidean" + beta, tau, softmax_BN = 0.0, 1.0, True + share_subspace = hparams.kdq_share_subspace + else: + assert hparams.kdq_type == "smx", [ + "unknown kdq_type %s" % hparams.kdq_type] + tie_in_n_out = False + dist_metric = "dot" + beta, tau, softmax_BN = 0.0, 1.0, True + share_subspace = hparams.kdq_share_subspace + kdq = KDQuantizer(K, D, d_in, d_out, tie_in_n_out, + dist_metric, share_subspace, + beta, tau, softmax_BN) + _, input_emb = kdq.forward(tf.reshape(input_emb, [-1, D, d_in]), + is_training=training) + final_size = tf.concat( + [tf.shape(input), tf.constant(out_size)], 0) + input_emb = tf.reshape(input_emb, final_size) + if hparams.additive_quantization: + input_emb = tf.reduce_mean(input_emb, -2) + return input_emb class KDQhparam(object): - # A default KDQ parameter setting (demo) - def __init__(self, - K=16, - D=32, - kdq_type='smx', - kdq_d_in=0, - kdq_share_subspace=True, - additive_quantization=False): - """ - Args: - kdq_type: 'vq' or 'smx' - kdq_d_in: when kdq_type == 'smx', we could reduce d_in - kdq_share_subspace: whether or not to share the subspace among D. - """ - self.K = K - self.D = D - self.kdq_type = kdq_type - self.kdq_d_in = kdq_d_in - self.kdq_share_subspace = kdq_share_subspace - self.additive_quantization = additive_quantization + # A default KDQ parameter setting (demo) + def __init__(self, + K=16, + D=32, + kdq_type='smx', + kdq_d_in=0, + kdq_share_subspace=True, + additive_quantization=False): + """ + Args: + kdq_type: 'vq' or 'smx' + kdq_d_in: when kdq_type == 'smx', we could reduce d_in + kdq_share_subspace: whether or not to share the subspace among D. + """ + self.K = K + self.D = D + self.kdq_type = kdq_type + self.kdq_d_in = kdq_d_in + self.kdq_share_subspace = kdq_share_subspace + self.additive_quantization = additive_quantization diff --git a/lm/ptb_word_lm.py b/lm/ptb_word_lm.py index 50d73ab..95c91ce 100644 --- a/lm/ptb_word_lm.py +++ b/lm/ptb_word_lm.py @@ -19,6 +19,8 @@ http://arxiv.org/abs/1409.2329 """ from __future__ import absolute_import +from kdq_embedding import full_embed, kdq_embed, KDQhparam +from kd_quantizer import KDQuantizer from __future__ import division from __future__ import print_function @@ -36,8 +38,6 @@ parent_path = "/".join(os.getcwd().split('/')[:-1]) sys.path.append(os.path.join(parent_path, "core")) -from kd_quantizer import KDQuantizer -from kdq_embedding import full_embed, kdq_embed, KDQhparam # os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # disable when using tf.Print @@ -56,7 +56,7 @@ flags.DEFINE_string("save_path", None, "Model output directory.") flags.DEFINE_integer("save_model_secs", 0, - "Setting it zero to avoid saving the checkpoint.") + "Setting it zero to avoid saving the checkpoint.") flags.DEFINE_bool("use_fp16", False, "Train using 16-bit floats instead of 32bit floats.") flags.DEFINE_integer("max_max_epoch", None, @@ -91,418 +91,418 @@ def data_type(): - return tf.float16 if FLAGS.use_fp16 else tf.float32 + return tf.float16 if FLAGS.use_fp16 else tf.float32 def print_at_beginning(hparams): - global vocab_size - print("kdq_type={}, vocab_size={}, K={}, D={}".format( - FLAGS.kdq_type, vocab_size, FLAGS.K, FLAGS.D)) - print("Number of trainable params: {}".format( - util.get_parameter_count( - excludings=["code_logits", "embb", "symbol2code"]))) + global vocab_size + print("kdq_type={}, vocab_size={}, K={}, D={}".format( + FLAGS.kdq_type, vocab_size, FLAGS.K, FLAGS.D)) + print("Number of trainable params: {}".format( + util.get_parameter_count( + excludings=["code_logits", "embb", "symbol2code"]))) class PTBInput(object): - """The input data.""" + """The input data.""" - def __init__(self, config, data, name=None): - self.batch_size = batch_size = config.batch_size - self.num_steps = num_steps = config.num_steps - self.epoch_size = ((len(data) // batch_size) - 1) // num_steps - self.input_data, self.targets = reader.ptb_producer( - data, batch_size, num_steps, name=name) + def __init__(self, config, data, name=None): + self.batch_size = batch_size = config.batch_size + self.num_steps = num_steps = config.num_steps + self.epoch_size = ((len(data) // batch_size) - 1) // num_steps + self.input_data, self.targets = reader.ptb_producer( + data, batch_size, num_steps, name=name) class PTBModel(object): - """The PTB model.""" - - def __init__(self, is_training, config, input_, vocab_size): - self._is_training = is_training - self._input = input_ - self._rnn_params = None - self._cell = None - self.batch_size = input_.batch_size - self.num_steps = input_.num_steps - self._config = config - size = config.hidden_size - - if FLAGS.kdq_type == "none": - inputs = full_embed(input_.input_data, vocab_size, size) - else: - kdq_hparam = KDQhparam( - K=FLAGS.K, D=FLAGS.D, kdq_type=FLAGS.kdq_type, - kdq_d_in=FLAGS.kdq_d_in, kdq_share_subspace=FLAGS.kdq_share_subspace, - additive_quantization=FLAGS.additive_quantization) - inputs = kdq_embed( - input_.input_data, vocab_size, size, kdq_hparam, is_training) - - # RNN layers - outputs, state = self._build_rnn_graph(inputs, config, is_training) - - # final softmax layer - targets = input_.targets - softmax_w = tf.get_variable( - "softmax_w", [size, vocab_size], dtype=data_type()) - softmax_b = tf.get_variable("softmax_b", [vocab_size], dtype=data_type()) - logits = tf.nn.xw_plus_b(outputs, softmax_w, softmax_b) - # Reshape logits to be a 3-D tensor for sequence loss - logits = tf.reshape(logits, - [self.batch_size, self.num_steps, vocab_size]) - loss = tf.contrib.seq2seq.sequence_loss( - logits, - targets, - tf.ones([self.batch_size, self.num_steps], dtype=data_type()), - average_across_timesteps=False, - average_across_batch=False) # (batch_size, num_steps) - - # Update the cost - self._nll = tf.reduce_sum(tf.reduce_mean(loss, 0)) - self._cost = self._nll - self._final_state = state - - # compute recall metric - _, preds_topk = tf.nn.top_k(logits, FLAGS.eval_topk) - targets_topk = tf.tile( - tf.expand_dims(targets, -1), - [1] * targets.shape.ndims + [FLAGS.eval_topk]) - hits = tf.reduce_sum( - tf.cast(tf.equal(preds_topk, targets_topk), tf.float32), -1) - self._recall_at_k = tf.reduce_sum(tf.reduce_mean(hits, 0)) - - if not is_training: - return - - # Add regularization. - print("[INFO] regularization loss", - tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) - self._cost += sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) - - # Optimizer - self._lr = tf.Variable(0.0, trainable=False) - update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) - with tf.control_dependencies(update_ops): - grads = tf.gradients(self._cost, tf.trainable_variables()) - tf.summary.scalar("global_grad_norm", tf.global_norm(grads)) - grads, _ = tf.clip_by_global_norm(grads, - config.max_grad_norm) - optimizer = tf.train.GradientDescentOptimizer(self._lr) - self._train_op = optimizer.apply_gradients( - zip(grads, tf.trainable_variables()), - global_step=tf.train.get_or_create_global_step()) - self._new_lr = tf.placeholder( - tf.float32, shape=[], name="new_learning_rate") - self._lr_update = tf.assign(self._lr, self._new_lr) - - def _build_rnn_graph(self, inputs, config, is_training): - if config.rnn_mode == CUDNN: - return self._build_rnn_graph_cudnn(inputs, config, is_training) - else: - return self._build_rnn_graph_lstm(inputs, config, is_training) - - def _build_rnn_graph_cudnn(self, inputs, config, is_training): - """Build the inference graph using CUDNN cell.""" - if is_training and config.keep_prob < 1: - inputs = tf.nn.dropout(inputs, config.keep_prob) - - inputs = tf.transpose(inputs, [1, 0, 2]) - self._cell = tf.contrib.cudnn_rnn.CudnnLSTM( - num_layers=config.num_layers, - num_units=config.hidden_size, - input_size=config.hidden_size, - dropout=1 - config.keep_prob if is_training else 0) - params_size_t = self._cell.params_size() - self._rnn_params = tf.get_variable( - "lstm_params", - initializer=tf.random_uniform( - [params_size_t], -config.init_scale, config.init_scale), - validate_shape=False) - c = tf.zeros([config.num_layers, self.batch_size, config.hidden_size], - tf.float32) - h = tf.zeros([config.num_layers, self.batch_size, config.hidden_size], - tf.float32) - self._initial_state = (tf.contrib.rnn.LSTMStateTuple(h=h, c=c),) - outputs, h, c = self._cell(inputs, h, c, self._rnn_params, is_training) - outputs = tf.transpose(outputs, [1, 0, 2]) - outputs = tf.reshape(outputs, [-1, config.hidden_size]) - return outputs, (tf.contrib.rnn.LSTMStateTuple(h=h, c=c),) - - def _get_lstm_cell(self, config, is_training): - if config.rnn_mode == BASIC: - return tf.contrib.rnn.BasicLSTMCell( - config.hidden_size, forget_bias=0.0, state_is_tuple=True, - reuse=not is_training) - if config.rnn_mode == BLOCK: - return tf.contrib.rnn.LSTMBlockCell( - config.hidden_size, forget_bias=0.0) - raise ValueError("rnn_mode %s not supported" % config.rnn_mode) - - def _build_rnn_graph_lstm(self, inputs, config, is_training): - """Build the inference graph using canonical LSTM cells without Wrapper.""" - init_sates = [] - final_states = [] - with tf.variable_scope("RNN", reuse=not is_training): - for l in range(config.num_layers): - with tf.variable_scope("layer_%d" % l): - cell = self._get_lstm_cell(config, is_training) - initial_state = cell.zero_state(self.batch_size, data_type()) - init_sates.append(initial_state) - state = init_sates[-1] - if is_training and config.keep_prob < 1: + """The PTB model.""" + + def __init__(self, is_training, config, input_, vocab_size): + self._is_training = is_training + self._input = input_ + self._rnn_params = None + self._cell = None + self.batch_size = input_.batch_size + self.num_steps = input_.num_steps + self._config = config + size = config.hidden_size + + if FLAGS.kdq_type == "none": + inputs = full_embed(input_.input_data, vocab_size, size) + else: + kdq_hparam = KDQhparam( + K=FLAGS.K, D=FLAGS.D, kdq_type=FLAGS.kdq_type, + kdq_d_in=FLAGS.kdq_d_in, kdq_share_subspace=FLAGS.kdq_share_subspace, + additive_quantization=FLAGS.additive_quantization) + inputs = kdq_embed( + input_.input_data, vocab_size, size, kdq_hparam, is_training) + + # RNN layers + outputs, state = self._build_rnn_graph(inputs, config, is_training) + + # final softmax layer + targets = input_.targets + softmax_w = tf.get_variable( + "softmax_w", [size, vocab_size], dtype=data_type()) + softmax_b = tf.get_variable("softmax_b", [vocab_size], dtype=data_type()) + logits = tf.nn.xw_plus_b(outputs, softmax_w, softmax_b) + # Reshape logits to be a 3-D tensor for sequence loss + logits = tf.reshape(logits, + [self.batch_size, self.num_steps, vocab_size]) + loss = tf.contrib.seq2seq.sequence_loss( + logits, + targets, + tf.ones([self.batch_size, self.num_steps], dtype=data_type()), + average_across_timesteps=False, + average_across_batch=False) # (batch_size, num_steps) + + # Update the cost + self._nll = tf.reduce_sum(tf.reduce_mean(loss, 0)) + self._cost = self._nll + self._final_state = state + + # compute recall metric + _, preds_topk = tf.nn.top_k(logits, FLAGS.eval_topk) + targets_topk = tf.tile( + tf.expand_dims(targets, -1), + [1] * targets.shape.ndims + [FLAGS.eval_topk]) + hits = tf.reduce_sum( + tf.cast(tf.equal(preds_topk, targets_topk), tf.float32), -1) + self._recall_at_k = tf.reduce_sum(tf.reduce_mean(hits, 0)) + + if not is_training: + return + + # Add regularization. + print("[INFO] regularization loss", + tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) + self._cost += sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) + + # Optimizer + self._lr = tf.Variable(0.0, trainable=False) + update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) + with tf.control_dependencies(update_ops): + grads = tf.gradients(self._cost, tf.trainable_variables()) + tf.summary.scalar("global_grad_norm", tf.global_norm(grads)) + grads, _ = tf.clip_by_global_norm(grads, + config.max_grad_norm) + optimizer = tf.train.GradientDescentOptimizer(self._lr) + self._train_op = optimizer.apply_gradients( + zip(grads, tf.trainable_variables()), + global_step=tf.train.get_or_create_global_step()) + self._new_lr = tf.placeholder( + tf.float32, shape=[], name="new_learning_rate") + self._lr_update = tf.assign(self._lr, self._new_lr) + + def _build_rnn_graph(self, inputs, config, is_training): + if config.rnn_mode == CUDNN: + return self._build_rnn_graph_cudnn(inputs, config, is_training) + else: + return self._build_rnn_graph_lstm(inputs, config, is_training) + + def _build_rnn_graph_cudnn(self, inputs, config, is_training): + """Build the inference graph using CUDNN cell.""" + if is_training and config.keep_prob < 1: inputs = tf.nn.dropout(inputs, config.keep_prob) - inputs = tf.unstack(inputs, num=self.num_steps, axis=1) - outputs, state = tf.contrib.rnn.static_rnn(cell, inputs, - initial_state=init_sates[-1]) - final_states.append(state) - outputs = [tf.expand_dims(output, 1) for output in outputs] - outputs = tf.concat(outputs, 1) - inputs = outputs - outputs = tf.reshape(outputs, [-1, config.hidden_size]) - self._initial_state = tuple(init_sates) - state = tuple(final_states) - if is_training and config.keep_prob < 1: - outputs = tf.nn.dropout(outputs, config.keep_prob) - return outputs, state - - def assign_lr(self, session, lr_value): - session.run(self._lr_update, feed_dict={self._new_lr: lr_value}) - - @property - def input(self): - return self._input - - @property - def initial_state(self): - return self._initial_state - - @property - def cost(self): - return self._cost - - @property - def nll(self): - return self._nll - - @property - def recall_at_k(self): - return self._recall_at_k - - @property - def final_state(self): - return self._final_state - - @property - def lr(self): - return self._lr - - @property - def train_op(self): - return self._train_op - - @property - def initial_state_name(self): - return self._initial_state_name - - @property - def final_state_name(self): - return self._final_state_name + + inputs = tf.transpose(inputs, [1, 0, 2]) + self._cell = tf.contrib.cudnn_rnn.CudnnLSTM( + num_layers=config.num_layers, + num_units=config.hidden_size, + input_size=config.hidden_size, + dropout=1 - config.keep_prob if is_training else 0) + params_size_t = self._cell.params_size() + self._rnn_params = tf.get_variable( + "lstm_params", + initializer=tf.random_uniform( + [params_size_t], -config.init_scale, config.init_scale), + validate_shape=False) + c = tf.zeros([config.num_layers, self.batch_size, config.hidden_size], + tf.float32) + h = tf.zeros([config.num_layers, self.batch_size, config.hidden_size], + tf.float32) + self._initial_state = (tf.contrib.rnn.LSTMStateTuple(h=h, c=c),) + outputs, h, c = self._cell(inputs, h, c, self._rnn_params, is_training) + outputs = tf.transpose(outputs, [1, 0, 2]) + outputs = tf.reshape(outputs, [-1, config.hidden_size]) + return outputs, (tf.contrib.rnn.LSTMStateTuple(h=h, c=c),) + + def _get_lstm_cell(self, config, is_training): + if config.rnn_mode == BASIC: + return tf.contrib.rnn.BasicLSTMCell( + config.hidden_size, forget_bias=0.0, state_is_tuple=True, + reuse=not is_training) + if config.rnn_mode == BLOCK: + return tf.contrib.rnn.LSTMBlockCell( + config.hidden_size, forget_bias=0.0) + raise ValueError("rnn_mode %s not supported" % config.rnn_mode) + + def _build_rnn_graph_lstm(self, inputs, config, is_training): + """Build the inference graph using canonical LSTM cells without Wrapper.""" + init_sates = [] + final_states = [] + with tf.variable_scope("RNN", reuse=not is_training): + for l in range(config.num_layers): + with tf.variable_scope("layer_%d" % l): + cell = self._get_lstm_cell(config, is_training) + initial_state = cell.zero_state(self.batch_size, data_type()) + init_sates.append(initial_state) + state = init_sates[-1] + if is_training and config.keep_prob < 1: + inputs = tf.nn.dropout(inputs, config.keep_prob) + inputs = tf.unstack(inputs, num=self.num_steps, axis=1) + outputs, state = tf.contrib.rnn.static_rnn(cell, inputs, + initial_state=init_sates[-1]) + final_states.append(state) + outputs = [tf.expand_dims(output, 1) for output in outputs] + outputs = tf.concat(outputs, 1) + inputs = outputs + outputs = tf.reshape(outputs, [-1, config.hidden_size]) + self._initial_state = tuple(init_sates) + state = tuple(final_states) + if is_training and config.keep_prob < 1: + outputs = tf.nn.dropout(outputs, config.keep_prob) + return outputs, state + + def assign_lr(self, session, lr_value): + session.run(self._lr_update, feed_dict={self._new_lr: lr_value}) + + @property + def input(self): + return self._input + + @property + def initial_state(self): + return self._initial_state + + @property + def cost(self): + return self._cost + + @property + def nll(self): + return self._nll + + @property + def recall_at_k(self): + return self._recall_at_k + + @property + def final_state(self): + return self._final_state + + @property + def lr(self): + return self._lr + + @property + def train_op(self): + return self._train_op + + @property + def initial_state_name(self): + return self._initial_state_name + + @property + def final_state_name(self): + return self._final_state_name class SmallConfig(object): - """Small config.""" - init_scale = 0.1 - learning_rate = 1.0 - max_grad_norm = 5 if FLAGS.max_grad_norm is None else FLAGS.max_grad_norm - num_layers = 2 - num_steps = 20 - hidden_size = 200 - max_epoch = 4 - max_max_epoch = 100 if FLAGS.max_max_epoch is None else FLAGS.max_max_epoch - keep_prob = 1.0 - lr_decay = 0.5 - batch_size = 20 - global vocab_size - rnn_mode = BLOCK + """Small config.""" + init_scale = 0.1 + learning_rate = 1.0 + max_grad_norm = 5 if FLAGS.max_grad_norm is None else FLAGS.max_grad_norm + num_layers = 2 + num_steps = 20 + hidden_size = 200 + max_epoch = 4 + max_max_epoch = 100 if FLAGS.max_max_epoch is None else FLAGS.max_max_epoch + keep_prob = 1.0 + lr_decay = 0.5 + batch_size = 20 + global vocab_size + rnn_mode = BLOCK class MediumConfig(SmallConfig): - """Medium config.""" - init_scale = 0.05 - max_grad_norm = 5 if FLAGS.max_grad_norm is None else FLAGS.max_grad_norm - num_layers = 2 - num_steps = 35 - hidden_size = 650 - max_epoch = 6 - max_max_epoch = 39 if FLAGS.max_max_epoch is None else FLAGS.max_max_epoch - keep_prob = 0.5 - lr_decay = 0.8 - batch_size = 20 - global vocab_size - rnn_mode = BLOCK + """Medium config.""" + init_scale = 0.05 + max_grad_norm = 5 if FLAGS.max_grad_norm is None else FLAGS.max_grad_norm + num_layers = 2 + num_steps = 35 + hidden_size = 650 + max_epoch = 6 + max_max_epoch = 39 if FLAGS.max_max_epoch is None else FLAGS.max_max_epoch + keep_prob = 0.5 + lr_decay = 0.8 + batch_size = 20 + global vocab_size + rnn_mode = BLOCK class LargeConfig(SmallConfig): - """Large config.""" - init_scale = 0.04 - max_grad_norm = 10 if FLAGS.max_grad_norm is None else FLAGS.max_grad_norm - num_layers = 2 - num_steps = 35 - hidden_size = 1500 - max_epoch = 14 - max_max_epoch = 55 if FLAGS.max_max_epoch is None else FLAGS.max_max_epoch - keep_prob = 0.35 - lr_decay = 1 / 1.15 if (FLAGS.lr_decay is None or FLAGS.lr_decay < 0) else ( - FLAGS.lr_decay) - batch_size = 20 - global vocab_size - rnn_mode = BLOCK + """Large config.""" + init_scale = 0.04 + max_grad_norm = 10 if FLAGS.max_grad_norm is None else FLAGS.max_grad_norm + num_layers = 2 + num_steps = 35 + hidden_size = 1500 + max_epoch = 14 + max_max_epoch = 55 if FLAGS.max_max_epoch is None else FLAGS.max_max_epoch + keep_prob = 0.35 + lr_decay = 1 / 1.15 if (FLAGS.lr_decay is None or FLAGS.lr_decay < 0) else ( + FLAGS.lr_decay) + batch_size = 20 + global vocab_size + rnn_mode = BLOCK def run_epoch(session, model, eval_op=None, verbose=False, mode=TRAIN_MODE): - """Runs the model on the given data.""" - start_time = time.time() - costs = 0.0 - nlls = 0.0 - recalls_at_k = 0.0 - iters = 0 - state = session.run(model.initial_state) - - fetches = { - "cost": model.cost, - "nll": model.nll, - "recall_at_k": model.recall_at_k, - "final_state": model.final_state - } - - if eval_op is not None: - fetches["eval_op"] = eval_op - - for step in range(model.input.epoch_size): - feed_dict = {} - for i, (c, h) in enumerate(model.initial_state): - feed_dict[c] = state[i].c - feed_dict[h] = state[i].h - - vals = session.run(fetches, feed_dict) - cost = vals["cost"] - nll = vals["nll"] - recall_at_k = vals["recall_at_k"] - state = vals["final_state"] - - costs += cost - nlls += nll - recalls_at_k += recall_at_k - iters += model.input.num_steps - - if verbose and step % (model.input.epoch_size // 10) == 10: - print("%.3f cost %.3f perplexity: %.3f recall@%d: %.3f speed: %.0f wps" % - (step*1./model.input.epoch_size, costs / iters, np.exp(nlls/iters), - FLAGS.eval_topk, recalls_at_k / iters, - iters * model.input.batch_size / (time.time() - start_time))) - - return np.exp(nlls / iters), recalls_at_k / iters + """Runs the model on the given data.""" + start_time = time.time() + costs = 0.0 + nlls = 0.0 + recalls_at_k = 0.0 + iters = 0 + state = session.run(model.initial_state) + + fetches = { + "cost": model.cost, + "nll": model.nll, + "recall_at_k": model.recall_at_k, + "final_state": model.final_state + } + + if eval_op is not None: + fetches["eval_op"] = eval_op + + for step in range(model.input.epoch_size): + feed_dict = {} + for i, (c, h) in enumerate(model.initial_state): + feed_dict[c] = state[i].c + feed_dict[h] = state[i].h + + vals = session.run(fetches, feed_dict) + cost = vals["cost"] + nll = vals["nll"] + recall_at_k = vals["recall_at_k"] + state = vals["final_state"] + + costs += cost + nlls += nll + recalls_at_k += recall_at_k + iters += model.input.num_steps + + if verbose and step % (model.input.epoch_size // 10) == 10: + print("%.3f cost %.3f perplexity: %.3f recall@%d: %.3f speed: %.0f wps" % + (step*1./model.input.epoch_size, costs / iters, np.exp(nlls/iters), + FLAGS.eval_topk, recalls_at_k / iters, + iters * model.input.batch_size / (time.time() - start_time))) + + return np.exp(nlls / iters), recalls_at_k / iters def get_config(verbose=False): - """Get model config.""" - config = None - if FLAGS.model == "small": - config = SmallConfig() - elif FLAGS.model == "medium": - config = MediumConfig() - elif FLAGS.model == "large": - config = LargeConfig() - else: - raise ValueError("Invalid model: %s", FLAGS.model) - if FLAGS.rnn_mode: - config.rnn_mode = FLAGS.rnn_mode - if verbose: - config_attrs = [a for a in inspect.getmembers(config) if not ( - a[0].startswith('__') and a[0].endswith('__'))] - print(config_attrs) - return config + """Get model config.""" + config = None + if FLAGS.model == "small": + config = SmallConfig() + elif FLAGS.model == "medium": + config = MediumConfig() + elif FLAGS.model == "large": + config = LargeConfig() + else: + raise ValueError("Invalid model: %s", FLAGS.model) + if FLAGS.rnn_mode: + config.rnn_mode = FLAGS.rnn_mode + if verbose: + config_attrs = [a for a in inspect.getmembers(config) if not ( + a[0].startswith('__') and a[0].endswith('__'))] + print(config_attrs) + return config def main(_): - if not FLAGS.data_path: - raise ValueError("Must set --data_path to PTB data directory") - - raw_data = reader.ptb_raw_data(FLAGS.dataset, FLAGS.data_path, True) - train_data, valid_data, test_data, _vocab_size, _id2word = raw_data - global vocab_size - vocab_size = _vocab_size - - config = get_config(verbose=True) - eval_config = get_config() - eval_config.batch_size = 1 - eval_config.num_steps = 1 - - with tf.Graph().as_default(): - initializer = tf.random_uniform_initializer(-config.init_scale, - config.init_scale) - - with tf.name_scope("Train"): - train_input = PTBInput(config=config, data=train_data, name="TrainInput") - with tf.variable_scope("Model", reuse=None, initializer=initializer): - m = PTBModel(is_training=True, - config=config, - input_=train_input, - vocab_size=vocab_size) - tf.summary.scalar("Training Loss", m.cost) - tf.summary.scalar("Learning Rate", m.lr) - - with tf.name_scope("Valid"): - valid_input = PTBInput(config=config, data=valid_data, name="ValidInput") - with tf.variable_scope("Model", reuse=True, initializer=initializer): - mvalid = PTBModel(is_training=False, - config=config, - input_=valid_input, - vocab_size=vocab_size) - tf.summary.scalar("Validation Loss", mvalid.cost) - - with tf.name_scope("Test"): - test_input = PTBInput( - config=eval_config, data=test_data, name="TestInput") - with tf.variable_scope("Model", reuse=True, initializer=initializer): - mtest = PTBModel(is_training=False, - config=eval_config, - input_=test_input, - vocab_size=vocab_size) - - models = {"Train": m, "Valid": mvalid, "Test": mtest} - - print_at_beginning(config) - sv = tf.train.Supervisor(logdir=FLAGS.save_path, - save_model_secs=FLAGS.save_model_secs, - save_summaries_secs=10) - config_proto = tf.ConfigProto(allow_soft_placement=True) - config_proto.gpu_options.allow_growth = True - with sv.managed_session(config=config_proto) as session: - for i in range(config.max_max_epoch): - lr_decay = config.lr_decay ** max(i + 1 - config.max_epoch, 0.0) - m.assign_lr(session, config.learning_rate * lr_decay) - - print("Epoch: %d Learning rate: %.3f" % (i + 1, session.run(m.lr))) - train_perplexity, train_recall_at_k = run_epoch( - session, m, eval_op=m.train_op, verbose=True, mode=TRAIN_MODE) - print("Epoch: %d Train Perplexity: %.3f, recall@%d: %.3f" % ( - i + 1, train_perplexity, FLAGS.eval_topk, train_recall_at_k)) - valid_perplexity, valid_recall_at_k = run_epoch( - session, mvalid, mode=VALID_MODE) - print("Epoch: %d Valid Perplexity: %.3f, recall@%d: %.3f" % ( - i + 1, valid_perplexity, FLAGS.eval_topk, valid_recall_at_k)) - - test_perplexity, test_recall_at_k = run_epoch( - session, mtest, mode=TEST_MODE) - print("Test Perplexity: %.3f, recall@%d: %.3f" % ( - test_perplexity, FLAGS.eval_topk, test_recall_at_k)) - - if FLAGS.save_path and sv.saver is not None: - print("Saving model to %s." % FLAGS.save_path) - sv.saver.save(session, - os.path.join(FLAGS.save_path, "model"), - global_step=sv.global_step) + if not FLAGS.data_path: + raise ValueError("Must set --data_path to PTB data directory") + + raw_data = reader.ptb_raw_data(FLAGS.dataset, FLAGS.data_path, True) + train_data, valid_data, test_data, _vocab_size, _id2word = raw_data + global vocab_size + vocab_size = _vocab_size + + config = get_config(verbose=True) + eval_config = get_config() + eval_config.batch_size = 1 + eval_config.num_steps = 1 + + with tf.Graph().as_default(): + initializer = tf.random_uniform_initializer(-config.init_scale, + config.init_scale) + + with tf.name_scope("Train"): + train_input = PTBInput(config=config, data=train_data, name="TrainInput") + with tf.variable_scope("Model", reuse=None, initializer=initializer): + m = PTBModel(is_training=True, + config=config, + input_=train_input, + vocab_size=vocab_size) + tf.summary.scalar("Training Loss", m.cost) + tf.summary.scalar("Learning Rate", m.lr) + + with tf.name_scope("Valid"): + valid_input = PTBInput(config=config, data=valid_data, name="ValidInput") + with tf.variable_scope("Model", reuse=True, initializer=initializer): + mvalid = PTBModel(is_training=False, + config=config, + input_=valid_input, + vocab_size=vocab_size) + tf.summary.scalar("Validation Loss", mvalid.cost) + + with tf.name_scope("Test"): + test_input = PTBInput( + config=eval_config, data=test_data, name="TestInput") + with tf.variable_scope("Model", reuse=True, initializer=initializer): + mtest = PTBModel(is_training=False, + config=eval_config, + input_=test_input, + vocab_size=vocab_size) + + models = {"Train": m, "Valid": mvalid, "Test": mtest} + + print_at_beginning(config) + sv = tf.train.Supervisor(logdir=FLAGS.save_path, + save_model_secs=FLAGS.save_model_secs, + save_summaries_secs=10) + config_proto = tf.ConfigProto(allow_soft_placement=True) + config_proto.gpu_options.allow_growth = True + with sv.managed_session(config=config_proto) as session: + for i in range(config.max_max_epoch): + lr_decay = config.lr_decay ** max(i + 1 - config.max_epoch, 0.0) + m.assign_lr(session, config.learning_rate * lr_decay) + + print("Epoch: %d Learning rate: %.3f" % (i + 1, session.run(m.lr))) + train_perplexity, train_recall_at_k = run_epoch( + session, m, eval_op=m.train_op, verbose=True, mode=TRAIN_MODE) + print("Epoch: %d Train Perplexity: %.3f, recall@%d: %.3f" % ( + i + 1, train_perplexity, FLAGS.eval_topk, train_recall_at_k)) + valid_perplexity, valid_recall_at_k = run_epoch( + session, mvalid, mode=VALID_MODE) + print("Epoch: %d Valid Perplexity: %.3f, recall@%d: %.3f" % ( + i + 1, valid_perplexity, FLAGS.eval_topk, valid_recall_at_k)) + + test_perplexity, test_recall_at_k = run_epoch( + session, mtest, mode=TEST_MODE) + print("Test Perplexity: %.3f, recall@%d: %.3f" % ( + test_perplexity, FLAGS.eval_topk, test_recall_at_k)) + + if FLAGS.save_path and sv.saver is not None: + print("Saving model to %s." % FLAGS.save_path) + sv.saver.save(session, + os.path.join(FLAGS.save_path, "model"), + global_step=sv.global_step) if __name__ == "__main__": - tf.app.run() + tf.app.run() diff --git a/lm/reader.py b/lm/reader.py index 85d1c3d..6558bf2 100644 --- a/lm/reader.py +++ b/lm/reader.py @@ -27,121 +27,122 @@ Py3 = sys.version_info[0] == 3 + def _read_words(filename): - with tf.gfile.GFile(filename, "r") as f: - if Py3: - return f.read().replace("\n", "").split() - else: - return f.read().decode("utf-8").replace("\n", "").split() + with tf.gfile.GFile(filename, "r") as f: + if Py3: + return f.read().replace("\n", "").split() + else: + return f.read().decode("utf-8").replace("\n", "").split() def _build_vocab(filename): - data = _read_words(filename) + data = _read_words(filename) - counter = collections.Counter(data) - count_pairs = sorted(counter.items(), key=lambda x: (-x[1], x[0])) + counter = collections.Counter(data) + count_pairs = sorted(counter.items(), key=lambda x: (-x[1], x[0])) - words, _ = list(zip(*count_pairs)) - word_to_id = dict(zip(words, range(len(words)))) + words, _ = list(zip(*count_pairs)) + word_to_id = dict(zip(words, range(len(words)))) - return word_to_id + return word_to_id def _file_to_word_ids(filename, word_to_id): - data = _read_words(filename) - return [word_to_id[word] for word in data if word in word_to_id] + data = _read_words(filename) + return [word_to_id[word] for word in data if word in word_to_id] def ptb_raw_data(dataset="ptb", data_path=None, return_id2word=False): - """Load PTB raw data from data directory "data_path". - - Reads PTB text files, converts strings to integer ids, - and performs mini-batching of the inputs. - - The PTB dataset comes from Tomas Mikolov's webpage: - - http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz - - Args: - data_path: string path to the directory where simple-examples.tgz has - been extracted. - - Returns: - tuple (train_data, valid_data, test_data, vocabulary) - where each of the data objects can be passed to PTBIterator. - """ - - if dataset == "ptb": - train_path = os.path.join(data_path, "ptb.train.txt") - valid_path = os.path.join(data_path, "ptb.valid.txt") - test_path = os.path.join(data_path, "ptb.test.txt") - elif dataset == "text8": - train_path = os.path.join(data_path, "text8.train.txt") - valid_path = os.path.join(data_path, "text8.valid.txt") - test_path = os.path.join(data_path, "text8.valid.txt") - elif dataset == "wikitext-2": - train_path = os.path.join(data_path, "wiki.train.tokens") - valid_path = os.path.join(data_path, "wiki.valid.tokens") - test_path = os.path.join(data_path, "wiki.test.tokens") - else: - raise ValueError("Unknown dataset {}".format(dataset)) - - word_to_id = _build_vocab(train_path) - train_data = _file_to_word_ids(train_path, word_to_id) - valid_data = _file_to_word_ids(valid_path, word_to_id) - test_data = _file_to_word_ids(test_path, word_to_id) - vocabulary = len(word_to_id) - if return_id2word: - id2word = dict([(_id, word) for word, _id in word_to_id.items()]) - id2word_list = [ - id2word[i] if i == _id else -1 for i, _id in enumerate(sorted(id2word))] - if len(id2word) != len(id2word_list): - raise ValueError("Something is wrong..") - return train_data, valid_data, test_data, vocabulary, id2word_list - else: - return train_data, valid_data, test_data, vocabulary + """Load PTB raw data from data directory "data_path". + + Reads PTB text files, converts strings to integer ids, + and performs mini-batching of the inputs. + + The PTB dataset comes from Tomas Mikolov's webpage: + + http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz + + Args: + data_path: string path to the directory where simple-examples.tgz has + been extracted. + + Returns: + tuple (train_data, valid_data, test_data, vocabulary) + where each of the data objects can be passed to PTBIterator. + """ + + if dataset == "ptb": + train_path = os.path.join(data_path, "ptb.train.txt") + valid_path = os.path.join(data_path, "ptb.valid.txt") + test_path = os.path.join(data_path, "ptb.test.txt") + elif dataset == "text8": + train_path = os.path.join(data_path, "text8.train.txt") + valid_path = os.path.join(data_path, "text8.valid.txt") + test_path = os.path.join(data_path, "text8.valid.txt") + elif dataset == "wikitext-2": + train_path = os.path.join(data_path, "wiki.train.tokens") + valid_path = os.path.join(data_path, "wiki.valid.tokens") + test_path = os.path.join(data_path, "wiki.test.tokens") + else: + raise ValueError("Unknown dataset {}".format(dataset)) + + word_to_id = _build_vocab(train_path) + train_data = _file_to_word_ids(train_path, word_to_id) + valid_data = _file_to_word_ids(valid_path, word_to_id) + test_data = _file_to_word_ids(test_path, word_to_id) + vocabulary = len(word_to_id) + if return_id2word: + id2word = dict([(_id, word) for word, _id in word_to_id.items()]) + id2word_list = [ + id2word[i] if i == _id else -1 for i, _id in enumerate(sorted(id2word))] + if len(id2word) != len(id2word_list): + raise ValueError("Something is wrong..") + return train_data, valid_data, test_data, vocabulary, id2word_list + else: + return train_data, valid_data, test_data, vocabulary def ptb_producer(raw_data, batch_size, num_steps, name=None): - """Iterate on the raw PTB data. - - This chunks up raw_data into batches of examples and returns Tensors that - are drawn from these batches. - - Args: - raw_data: one of the raw data outputs from ptb_raw_data. - batch_size: int, the batch size. - num_steps: int, the number of unrolls. - name: the name of this operation (optional). - - Returns: - A pair of Tensors, each shaped [batch_size, num_steps]. The second element - of the tuple is the same data time-shifted to the right by one. - - Raises: - tf.errors.InvalidArgumentError: if batch_size or num_steps are too high. - """ - with tf.name_scope(name, "PTBProducer", [raw_data, batch_size, num_steps]): - raw_data = tf.convert_to_tensor(raw_data, name="raw_data", dtype=tf.int32) - - data_len = tf.size(raw_data) - batch_len = data_len // batch_size - data = tf.reshape(raw_data[0 : batch_size * batch_len], - [batch_size, batch_len]) - - epoch_size = (batch_len - 1) // num_steps - assertion = tf.assert_positive( - epoch_size, - message="epoch_size == 0, decrease batch_size or num_steps") - with tf.control_dependencies([assertion]): - epoch_size = tf.identity(epoch_size, name="epoch_size") - - i = tf.train.range_input_producer(epoch_size, shuffle=False).dequeue() - x = tf.strided_slice(data, [0, i * num_steps], - [batch_size, (i + 1) * num_steps]) - x.set_shape([batch_size, num_steps]) - y = tf.strided_slice(data, [0, i * num_steps + 1], - [batch_size, (i + 1) * num_steps + 1]) - y.set_shape([batch_size, num_steps]) - - return x, y + """Iterate on the raw PTB data. + + This chunks up raw_data into batches of examples and returns Tensors that + are drawn from these batches. + + Args: + raw_data: one of the raw data outputs from ptb_raw_data. + batch_size: int, the batch size. + num_steps: int, the number of unrolls. + name: the name of this operation (optional). + + Returns: + A pair of Tensors, each shaped [batch_size, num_steps]. The second element + of the tuple is the same data time-shifted to the right by one. + + Raises: + tf.errors.InvalidArgumentError: if batch_size or num_steps are too high. + """ + with tf.name_scope(name, "PTBProducer", [raw_data, batch_size, num_steps]): + raw_data = tf.convert_to_tensor(raw_data, name="raw_data", dtype=tf.int32) + + data_len = tf.size(raw_data) + batch_len = data_len // batch_size + data = tf.reshape(raw_data[0: batch_size * batch_len], + [batch_size, batch_len]) + + epoch_size = (batch_len - 1) // num_steps + assertion = tf.assert_positive( + epoch_size, + message="epoch_size == 0, decrease batch_size or num_steps") + with tf.control_dependencies([assertion]): + epoch_size = tf.identity(epoch_size, name="epoch_size") + + i = tf.train.range_input_producer(epoch_size, shuffle=False).dequeue() + x = tf.strided_slice(data, [0, i * num_steps], + [batch_size, (i + 1) * num_steps]) + x.set_shape([batch_size, num_steps]) + y = tf.strided_slice(data, [0, i * num_steps + 1], + [batch_size, (i + 1) * num_steps + 1]) + y.set_shape([batch_size, num_steps]) + + return x, y diff --git a/lm/util.py b/lm/util.py index a66e973..a36a03f 100644 --- a/lm/util.py +++ b/lm/util.py @@ -33,235 +33,236 @@ def export_state_tuples(state_tuples, name): - for state_tuple in state_tuples: - tf.add_to_collection(name, state_tuple.c) - tf.add_to_collection(name, state_tuple.h) + for state_tuple in state_tuples: + tf.add_to_collection(name, state_tuple.c) + tf.add_to_collection(name, state_tuple.h) def import_state_tuples(state_tuples, name, num_replicas): - restored = [] - for i in range(len(state_tuples) * num_replicas): - c = tf.get_collection_ref(name)[2 * i + 0] - h = tf.get_collection_ref(name)[2 * i + 1] - restored.append(tf.contrib.rnn.LSTMStateTuple(c, h)) - return tuple(restored) + restored = [] + for i in range(len(state_tuples) * num_replicas): + c = tf.get_collection_ref(name)[2 * i + 0] + h = tf.get_collection_ref(name)[2 * i + 1] + restored.append(tf.contrib.rnn.LSTMStateTuple(c, h)) + return tuple(restored) def with_prefix(prefix, name): - """Adds prefix to name.""" - return "/".join((prefix, name)) + """Adds prefix to name.""" + return "/".join((prefix, name)) def with_autoparallel_prefix(replica_id, name): - return with_prefix("AutoParallel-Replica-%d" % replica_id, name) + return with_prefix("AutoParallel-Replica-%d" % replica_id, name) class UpdateCollection(object): - """Update collection info in MetaGraphDef for AutoParallel optimizer.""" - - def __init__(self, metagraph, model): - self._metagraph = metagraph - self.replicate_states(model.initial_state_name) - self.replicate_states(model.final_state_name) - self.update_snapshot_name("variables") - self.update_snapshot_name("trainable_variables") - - def update_snapshot_name(self, var_coll_name): - var_list = self._metagraph.collection_def[var_coll_name] - for i, value in enumerate(var_list.bytes_list.value): - var_def = variable_pb2.VariableDef() - var_def.ParseFromString(value) - # Somehow node Model/global_step/read doesn't have any fanout and seems to - # be only used for snapshot; this is different from all other variables. - if var_def.snapshot_name != "Model/global_step/read:0": - var_def.snapshot_name = with_autoparallel_prefix( - 0, var_def.snapshot_name) - value = var_def.SerializeToString() - var_list.bytes_list.value[i] = value - - def replicate_states(self, state_coll_name): - state_list = self._metagraph.collection_def[state_coll_name] - num_states = len(state_list.node_list.value) - for replica_id in range(1, FLAGS.num_gpus): - for i in range(num_states): - state_list.node_list.value.append(state_list.node_list.value[i]) - for replica_id in range(FLAGS.num_gpus): - for i in range(num_states): - index = replica_id * num_states + i - state_list.node_list.value[index] = with_autoparallel_prefix( - replica_id, state_list.node_list.value[index]) + """Update collection info in MetaGraphDef for AutoParallel optimizer.""" + + def __init__(self, metagraph, model): + self._metagraph = metagraph + self.replicate_states(model.initial_state_name) + self.replicate_states(model.final_state_name) + self.update_snapshot_name("variables") + self.update_snapshot_name("trainable_variables") + + def update_snapshot_name(self, var_coll_name): + var_list = self._metagraph.collection_def[var_coll_name] + for i, value in enumerate(var_list.bytes_list.value): + var_def = variable_pb2.VariableDef() + var_def.ParseFromString(value) + # Somehow node Model/global_step/read doesn't have any fanout and seems to + # be only used for snapshot; this is different from all other variables. + if var_def.snapshot_name != "Model/global_step/read:0": + var_def.snapshot_name = with_autoparallel_prefix( + 0, var_def.snapshot_name) + value = var_def.SerializeToString() + var_list.bytes_list.value[i] = value + + def replicate_states(self, state_coll_name): + state_list = self._metagraph.collection_def[state_coll_name] + num_states = len(state_list.node_list.value) + for replica_id in range(1, FLAGS.num_gpus): + for i in range(num_states): + state_list.node_list.value.append(state_list.node_list.value[i]) + for replica_id in range(FLAGS.num_gpus): + for i in range(num_states): + index = replica_id * num_states + i + state_list.node_list.value[index] = with_autoparallel_prefix( + replica_id, state_list.node_list.value[index]) def auto_parallel(metagraph, model): - from tensorflow.python.grappler import tf_optimizer - rewriter_config = rewriter_config_pb2.RewriterConfig() - rewriter_config.optimizers.append("autoparallel") - rewriter_config.auto_parallel.enable = True - rewriter_config.auto_parallel.num_replicas = FLAGS.num_gpus - optimized_graph = tf_optimizer.OptimizeGraph(rewriter_config, metagraph) - metagraph.graph_def.CopyFrom(optimized_graph) - UpdateCollection(metagraph, model) + from tensorflow.python.grappler import tf_optimizer + rewriter_config = rewriter_config_pb2.RewriterConfig() + rewriter_config.optimizers.append("autoparallel") + rewriter_config.auto_parallel.enable = True + rewriter_config.auto_parallel.num_replicas = FLAGS.num_gpus + optimized_graph = tf_optimizer.OptimizeGraph(rewriter_config, metagraph) + metagraph.graph_def.CopyFrom(optimized_graph) + UpdateCollection(metagraph, model) def safer_log(x, eps=eps_micro): - """Avoid nan when x is zero by adding small eps. - - Note that if x.dtype=tf.float16, \forall eps, eps < 3e-8, is equal to zero. - """ - return tf.log(x + eps) + """Avoid nan when x is zero by adding small eps. + + Note that if x.dtype=tf.float16, \forall eps, eps < 3e-8, is equal to zero. + """ + return tf.log(x + eps) def get_activation(name): - """Returns activation function given name.""" - name = name.lower() - if name == "relu": - return tf.nn.relu - elif name == "sigmoid": - return tf.nn.sigmoid - elif name == "tanh": - return tf.nn.sigmoid - elif name == "elu": - return tf.nn.elu - elif name == "linear": - return lambda x: x - else: - raise ValueError("Unknown activation name {}".format(name)) - - return name + """Returns activation function given name.""" + name = name.lower() + if name == "relu": + return tf.nn.relu + elif name == "sigmoid": + return tf.nn.sigmoid + elif name == "tanh": + return tf.nn.sigmoid + elif name == "elu": + return tf.nn.elu + elif name == "linear": + return lambda x: x + else: + raise ValueError("Unknown activation name {}".format(name)) + + return name def hparam_fn(hparams, prefix=None): - """Returns a function to get hparam with prefix in the name.""" - if prefix is None or prefix == "": - prefix = "" - elif isinstance(prefix, str): - prefix += "_" - else: - raise ValueError("prefix {} is invalid".format(prefix)) - get_hparam = lambda name: getattr(hparams, prefix + name) - return get_hparam + """Returns a function to get hparam with prefix in the name.""" + if prefix is None or prefix == "": + prefix = "" + elif isinstance(prefix, str): + prefix += "_" + else: + raise ValueError("prefix {} is invalid".format(prefix)) + + def get_hparam(name): return getattr(hparams, prefix + name) + return get_hparam def filter_activation_fn(): - """Returns a activation if actv is an string name, otherwise input.""" - filter_actv = lambda actv: ( - get_activation(actv) if isinstance(actv, str) else actv) - return filter_actv + """Returns a activation if actv is an string name, otherwise input.""" + def filter_actv(actv): return ( + get_activation(actv) if isinstance(actv, str) else actv) + return filter_actv def get_optimizer(name): - name = name.lower() - if name == "sgd": - optimizer = tf.train.GradientDescentOptimizer - elif name == "momentum": - optimizer = partial(tf.train.MomentumOptimizer, - momentum=0.05, use_nesterov=True) - elif name == "adam": - optimizer = tf.train.AdamOptimizer - # optimizer = partial(tf.train.AdamOptimizer, beta1=0.5, beta2=0.9) - elif name == "lazy_adam": - optimizer = tf.contrib.opt.LazyAdamOptimizer - # optimizer = partial(tf.contrib.opt.LazyAdamOptimizer, beta1=0.5, beta2=0.9) - elif name == "adagrad": - optimizer = tf.train.AdagradOptimizer - elif name == "rmsprop": - optimizer = tf.train.RMSPropOptimizer - else: - raise ValueError("Unknown optimizer name {}.".format(name)) - - return optimizer + name = name.lower() + if name == "sgd": + optimizer = tf.train.GradientDescentOptimizer + elif name == "momentum": + optimizer = partial(tf.train.MomentumOptimizer, + momentum=0.05, use_nesterov=True) + elif name == "adam": + optimizer = tf.train.AdamOptimizer + # optimizer = partial(tf.train.AdamOptimizer, beta1=0.5, beta2=0.9) + elif name == "lazy_adam": + optimizer = tf.contrib.opt.LazyAdamOptimizer + # optimizer = partial(tf.contrib.opt.LazyAdamOptimizer, beta1=0.5, beta2=0.9) + elif name == "adagrad": + optimizer = tf.train.AdagradOptimizer + elif name == "rmsprop": + optimizer = tf.train.RMSPropOptimizer + else: + raise ValueError("Unknown optimizer name {}.".format(name)) + + return optimizer def replace_list_element(data_list, x, y): - return [y if each == x else each for each in data_list] + return [y if each == x else each for each in data_list] def get_parameter_count(excludings=None, display_count=True): - trainables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) - count = 0 - for var in trainables: - ignored = False - if excludings is not None: - for excluding in excludings: - if var.name.find(excluding) >= 0: - ignored = True - break - if ignored: - continue - if var.shape == tf.TensorShape(None): - tf.logging.warn("var {} has unknown shape and it is not counted.".format( - var.name)) - continue - if var.shape.as_list() == []: - count_ = 1 - else: - count_ = reduce(lambda x, y: x * y, var.shape.as_list()) - count += count_ - if display_count: - print("{0:80} {1}".format( - var.name, count_)) - return count + trainables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) + count = 0 + for var in trainables: + ignored = False + if excludings is not None: + for excluding in excludings: + if var.name.find(excluding) >= 0: + ignored = True + break + if ignored: + continue + if var.shape == tf.TensorShape(None): + tf.logging.warn("var {} has unknown shape and it is not counted.".format( + var.name)) + continue + if var.shape.as_list() == []: + count_ = 1 + else: + count_ = reduce(lambda x, y: x * y, var.shape.as_list()) + count += count_ + if display_count: + print("{0:80} {1}".format( + var.name, count_)) + return count def save_emb_visualize_meta(save_path, emb_var_names, label_lists, metadata_names=None): - """Save meta information about the embedding visualization in tensorboard. - - Args: - save_path: a `string` specifying the save location. - emb_var_names: a `list` containing variable names. - label_lists: a `list` of lists of labels, each label list corresponds to an - emb_var_name. - metadata_names: a `list` of file names for metadata, if not specify, will - use emb_var_names. - """ - if not isinstance(emb_var_names, list): - raise ValueError("emb_var_names must be a list of var names.") - if not isinstance(label_lists, list) or not isinstance(label_lists[0], list): - raise ValueError("label_lists must be a list of label lists.") - - if metadata_names is None: - metadata_names = emb_var_names - - config = projector.ProjectorConfig() - for emb_var_name, metadata_name, labels in zip( - emb_var_names, metadata_names, label_lists): - filename = "metadata-{}.tsv".format(metadata_name) - metadata_path = os.path.join(save_path, filename) - with open(metadata_path, "w") as fp: - for label in labels: - fp.write("{}\n".format(label)) - embedding = config.embeddings.add() - embedding.tensor_name = emb_var_name - embedding.metadata_path = metadata_path - summary_writer = tf.summary.FileWriter(save_path) - projector.visualize_embeddings(summary_writer, config) + """Save meta information about the embedding visualization in tensorboard. + + Args: + save_path: a `string` specifying the save location. + emb_var_names: a `list` containing variable names. + label_lists: a `list` of lists of labels, each label list corresponds to an + emb_var_name. + metadata_names: a `list` of file names for metadata, if not specify, will + use emb_var_names. + """ + if not isinstance(emb_var_names, list): + raise ValueError("emb_var_names must be a list of var names.") + if not isinstance(label_lists, list) or not isinstance(label_lists[0], list): + raise ValueError("label_lists must be a list of label lists.") + + if metadata_names is None: + metadata_names = emb_var_names + + config = projector.ProjectorConfig() + for emb_var_name, metadata_name, labels in zip( + emb_var_names, metadata_names, label_lists): + filename = "metadata-{}.tsv".format(metadata_name) + metadata_path = os.path.join(save_path, filename) + with open(metadata_path, "w") as fp: + for label in labels: + fp.write("{}\n".format(label)) + embedding = config.embeddings.add() + embedding.tensor_name = emb_var_name + embedding.metadata_path = metadata_path + summary_writer = tf.summary.FileWriter(save_path) + projector.visualize_embeddings(summary_writer, config) def create_labels_based_on_codes(codes, K): - """Take codes and produce per-axis based and prefix based labels. - - Args: - codes: a `np.ndarray` of size (N, D) where N data points with D-dimensional - discrete code. - K: a `int` specifying the cardinality of the code in each dimension. - - Returns: - label_lists: a list of labels. - """ - N, D = codes.shape - label_lists = [] - - # Create per-axis labels. - for i in range(D): - label_lists.append(codes[:, i].tolist()) - - # Create prefix labels. - buffer_basis = 0 - for i in range(D): - buffer_basis = buffer_basis * K + codes[:, i] - label_lists.append(buffer_basis.tolist()) - - return label_lists + """Take codes and produce per-axis based and prefix based labels. + + Args: + codes: a `np.ndarray` of size (N, D) where N data points with D-dimensional + discrete code. + K: a `int` specifying the cardinality of the code in each dimension. + + Returns: + label_lists: a list of labels. + """ + N, D = codes.shape + label_lists = [] + + # Create per-axis labels. + for i in range(D): + label_lists.append(codes[:, i].tolist()) + + # Create prefix labels. + buffer_basis = 0 + for i in range(D): + buffer_basis = buffer_basis * K + codes[:, i] + label_lists.append(buffer_basis.tolist()) + + return label_lists diff --git a/nmt/attention_model.py b/nmt/attention_model.py index d262b8e..79caa71 100644 --- a/nmt/attention_model.py +++ b/nmt/attention_model.py @@ -26,169 +26,169 @@ class AttentionModel(model.Model): - """Sequence-to-sequence dynamic model with attention. - - This class implements a multi-layer recurrent neural network as encoder, - and an attention-based decoder. This is the same as the model described in - (Luong et al., EMNLP'2015) paper: https://arxiv.org/pdf/1508.04025v5.pdf. - This class also allows to use GRU cells in addition to LSTM cells with - support for dropout. - """ - - def __init__(self, - hparams, - mode, - iterator, - source_vocab_table, - target_vocab_table, - reverse_target_vocab_table=None, - scope=None, - extra_args=None): - self.has_attention = hparams.attention_architecture and hparams.attention - - # Set attention_mechanism_fn - if self.has_attention: - if extra_args and extra_args.attention_mechanism_fn: - self.attention_mechanism_fn = extra_args.attention_mechanism_fn - else: - self.attention_mechanism_fn = create_attention_mechanism - - super(AttentionModel, self).__init__( - hparams=hparams, - mode=mode, - iterator=iterator, - source_vocab_table=source_vocab_table, - target_vocab_table=target_vocab_table, - reverse_target_vocab_table=reverse_target_vocab_table, - scope=scope, - extra_args=extra_args) - - def _prepare_beam_search_decoder_inputs( - self, beam_width, memory, source_sequence_length, encoder_state): - memory = tf.contrib.seq2seq.tile_batch( - memory, multiplier=beam_width) - source_sequence_length = tf.contrib.seq2seq.tile_batch( - source_sequence_length, multiplier=beam_width) - encoder_state = tf.contrib.seq2seq.tile_batch( - encoder_state, multiplier=beam_width) - batch_size = self.batch_size * beam_width - return memory, source_sequence_length, encoder_state, batch_size - - def _build_decoder_cell(self, hparams, encoder_outputs, encoder_state, - source_sequence_length): - """Build a RNN cell with attention mechanism that can be used by decoder.""" - # No Attention - if not self.has_attention: - return super(AttentionModel, self)._build_decoder_cell( - hparams, encoder_outputs, encoder_state, source_sequence_length) - elif hparams.attention_architecture != "standard": - raise ValueError( - "Unknown attention architecture %s" % hparams.attention_architecture) - - num_units = hparams.num_units - num_layers = self.num_decoder_layers - num_residual_layers = self.num_decoder_residual_layers - infer_mode = hparams.infer_mode - - dtype = tf.float32 - - # Ensure memory is batch-major - if self.time_major: - memory = tf.transpose(encoder_outputs, [1, 0, 2]) - else: - memory = encoder_outputs - - if (self.mode == tf.contrib.learn.ModeKeys.INFER and - infer_mode == "beam_search"): - memory, source_sequence_length, encoder_state, batch_size = ( - self._prepare_beam_search_decoder_inputs( - hparams.beam_width, memory, source_sequence_length, - encoder_state)) - else: - batch_size = self.batch_size - - # Attention - attention_mechanism = self.attention_mechanism_fn( - hparams.attention, num_units, memory, source_sequence_length, self.mode) - - cell = model_helper.create_rnn_cell( - unit_type=hparams.unit_type, - num_units=num_units, - num_layers=num_layers, - num_residual_layers=num_residual_layers, - forget_bias=hparams.forget_bias, - dropout=hparams.dropout, - num_gpus=self.num_gpus, - mode=self.mode, - single_cell_fn=self.single_cell_fn) - - # Only generate alignment in greedy INFER mode. - alignment_history = (self.mode == tf.contrib.learn.ModeKeys.INFER and - infer_mode != "beam_search") - cell = tf.contrib.seq2seq.AttentionWrapper( - cell, - attention_mechanism, - attention_layer_size=num_units, - alignment_history=alignment_history, - output_attention=hparams.output_attention, - name="attention") - - # TODO(thangluong): do we need num_layers, num_gpus? - cell = tf.contrib.rnn.DeviceWrapper(cell, - model_helper.get_device_str( - num_layers - 1, self.num_gpus)) - - if hparams.pass_hidden_state: - decoder_initial_state = cell.zero_state(batch_size, dtype).clone( - cell_state=encoder_state) - else: - decoder_initial_state = cell.zero_state(batch_size, dtype) - - return cell, decoder_initial_state - - def _get_infer_summary(self, hparams): - if not self.has_attention or hparams.infer_mode == "beam_search": - return tf.no_op() - return _create_attention_images_summary(self.final_context_state) + """Sequence-to-sequence dynamic model with attention. + + This class implements a multi-layer recurrent neural network as encoder, + and an attention-based decoder. This is the same as the model described in + (Luong et al., EMNLP'2015) paper: https://arxiv.org/pdf/1508.04025v5.pdf. + This class also allows to use GRU cells in addition to LSTM cells with + support for dropout. + """ + + def __init__(self, + hparams, + mode, + iterator, + source_vocab_table, + target_vocab_table, + reverse_target_vocab_table=None, + scope=None, + extra_args=None): + self.has_attention = hparams.attention_architecture and hparams.attention + + # Set attention_mechanism_fn + if self.has_attention: + if extra_args and extra_args.attention_mechanism_fn: + self.attention_mechanism_fn = extra_args.attention_mechanism_fn + else: + self.attention_mechanism_fn = create_attention_mechanism + + super(AttentionModel, self).__init__( + hparams=hparams, + mode=mode, + iterator=iterator, + source_vocab_table=source_vocab_table, + target_vocab_table=target_vocab_table, + reverse_target_vocab_table=reverse_target_vocab_table, + scope=scope, + extra_args=extra_args) + + def _prepare_beam_search_decoder_inputs( + self, beam_width, memory, source_sequence_length, encoder_state): + memory = tf.contrib.seq2seq.tile_batch( + memory, multiplier=beam_width) + source_sequence_length = tf.contrib.seq2seq.tile_batch( + source_sequence_length, multiplier=beam_width) + encoder_state = tf.contrib.seq2seq.tile_batch( + encoder_state, multiplier=beam_width) + batch_size = self.batch_size * beam_width + return memory, source_sequence_length, encoder_state, batch_size + + def _build_decoder_cell(self, hparams, encoder_outputs, encoder_state, + source_sequence_length): + """Build a RNN cell with attention mechanism that can be used by decoder.""" + # No Attention + if not self.has_attention: + return super(AttentionModel, self)._build_decoder_cell( + hparams, encoder_outputs, encoder_state, source_sequence_length) + elif hparams.attention_architecture != "standard": + raise ValueError( + "Unknown attention architecture %s" % hparams.attention_architecture) + + num_units = hparams.num_units + num_layers = self.num_decoder_layers + num_residual_layers = self.num_decoder_residual_layers + infer_mode = hparams.infer_mode + + dtype = tf.float32 + + # Ensure memory is batch-major + if self.time_major: + memory = tf.transpose(encoder_outputs, [1, 0, 2]) + else: + memory = encoder_outputs + + if (self.mode == tf.contrib.learn.ModeKeys.INFER and + infer_mode == "beam_search"): + memory, source_sequence_length, encoder_state, batch_size = ( + self._prepare_beam_search_decoder_inputs( + hparams.beam_width, memory, source_sequence_length, + encoder_state)) + else: + batch_size = self.batch_size + + # Attention + attention_mechanism = self.attention_mechanism_fn( + hparams.attention, num_units, memory, source_sequence_length, self.mode) + + cell = model_helper.create_rnn_cell( + unit_type=hparams.unit_type, + num_units=num_units, + num_layers=num_layers, + num_residual_layers=num_residual_layers, + forget_bias=hparams.forget_bias, + dropout=hparams.dropout, + num_gpus=self.num_gpus, + mode=self.mode, + single_cell_fn=self.single_cell_fn) + + # Only generate alignment in greedy INFER mode. + alignment_history = (self.mode == tf.contrib.learn.ModeKeys.INFER and + infer_mode != "beam_search") + cell = tf.contrib.seq2seq.AttentionWrapper( + cell, + attention_mechanism, + attention_layer_size=num_units, + alignment_history=alignment_history, + output_attention=hparams.output_attention, + name="attention") + + # TODO(thangluong): do we need num_layers, num_gpus? + cell = tf.contrib.rnn.DeviceWrapper(cell, + model_helper.get_device_str( + num_layers - 1, self.num_gpus)) + + if hparams.pass_hidden_state: + decoder_initial_state = cell.zero_state(batch_size, dtype).clone( + cell_state=encoder_state) + else: + decoder_initial_state = cell.zero_state(batch_size, dtype) + + return cell, decoder_initial_state + + def _get_infer_summary(self, hparams): + if not self.has_attention or hparams.infer_mode == "beam_search": + return tf.no_op() + return _create_attention_images_summary(self.final_context_state) def create_attention_mechanism(attention_option, num_units, memory, source_sequence_length, mode): - """Create attention mechanism based on the attention_option.""" - del mode # unused - - # Mechanism - if attention_option == "luong": - attention_mechanism = tf.contrib.seq2seq.LuongAttention( - num_units, memory, memory_sequence_length=source_sequence_length) - elif attention_option == "scaled_luong": - attention_mechanism = tf.contrib.seq2seq.LuongAttention( - num_units, - memory, - memory_sequence_length=source_sequence_length, - scale=True) - elif attention_option == "bahdanau": - attention_mechanism = tf.contrib.seq2seq.BahdanauAttention( - num_units, memory, memory_sequence_length=source_sequence_length) - elif attention_option == "normed_bahdanau": - attention_mechanism = tf.contrib.seq2seq.BahdanauAttention( - num_units, - memory, - memory_sequence_length=source_sequence_length, - normalize=True) - else: - raise ValueError("Unknown attention option %s" % attention_option) - - return attention_mechanism + """Create attention mechanism based on the attention_option.""" + del mode # unused + + # Mechanism + if attention_option == "luong": + attention_mechanism = tf.contrib.seq2seq.LuongAttention( + num_units, memory, memory_sequence_length=source_sequence_length) + elif attention_option == "scaled_luong": + attention_mechanism = tf.contrib.seq2seq.LuongAttention( + num_units, + memory, + memory_sequence_length=source_sequence_length, + scale=True) + elif attention_option == "bahdanau": + attention_mechanism = tf.contrib.seq2seq.BahdanauAttention( + num_units, memory, memory_sequence_length=source_sequence_length) + elif attention_option == "normed_bahdanau": + attention_mechanism = tf.contrib.seq2seq.BahdanauAttention( + num_units, + memory, + memory_sequence_length=source_sequence_length, + normalize=True) + else: + raise ValueError("Unknown attention option %s" % attention_option) + + return attention_mechanism def _create_attention_images_summary(final_context_state): - """create attention image and attention summary.""" - attention_images = (final_context_state.alignment_history.stack()) - # Reshape to (batch, src_seq_len, tgt_seq_len,1) - attention_images = tf.expand_dims( - tf.transpose(attention_images, [1, 2, 0]), -1) - # Scale to range [0, 255] - attention_images *= 255 - attention_summary = tf.summary.image("attention_images", attention_images) - return attention_summary + """create attention image and attention summary.""" + attention_images = (final_context_state.alignment_history.stack()) + # Reshape to (batch, src_seq_len, tgt_seq_len,1) + attention_images = tf.expand_dims( + tf.transpose(attention_images, [1, 2, 0]), -1) + # Scale to range [0, 255] + attention_images *= 255 + attention_summary = tf.summary.image("attention_images", attention_images) + return attention_summary diff --git a/nmt/gnmt_model.py b/nmt/gnmt_model.py index 468a5d0..2cbfa06 100644 --- a/nmt/gnmt_model.py +++ b/nmt/gnmt_model.py @@ -29,305 +29,306 @@ class GNMTModel(attention_model.AttentionModel): - """Sequence-to-sequence dynamic model with GNMT attention architecture. - """ - - def __init__(self, - hparams, - mode, - iterator, - source_vocab_table, - target_vocab_table, - reverse_target_vocab_table=None, - scope=None, - extra_args=None): - self.is_gnmt_attention = ( - hparams.attention_architecture in ["gnmt", "gnmt_v2"]) - - super(GNMTModel, self).__init__( - hparams=hparams, - mode=mode, - iterator=iterator, - source_vocab_table=source_vocab_table, - target_vocab_table=target_vocab_table, - reverse_target_vocab_table=reverse_target_vocab_table, - scope=scope, - extra_args=extra_args) - - def _build_encoder(self, hparams): - """Build a GNMT encoder.""" - if hparams.encoder_type == "uni" or hparams.encoder_type == "bi": - return super(GNMTModel, self)._build_encoder(hparams) - - if hparams.encoder_type != "gnmt": - raise ValueError("Unknown encoder_type %s" % hparams.encoder_type) - - # Build GNMT encoder. - num_bi_layers = 1 - num_uni_layers = self.num_encoder_layers - num_bi_layers - utils.print_out("# Build a GNMT encoder") - utils.print_out(" num_bi_layers = %d" % num_bi_layers) - utils.print_out(" num_uni_layers = %d" % num_uni_layers) - - iterator = self.iterator - source = iterator.source - if self.time_major: - source = tf.transpose(source) - - with tf.variable_scope("encoder") as scope: - dtype = scope.dtype - - self.encoder_emb_inp = self.encoder_emb_lookup_fn( - self.embedding_encoder, source) - - # Execute _build_bidirectional_rnn from Model class - bi_encoder_outputs, bi_encoder_state = self._build_bidirectional_rnn( - inputs=self.encoder_emb_inp, - sequence_length=iterator.source_sequence_length, - dtype=dtype, - hparams=hparams, - num_bi_layers=num_bi_layers, - num_bi_residual_layers=0, # no residual connection - ) - - # Build unidirectional layers - if self.extract_encoder_layers: - encoder_state, encoder_outputs = self._build_individual_encoder_layers( - bi_encoder_outputs, num_uni_layers, dtype, hparams) - else: - encoder_state, encoder_outputs = self._build_all_encoder_layers( - bi_encoder_outputs, num_uni_layers, dtype, hparams) - - # Pass all encoder states to the decoder - # except the first bi-directional layer - encoder_state = (bi_encoder_state[1],) + ( - (encoder_state,) if num_uni_layers == 1 else encoder_state) - - return encoder_outputs, encoder_state - - def _build_all_encoder_layers(self, bi_encoder_outputs, - num_uni_layers, dtype, hparams): - """Build encoder layers all at once.""" - uni_cell = model_helper.create_rnn_cell( - unit_type=hparams.unit_type, - num_units=hparams.num_units, - num_layers=num_uni_layers, - num_residual_layers=self.num_encoder_residual_layers, - forget_bias=hparams.forget_bias, - dropout=hparams.dropout, - num_gpus=self.num_gpus, - base_gpu=1, - mode=self.mode, - single_cell_fn=self.single_cell_fn) - encoder_outputs, encoder_state = tf.nn.dynamic_rnn( - uni_cell, - bi_encoder_outputs, - dtype=dtype, - sequence_length=self.iterator.source_sequence_length, - time_major=self.time_major) - - # Use the top layer for now - self.encoder_state_list = [encoder_outputs] - - return encoder_state, encoder_outputs - - def _build_individual_encoder_layers(self, bi_encoder_outputs, - num_uni_layers, dtype, hparams): - """Run each of the encoder layer separately, not used in general seq2seq.""" - uni_cell_lists = model_helper._cell_list( - unit_type=hparams.unit_type, - num_units=hparams.num_units, - num_layers=num_uni_layers, - num_residual_layers=self.num_encoder_residual_layers, - forget_bias=hparams.forget_bias, - dropout=hparams.dropout, - num_gpus=self.num_gpus, - base_gpu=1, - mode=self.mode, - single_cell_fn=self.single_cell_fn) - - encoder_inp = bi_encoder_outputs - encoder_states = [] - self.encoder_state_list = [bi_encoder_outputs[:, :, :hparams.num_units], - bi_encoder_outputs[:, :, hparams.num_units:]] - with tf.variable_scope("rnn/multi_rnn_cell"): - for i, cell in enumerate(uni_cell_lists): - with tf.variable_scope("cell_%d" % i) as scope: - encoder_inp, encoder_state = tf.nn.dynamic_rnn( - cell, - encoder_inp, - dtype=dtype, - sequence_length=self.iterator.source_sequence_length, - time_major=self.time_major, - scope=scope) - encoder_states.append(encoder_state) - self.encoder_state_list.append(encoder_inp) - - encoder_state = tuple(encoder_states) - encoder_outputs = self.encoder_state_list[-1] - return encoder_state, encoder_outputs - - def _build_decoder_cell(self, hparams, encoder_outputs, encoder_state, - source_sequence_length): - """Build a RNN cell with GNMT attention architecture.""" - # Standard attention - if not self.is_gnmt_attention: - return super(GNMTModel, self)._build_decoder_cell( - hparams, encoder_outputs, encoder_state, source_sequence_length) - - # GNMT attention - attention_option = hparams.attention - attention_architecture = hparams.attention_architecture - num_units = hparams.num_units - infer_mode = hparams.infer_mode - - dtype = tf.float32 - - if self.time_major: - memory = tf.transpose(encoder_outputs, [1, 0, 2]) - else: - memory = encoder_outputs - - if (self.mode == tf.contrib.learn.ModeKeys.INFER and - infer_mode == "beam_search"): - memory, source_sequence_length, encoder_state, batch_size = ( - self._prepare_beam_search_decoder_inputs( - hparams.beam_width, memory, source_sequence_length, - encoder_state)) - else: - batch_size = self.batch_size - - attention_mechanism = self.attention_mechanism_fn( - attention_option, num_units, memory, source_sequence_length, self.mode) - - cell_list = model_helper._cell_list( # pylint: disable=protected-access - unit_type=hparams.unit_type, - num_units=num_units, - num_layers=self.num_decoder_layers, - num_residual_layers=self.num_decoder_residual_layers, - forget_bias=hparams.forget_bias, - dropout=hparams.dropout, - num_gpus=self.num_gpus, - mode=self.mode, - single_cell_fn=self.single_cell_fn, - residual_fn=gnmt_residual_fn - ) - - # Only wrap the bottom layer with the attention mechanism. - attention_cell = cell_list.pop(0) - - # Only generate alignment in greedy INFER mode. - alignment_history = (self.mode == tf.contrib.learn.ModeKeys.INFER and - infer_mode != "beam_search") - attention_cell = tf.contrib.seq2seq.AttentionWrapper( - attention_cell, - attention_mechanism, - attention_layer_size=None, # don't use attention layer. - output_attention=False, - alignment_history=alignment_history, - name="attention") - - if attention_architecture == "gnmt": - cell = GNMTAttentionMultiCell( - attention_cell, cell_list) - elif attention_architecture == "gnmt_v2": - cell = GNMTAttentionMultiCell( - attention_cell, cell_list, use_new_attention=True) - else: - raise ValueError( - "Unknown attention_architecture %s" % attention_architecture) - - if hparams.pass_hidden_state: - decoder_initial_state = tuple( - zs.clone(cell_state=es) - if isinstance(zs, tf.contrib.seq2seq.AttentionWrapperState) else es - for zs, es in zip( - cell.zero_state(batch_size, dtype), encoder_state)) - else: - decoder_initial_state = cell.zero_state(batch_size, dtype) - - return cell, decoder_initial_state - - def _get_infer_summary(self, hparams): - if hparams.infer_mode == "beam_search": - return tf.no_op() - elif self.is_gnmt_attention: - return attention_model._create_attention_images_summary( - self.final_context_state[0]) - else: - return super(GNMTModel, self)._get_infer_summary(hparams) + """Sequence-to-sequence dynamic model with GNMT attention architecture. + """ + + def __init__(self, + hparams, + mode, + iterator, + source_vocab_table, + target_vocab_table, + reverse_target_vocab_table=None, + scope=None, + extra_args=None): + self.is_gnmt_attention = ( + hparams.attention_architecture in ["gnmt", "gnmt_v2"]) + + super(GNMTModel, self).__init__( + hparams=hparams, + mode=mode, + iterator=iterator, + source_vocab_table=source_vocab_table, + target_vocab_table=target_vocab_table, + reverse_target_vocab_table=reverse_target_vocab_table, + scope=scope, + extra_args=extra_args) + + def _build_encoder(self, hparams): + """Build a GNMT encoder.""" + if hparams.encoder_type == "uni" or hparams.encoder_type == "bi": + return super(GNMTModel, self)._build_encoder(hparams) + + if hparams.encoder_type != "gnmt": + raise ValueError("Unknown encoder_type %s" % hparams.encoder_type) + + # Build GNMT encoder. + num_bi_layers = 1 + num_uni_layers = self.num_encoder_layers - num_bi_layers + utils.print_out("# Build a GNMT encoder") + utils.print_out(" num_bi_layers = %d" % num_bi_layers) + utils.print_out(" num_uni_layers = %d" % num_uni_layers) + + iterator = self.iterator + source = iterator.source + if self.time_major: + source = tf.transpose(source) + + with tf.variable_scope("encoder") as scope: + dtype = scope.dtype + + self.encoder_emb_inp = self.encoder_emb_lookup_fn( + self.embedding_encoder, source) + + # Execute _build_bidirectional_rnn from Model class + bi_encoder_outputs, bi_encoder_state = self._build_bidirectional_rnn( + inputs=self.encoder_emb_inp, + sequence_length=iterator.source_sequence_length, + dtype=dtype, + hparams=hparams, + num_bi_layers=num_bi_layers, + num_bi_residual_layers=0, # no residual connection + ) + + # Build unidirectional layers + if self.extract_encoder_layers: + encoder_state, encoder_outputs = self._build_individual_encoder_layers( + bi_encoder_outputs, num_uni_layers, dtype, hparams) + else: + encoder_state, encoder_outputs = self._build_all_encoder_layers( + bi_encoder_outputs, num_uni_layers, dtype, hparams) + + # Pass all encoder states to the decoder + # except the first bi-directional layer + encoder_state = (bi_encoder_state[1],) + ( + (encoder_state,) if num_uni_layers == 1 else encoder_state) + + return encoder_outputs, encoder_state + + def _build_all_encoder_layers(self, bi_encoder_outputs, + num_uni_layers, dtype, hparams): + """Build encoder layers all at once.""" + uni_cell = model_helper.create_rnn_cell( + unit_type=hparams.unit_type, + num_units=hparams.num_units, + num_layers=num_uni_layers, + num_residual_layers=self.num_encoder_residual_layers, + forget_bias=hparams.forget_bias, + dropout=hparams.dropout, + num_gpus=self.num_gpus, + base_gpu=1, + mode=self.mode, + single_cell_fn=self.single_cell_fn) + encoder_outputs, encoder_state = tf.nn.dynamic_rnn( + uni_cell, + bi_encoder_outputs, + dtype=dtype, + sequence_length=self.iterator.source_sequence_length, + time_major=self.time_major) + + # Use the top layer for now + self.encoder_state_list = [encoder_outputs] + + return encoder_state, encoder_outputs + + def _build_individual_encoder_layers(self, bi_encoder_outputs, + num_uni_layers, dtype, hparams): + """Run each of the encoder layer separately, not used in general seq2seq.""" + uni_cell_lists = model_helper._cell_list( + unit_type=hparams.unit_type, + num_units=hparams.num_units, + num_layers=num_uni_layers, + num_residual_layers=self.num_encoder_residual_layers, + forget_bias=hparams.forget_bias, + dropout=hparams.dropout, + num_gpus=self.num_gpus, + base_gpu=1, + mode=self.mode, + single_cell_fn=self.single_cell_fn) + + encoder_inp = bi_encoder_outputs + encoder_states = [] + self.encoder_state_list = [bi_encoder_outputs[:, :, :hparams.num_units], + bi_encoder_outputs[:, :, hparams.num_units:]] + with tf.variable_scope("rnn/multi_rnn_cell"): + for i, cell in enumerate(uni_cell_lists): + with tf.variable_scope("cell_%d" % i) as scope: + encoder_inp, encoder_state = tf.nn.dynamic_rnn( + cell, + encoder_inp, + dtype=dtype, + sequence_length=self.iterator.source_sequence_length, + time_major=self.time_major, + scope=scope) + encoder_states.append(encoder_state) + self.encoder_state_list.append(encoder_inp) + + encoder_state = tuple(encoder_states) + encoder_outputs = self.encoder_state_list[-1] + return encoder_state, encoder_outputs + + def _build_decoder_cell(self, hparams, encoder_outputs, encoder_state, + source_sequence_length): + """Build a RNN cell with GNMT attention architecture.""" + # Standard attention + if not self.is_gnmt_attention: + return super(GNMTModel, self)._build_decoder_cell( + hparams, encoder_outputs, encoder_state, source_sequence_length) + + # GNMT attention + attention_option = hparams.attention + attention_architecture = hparams.attention_architecture + num_units = hparams.num_units + infer_mode = hparams.infer_mode + + dtype = tf.float32 + + if self.time_major: + memory = tf.transpose(encoder_outputs, [1, 0, 2]) + else: + memory = encoder_outputs + + if (self.mode == tf.contrib.learn.ModeKeys.INFER and + infer_mode == "beam_search"): + memory, source_sequence_length, encoder_state, batch_size = ( + self._prepare_beam_search_decoder_inputs( + hparams.beam_width, memory, source_sequence_length, + encoder_state)) + else: + batch_size = self.batch_size + + attention_mechanism = self.attention_mechanism_fn( + attention_option, num_units, memory, source_sequence_length, self.mode) + + cell_list = model_helper._cell_list( # pylint: disable=protected-access + unit_type=hparams.unit_type, + num_units=num_units, + num_layers=self.num_decoder_layers, + num_residual_layers=self.num_decoder_residual_layers, + forget_bias=hparams.forget_bias, + dropout=hparams.dropout, + num_gpus=self.num_gpus, + mode=self.mode, + single_cell_fn=self.single_cell_fn, + residual_fn=gnmt_residual_fn + ) + + # Only wrap the bottom layer with the attention mechanism. + attention_cell = cell_list.pop(0) + + # Only generate alignment in greedy INFER mode. + alignment_history = (self.mode == tf.contrib.learn.ModeKeys.INFER and + infer_mode != "beam_search") + attention_cell = tf.contrib.seq2seq.AttentionWrapper( + attention_cell, + attention_mechanism, + attention_layer_size=None, # don't use attention layer. + output_attention=False, + alignment_history=alignment_history, + name="attention") + + if attention_architecture == "gnmt": + cell = GNMTAttentionMultiCell( + attention_cell, cell_list) + elif attention_architecture == "gnmt_v2": + cell = GNMTAttentionMultiCell( + attention_cell, cell_list, use_new_attention=True) + else: + raise ValueError( + "Unknown attention_architecture %s" % attention_architecture) + + if hparams.pass_hidden_state: + decoder_initial_state = tuple( + zs.clone(cell_state=es) + if isinstance(zs, tf.contrib.seq2seq.AttentionWrapperState) else es + for zs, es in zip( + cell.zero_state(batch_size, dtype), encoder_state)) + else: + decoder_initial_state = cell.zero_state(batch_size, dtype) + + return cell, decoder_initial_state + + def _get_infer_summary(self, hparams): + if hparams.infer_mode == "beam_search": + return tf.no_op() + elif self.is_gnmt_attention: + return attention_model._create_attention_images_summary( + self.final_context_state[0]) + else: + return super(GNMTModel, self)._get_infer_summary(hparams) class GNMTAttentionMultiCell(tf.nn.rnn_cell.MultiRNNCell): - """A MultiCell with GNMT attention style.""" + """A MultiCell with GNMT attention style.""" - def __init__(self, attention_cell, cells, use_new_attention=False): - """Creates a GNMTAttentionMultiCell. + def __init__(self, attention_cell, cells, use_new_attention=False): + """Creates a GNMTAttentionMultiCell. - Args: - attention_cell: An instance of AttentionWrapper. - cells: A list of RNNCell wrapped with AttentionInputWrapper. - use_new_attention: Whether to use the attention generated from current - step bottom layer's output. Default is False. - """ - cells = [attention_cell] + cells - self.use_new_attention = use_new_attention - super(GNMTAttentionMultiCell, self).__init__(cells, state_is_tuple=True) + Args: + attention_cell: An instance of AttentionWrapper. + cells: A list of RNNCell wrapped with AttentionInputWrapper. + use_new_attention: Whether to use the attention generated from current + step bottom layer's output. Default is False. + """ + cells = [attention_cell] + cells + self.use_new_attention = use_new_attention + super(GNMTAttentionMultiCell, self).__init__(cells, state_is_tuple=True) - def __call__(self, inputs, state, scope=None): - """Run the cell with bottom layer's attention copied to all upper layers.""" - if not tf.contrib.framework.nest.is_sequence(state): - raise ValueError( - "Expected state to be a tuple of length %d, but received: %s" - % (len(self.state_size), state)) + def __call__(self, inputs, state, scope=None): + """Run the cell with bottom layer's attention copied to all upper layers.""" + if not tf.contrib.framework.nest.is_sequence(state): + raise ValueError( + "Expected state to be a tuple of length %d, but received: %s" + % (len(self.state_size), state)) - with tf.variable_scope(scope or "multi_rnn_cell"): - new_states = [] + with tf.variable_scope(scope or "multi_rnn_cell"): + new_states = [] - with tf.variable_scope("cell_0_attention"): - attention_cell = self._cells[0] - attention_state = state[0] - cur_inp, new_attention_state = attention_cell(inputs, attention_state) - new_states.append(new_attention_state) + with tf.variable_scope("cell_0_attention"): + attention_cell = self._cells[0] + attention_state = state[0] + cur_inp, new_attention_state = attention_cell(inputs, attention_state) + new_states.append(new_attention_state) - for i in range(1, len(self._cells)): - with tf.variable_scope("cell_%d" % i): + for i in range(1, len(self._cells)): + with tf.variable_scope("cell_%d" % i): - cell = self._cells[i] - cur_state = state[i] + cell = self._cells[i] + cur_state = state[i] - if self.use_new_attention: - cur_inp = tf.concat([cur_inp, new_attention_state.attention], -1) - else: - cur_inp = tf.concat([cur_inp, attention_state.attention], -1) + if self.use_new_attention: + cur_inp = tf.concat([cur_inp, new_attention_state.attention], -1) + else: + cur_inp = tf.concat([cur_inp, attention_state.attention], -1) - cur_inp, new_state = cell(cur_inp, cur_state) - new_states.append(new_state) + cur_inp, new_state = cell(cur_inp, cur_state) + new_states.append(new_state) - return cur_inp, tuple(new_states) + return cur_inp, tuple(new_states) def gnmt_residual_fn(inputs, outputs): - """Residual function that handles different inputs and outputs inner dims. - - Args: - inputs: cell inputs, this is actual inputs concatenated with the attention - vector. - outputs: cell outputs - - Returns: - outputs + actual inputs - """ - def split_input(inp, out): - out_dim = out.get_shape().as_list()[-1] - inp_dim = inp.get_shape().as_list()[-1] - return tf.split(inp, [out_dim, inp_dim - out_dim], axis=-1) - actual_inputs, _ = tf.contrib.framework.nest.map_structure( - split_input, inputs, outputs) - def assert_shape_match(inp, out): - inp.get_shape().assert_is_compatible_with(out.get_shape()) - tf.contrib.framework.nest.assert_same_structure(actual_inputs, outputs) - tf.contrib.framework.nest.map_structure( - assert_shape_match, actual_inputs, outputs) - return tf.contrib.framework.nest.map_structure( - lambda inp, out: inp + out, actual_inputs, outputs) + """Residual function that handles different inputs and outputs inner dims. + + Args: + inputs: cell inputs, this is actual inputs concatenated with the attention + vector. + outputs: cell outputs + + Returns: + outputs + actual inputs + """ + def split_input(inp, out): + out_dim = out.get_shape().as_list()[-1] + inp_dim = inp.get_shape().as_list()[-1] + return tf.split(inp, [out_dim, inp_dim - out_dim], axis=-1) + actual_inputs, _ = tf.contrib.framework.nest.map_structure( + split_input, inputs, outputs) + + def assert_shape_match(inp, out): + inp.get_shape().assert_is_compatible_with(out.get_shape()) + tf.contrib.framework.nest.assert_same_structure(actual_inputs, outputs) + tf.contrib.framework.nest.map_structure( + assert_shape_match, actual_inputs, outputs) + return tf.contrib.framework.nest.map_structure( + lambda inp, out: inp + out, actual_inputs, outputs) diff --git a/nmt/inference.py b/nmt/inference.py index 2cbef07..d7b3069 100644 --- a/nmt/inference.py +++ b/nmt/inference.py @@ -37,72 +37,72 @@ def _decode_inference_indices(model, sess, output_infer, inference_indices, tgt_eos, subword_option): - """Decoding only a specific set of sentences.""" - utils.print_out(" decoding to output %s , num sents %d." % - (output_infer, len(inference_indices))) - start_time = time.time() - with codecs.getwriter("utf-8")( - tf.gfile.GFile(output_infer, mode="wb")) as trans_f: - trans_f.write("") # Write empty string to ensure file is created. - for decode_id in inference_indices: - nmt_outputs, infer_summary = model.decode(sess) - - # get text translation - assert nmt_outputs.shape[0] == 1 - translation = nmt_utils.get_translation( - nmt_outputs, - sent_id=0, - tgt_eos=tgt_eos, - subword_option=subword_option) - - if infer_summary is not None: # Attention models - image_file = output_infer_summary_prefix + str(decode_id) + ".png" - utils.print_out(" save attention image to %s*" % image_file) - image_summ = tf.Summary() - image_summ.ParseFromString(infer_summary) - with tf.gfile.GFile(image_file, mode="w") as img_f: - img_f.write(image_summ.value[0].image.encoded_image_string) - - trans_f.write("%s\n" % translation) - utils.print_out(translation + b"\n") - utils.print_time(" done", start_time) + """Decoding only a specific set of sentences.""" + utils.print_out(" decoding to output %s , num sents %d." % + (output_infer, len(inference_indices))) + start_time = time.time() + with codecs.getwriter("utf-8")( + tf.gfile.GFile(output_infer, mode="wb")) as trans_f: + trans_f.write("") # Write empty string to ensure file is created. + for decode_id in inference_indices: + nmt_outputs, infer_summary = model.decode(sess) + + # get text translation + assert nmt_outputs.shape[0] == 1 + translation = nmt_utils.get_translation( + nmt_outputs, + sent_id=0, + tgt_eos=tgt_eos, + subword_option=subword_option) + + if infer_summary is not None: # Attention models + image_file = output_infer_summary_prefix + str(decode_id) + ".png" + utils.print_out(" save attention image to %s*" % image_file) + image_summ = tf.Summary() + image_summ.ParseFromString(infer_summary) + with tf.gfile.GFile(image_file, mode="w") as img_f: + img_f.write(image_summ.value[0].image.encoded_image_string) + + trans_f.write("%s\n" % translation) + utils.print_out(translation + b"\n") + utils.print_time(" done", start_time) def load_data(inference_input_file, hparams=None): - """Load inference data.""" - with codecs.getreader("utf-8")( - tf.gfile.GFile(inference_input_file, mode="rb")) as f: - inference_data = f.read().splitlines() + """Load inference data.""" + with codecs.getreader("utf-8")( + tf.gfile.GFile(inference_input_file, mode="rb")) as f: + inference_data = f.read().splitlines() - if hparams and hparams.inference_indices: - inference_data = [inference_data[i] for i in hparams.inference_indices] + if hparams and hparams.inference_indices: + inference_data = [inference_data[i] for i in hparams.inference_indices] - return inference_data + return inference_data def get_model_creator(hparams): - """Get the right model class depending on configuration.""" - if (hparams.encoder_type == "gnmt" or - hparams.attention_architecture in ["gnmt", "gnmt_v2"]): - model_creator = gnmt_model.GNMTModel - elif hparams.attention_architecture == "standard": - model_creator = attention_model.AttentionModel - elif not hparams.attention: - model_creator = nmt_model.Model - else: - raise ValueError("Unknown attention architecture %s" % - hparams.attention_architecture) - return model_creator + """Get the right model class depending on configuration.""" + if (hparams.encoder_type == "gnmt" or + hparams.attention_architecture in ["gnmt", "gnmt_v2"]): + model_creator = gnmt_model.GNMTModel + elif hparams.attention_architecture == "standard": + model_creator = attention_model.AttentionModel + elif not hparams.attention: + model_creator = nmt_model.Model + else: + raise ValueError("Unknown attention architecture %s" % + hparams.attention_architecture) + return model_creator def start_sess_and_load_model(infer_model, ckpt_path): - """Start session and load model.""" - sess = tf.Session( - graph=infer_model.graph, config=utils.get_config_proto()) - with infer_model.graph.as_default(): - loaded_infer_model = model_helper.load_model( - infer_model.model, ckpt_path, sess, "infer") - return sess, loaded_infer_model + """Start session and load model.""" + sess = tf.Session( + graph=infer_model.graph, config=utils.get_config_proto()) + with infer_model.graph.as_default(): + loaded_infer_model = model_helper.load_model( + infer_model.model, ckpt_path, sess, "infer") + return sess, loaded_infer_model def inference(ckpt_path, @@ -112,33 +112,33 @@ def inference(ckpt_path, num_workers=1, jobid=0, scope=None): - """Perform translation.""" - if hparams.inference_indices: - assert num_workers == 1 - - model_creator = get_model_creator(hparams) - infer_model = model_helper.create_infer_model(model_creator, hparams, scope) - sess, loaded_infer_model = start_sess_and_load_model(infer_model, ckpt_path) - - if num_workers == 1: - single_worker_inference( - sess, - infer_model, - loaded_infer_model, - inference_input_file, - inference_output_file, - hparams) - else: - multi_worker_inference( - sess, - infer_model, - loaded_infer_model, - inference_input_file, - inference_output_file, - hparams, - num_workers=num_workers, - jobid=jobid) - sess.close() + """Perform translation.""" + if hparams.inference_indices: + assert num_workers == 1 + + model_creator = get_model_creator(hparams) + infer_model = model_helper.create_infer_model(model_creator, hparams, scope) + sess, loaded_infer_model = start_sess_and_load_model(infer_model, ckpt_path) + + if num_workers == 1: + single_worker_inference( + sess, + infer_model, + loaded_infer_model, + inference_input_file, + inference_output_file, + hparams) + else: + multi_worker_inference( + sess, + infer_model, + loaded_infer_model, + inference_input_file, + inference_output_file, + hparams, + num_workers=num_workers, + jobid=jobid) + sess.close() def single_worker_inference(sess, @@ -147,43 +147,43 @@ def single_worker_inference(sess, inference_input_file, inference_output_file, hparams): - """Inference with a single worker.""" - output_infer = inference_output_file - - # Read data - infer_data = load_data(inference_input_file, hparams) - - with infer_model.graph.as_default(): - sess.run( - infer_model.iterator.initializer, - feed_dict={ - infer_model.src_placeholder: infer_data, - infer_model.batch_size_placeholder: hparams.infer_batch_size - }) - # Decode - utils.print_out("# Start decoding") - if hparams.inference_indices: - _decode_inference_indices( - loaded_infer_model, - sess, - output_infer=output_infer, - output_infer_summary_prefix=output_infer, - inference_indices=hparams.inference_indices, - tgt_eos=hparams.eos, - subword_option=hparams.subword_option) - else: - nmt_utils.decode_and_evaluate( - "infer", - loaded_infer_model, - sess, - output_infer, - ref_file=None, - metrics=hparams.metrics, - subword_option=hparams.subword_option, - beam_width=hparams.beam_width, - tgt_eos=hparams.eos, - num_translations_per_input=hparams.num_translations_per_input, - infer_mode=hparams.infer_mode) + """Inference with a single worker.""" + output_infer = inference_output_file + + # Read data + infer_data = load_data(inference_input_file, hparams) + + with infer_model.graph.as_default(): + sess.run( + infer_model.iterator.initializer, + feed_dict={ + infer_model.src_placeholder: infer_data, + infer_model.batch_size_placeholder: hparams.infer_batch_size + }) + # Decode + utils.print_out("# Start decoding") + if hparams.inference_indices: + _decode_inference_indices( + loaded_infer_model, + sess, + output_infer=output_infer, + output_infer_summary_prefix=output_infer, + inference_indices=hparams.inference_indices, + tgt_eos=hparams.eos, + subword_option=hparams.subword_option) + else: + nmt_utils.decode_and_evaluate( + "infer", + loaded_infer_model, + sess, + output_infer, + ref_file=None, + metrics=hparams.metrics, + subword_option=hparams.subword_option, + beam_width=hparams.beam_width, + tgt_eos=hparams.eos, + num_translations_per_input=hparams.num_translations_per_input, + infer_mode=hparams.infer_mode) def multi_worker_inference(sess, @@ -194,64 +194,65 @@ def multi_worker_inference(sess, hparams, num_workers, jobid): - """Inference using multiple workers.""" - assert num_workers > 1 - - final_output_infer = inference_output_file - output_infer = "%s_%d" % (inference_output_file, jobid) - output_infer_done = "%s_done_%d" % (inference_output_file, jobid) - - # Read data - infer_data = load_data(inference_input_file, hparams) - - # Split data to multiple workers - total_load = len(infer_data) - load_per_worker = int((total_load - 1) / num_workers) + 1 - start_position = jobid * load_per_worker - end_position = min(start_position + load_per_worker, total_load) - infer_data = infer_data[start_position:end_position] - - with infer_model.graph.as_default(): - sess.run(infer_model.iterator.initializer, - { - infer_model.src_placeholder: infer_data, - infer_model.batch_size_placeholder: hparams.infer_batch_size - }) - # Decode - utils.print_out("# Start decoding") - nmt_utils.decode_and_evaluate( - "infer", - loaded_infer_model, - sess, - output_infer, - ref_file=None, - metrics=hparams.metrics, - subword_option=hparams.subword_option, - beam_width=hparams.beam_width, - tgt_eos=hparams.eos, - num_translations_per_input=hparams.num_translations_per_input, - infer_mode=hparams.infer_mode) - - # Change file name to indicate the file writing is completed. - tf.gfile.Rename(output_infer, output_infer_done, overwrite=True) - - # Job 0 is responsible for the clean up. - if jobid != 0: return - - # Now write all translations - with codecs.getwriter("utf-8")( - tf.gfile.GFile(final_output_infer, mode="wb")) as final_f: - for worker_id in range(num_workers): - worker_infer_done = "%s_done_%d" % (inference_output_file, worker_id) - while not tf.gfile.Exists(worker_infer_done): - utils.print_out(" waiting job %d to complete." % worker_id) - time.sleep(10) - - with codecs.getreader("utf-8")( - tf.gfile.GFile(worker_infer_done, mode="rb")) as f: - for translation in f: - final_f.write("%s" % translation) - - for worker_id in range(num_workers): - worker_infer_done = "%s_done_%d" % (inference_output_file, worker_id) - tf.gfile.Remove(worker_infer_done) + """Inference using multiple workers.""" + assert num_workers > 1 + + final_output_infer = inference_output_file + output_infer = "%s_%d" % (inference_output_file, jobid) + output_infer_done = "%s_done_%d" % (inference_output_file, jobid) + + # Read data + infer_data = load_data(inference_input_file, hparams) + + # Split data to multiple workers + total_load = len(infer_data) + load_per_worker = int((total_load - 1) / num_workers) + 1 + start_position = jobid * load_per_worker + end_position = min(start_position + load_per_worker, total_load) + infer_data = infer_data[start_position:end_position] + + with infer_model.graph.as_default(): + sess.run(infer_model.iterator.initializer, + { + infer_model.src_placeholder: infer_data, + infer_model.batch_size_placeholder: hparams.infer_batch_size + }) + # Decode + utils.print_out("# Start decoding") + nmt_utils.decode_and_evaluate( + "infer", + loaded_infer_model, + sess, + output_infer, + ref_file=None, + metrics=hparams.metrics, + subword_option=hparams.subword_option, + beam_width=hparams.beam_width, + tgt_eos=hparams.eos, + num_translations_per_input=hparams.num_translations_per_input, + infer_mode=hparams.infer_mode) + + # Change file name to indicate the file writing is completed. + tf.gfile.Rename(output_infer, output_infer_done, overwrite=True) + + # Job 0 is responsible for the clean up. + if jobid != 0: + return + + # Now write all translations + with codecs.getwriter("utf-8")( + tf.gfile.GFile(final_output_infer, mode="wb")) as final_f: + for worker_id in range(num_workers): + worker_infer_done = "%s_done_%d" % (inference_output_file, worker_id) + while not tf.gfile.Exists(worker_infer_done): + utils.print_out(" waiting job %d to complete." % worker_id) + time.sleep(10) + + with codecs.getreader("utf-8")( + tf.gfile.GFile(worker_infer_done, mode="rb")) as f: + for translation in f: + final_f.write("%s" % translation) + + for worker_id in range(num_workers): + worker_infer_done = "%s_done_%d" % (inference_output_file, worker_id) + tf.gfile.Remove(worker_infer_done) diff --git a/nmt/inference_test.py b/nmt/inference_test.py index 317024b..e7c62b8 100644 --- a/nmt/inference_test.py +++ b/nmt/inference_test.py @@ -34,142 +34,142 @@ class InferenceTest(tf.test.TestCase): - def _createTestInferCheckpoint(self, hparams, name): - # Prepare - hparams.vocab_prefix = ( - "nmt/testdata/test_infer_vocab") - hparams.src_vocab_file = hparams.vocab_prefix + "." + hparams.src - hparams.tgt_vocab_file = hparams.vocab_prefix + "." + hparams.tgt - out_dir = os.path.join(tf.test.get_temp_dir(), name) - os.makedirs(out_dir) - hparams.out_dir = out_dir - - # Create check point - model_creator = inference.get_model_creator(hparams) - infer_model = model_helper.create_infer_model(model_creator, hparams) - with self.test_session(graph=infer_model.graph) as sess: - loaded_model, global_step = model_helper.create_or_load_model( - infer_model.model, out_dir, sess, "infer_name") - ckpt_path = loaded_model.saver.save( - sess, os.path.join(out_dir, "translate.ckpt"), - global_step=global_step) - return ckpt_path - - def testBasicModel(self): - hparams = common_test_utils.create_test_hparams( - encoder_type="uni", - num_layers=1, - attention="", - attention_architecture="", - use_residual=False,) - ckpt_path = self._createTestInferCheckpoint(hparams, "basic_infer") - infer_file = "nmt/testdata/test_infer_file" - output_infer = os.path.join(hparams.out_dir, "output_infer") - inference.inference(ckpt_path, infer_file, output_infer, hparams) - with open(output_infer) as f: - self.assertEqual(5, len(list(f))) - - def testBasicModelWithMultipleTranslations(self): - hparams = common_test_utils.create_test_hparams( - encoder_type="uni", - num_layers=1, - attention="", - attention_architecture="", - use_residual=False, - num_translations_per_input=2, - beam_width=2, - ) - hparams.infer_mode = "beam_search" - - ckpt_path = self._createTestInferCheckpoint(hparams, "multi_basic_infer") - infer_file = "nmt/testdata/test_infer_file" - output_infer = os.path.join(hparams.out_dir, "output_infer") - inference.inference(ckpt_path, infer_file, output_infer, hparams) - with open(output_infer) as f: - self.assertEqual(10, len(list(f))) - - def testAttentionModel(self): - hparams = common_test_utils.create_test_hparams( - encoder_type="uni", - num_layers=1, - attention="scaled_luong", - attention_architecture="standard", - use_residual=False,) - ckpt_path = self._createTestInferCheckpoint(hparams, "attention_infer") - infer_file = "nmt/testdata/test_infer_file" - output_infer = os.path.join(hparams.out_dir, "output_infer") - inference.inference(ckpt_path, infer_file, output_infer, hparams) - with open(output_infer) as f: - self.assertEqual(5, len(list(f))) - - def testMultiWorkers(self): - hparams = common_test_utils.create_test_hparams( - encoder_type="uni", - num_layers=2, - attention="scaled_luong", - attention_architecture="standard", - use_residual=False,) - - num_workers = 3 - - # There are 5 examples, make batch_size=3 makes job0 has 3 examples, job1 - # has 2 examples, and job2 has 0 example. This helps testing some edge - # cases. - hparams.batch_size = 3 - - ckpt_path = self._createTestInferCheckpoint(hparams, "multi_worker_infer") - infer_file = "nmt/testdata/test_infer_file" - output_infer = os.path.join(hparams.out_dir, "output_infer") - inference.inference( - ckpt_path, infer_file, output_infer, hparams, num_workers, jobid=1) - - inference.inference( - ckpt_path, infer_file, output_infer, hparams, num_workers, jobid=2) - - # Note: Need to start job 0 at the end; otherwise, it will block the testing - # thread. - inference.inference( - ckpt_path, infer_file, output_infer, hparams, num_workers, jobid=0) - - with open(output_infer) as f: - self.assertEqual(5, len(list(f))) - - def testBasicModelWithInferIndices(self): - hparams = common_test_utils.create_test_hparams( - encoder_type="uni", - num_layers=1, - attention="", - attention_architecture="", - use_residual=False, - inference_indices=[0]) - ckpt_path = self._createTestInferCheckpoint(hparams, - "basic_infer_with_indices") - infer_file = "nmt/testdata/test_infer_file" - output_infer = os.path.join(hparams.out_dir, "output_infer") - inference.inference(ckpt_path, infer_file, output_infer, hparams) - with open(output_infer) as f: - self.assertEqual(1, len(list(f))) - - def testAttentionModelWithInferIndices(self): - hparams = common_test_utils.create_test_hparams( - encoder_type="uni", - num_layers=1, - attention="scaled_luong", - attention_architecture="standard", - use_residual=False, - inference_indices=[1, 2]) - # TODO(rzhao): Make infer indices support batch_size > 1. - hparams.infer_batch_size = 1 - ckpt_path = self._createTestInferCheckpoint(hparams, - "attention_infer_with_indices") - infer_file = "nmt/testdata/test_infer_file" - output_infer = os.path.join(hparams.out_dir, "output_infer") - inference.inference(ckpt_path, infer_file, output_infer, hparams) - with open(output_infer) as f: - self.assertEqual(2, len(list(f))) - self.assertTrue(os.path.exists(output_infer+str(1)+".png")) - self.assertTrue(os.path.exists(output_infer+str(2)+".png")) + def _createTestInferCheckpoint(self, hparams, name): + # Prepare + hparams.vocab_prefix = ( + "nmt/testdata/test_infer_vocab") + hparams.src_vocab_file = hparams.vocab_prefix + "." + hparams.src + hparams.tgt_vocab_file = hparams.vocab_prefix + "." + hparams.tgt + out_dir = os.path.join(tf.test.get_temp_dir(), name) + os.makedirs(out_dir) + hparams.out_dir = out_dir + + # Create check point + model_creator = inference.get_model_creator(hparams) + infer_model = model_helper.create_infer_model(model_creator, hparams) + with self.test_session(graph=infer_model.graph) as sess: + loaded_model, global_step = model_helper.create_or_load_model( + infer_model.model, out_dir, sess, "infer_name") + ckpt_path = loaded_model.saver.save( + sess, os.path.join(out_dir, "translate.ckpt"), + global_step=global_step) + return ckpt_path + + def testBasicModel(self): + hparams = common_test_utils.create_test_hparams( + encoder_type="uni", + num_layers=1, + attention="", + attention_architecture="", + use_residual=False,) + ckpt_path = self._createTestInferCheckpoint(hparams, "basic_infer") + infer_file = "nmt/testdata/test_infer_file" + output_infer = os.path.join(hparams.out_dir, "output_infer") + inference.inference(ckpt_path, infer_file, output_infer, hparams) + with open(output_infer) as f: + self.assertEqual(5, len(list(f))) + + def testBasicModelWithMultipleTranslations(self): + hparams = common_test_utils.create_test_hparams( + encoder_type="uni", + num_layers=1, + attention="", + attention_architecture="", + use_residual=False, + num_translations_per_input=2, + beam_width=2, + ) + hparams.infer_mode = "beam_search" + + ckpt_path = self._createTestInferCheckpoint(hparams, "multi_basic_infer") + infer_file = "nmt/testdata/test_infer_file" + output_infer = os.path.join(hparams.out_dir, "output_infer") + inference.inference(ckpt_path, infer_file, output_infer, hparams) + with open(output_infer) as f: + self.assertEqual(10, len(list(f))) + + def testAttentionModel(self): + hparams = common_test_utils.create_test_hparams( + encoder_type="uni", + num_layers=1, + attention="scaled_luong", + attention_architecture="standard", + use_residual=False,) + ckpt_path = self._createTestInferCheckpoint(hparams, "attention_infer") + infer_file = "nmt/testdata/test_infer_file" + output_infer = os.path.join(hparams.out_dir, "output_infer") + inference.inference(ckpt_path, infer_file, output_infer, hparams) + with open(output_infer) as f: + self.assertEqual(5, len(list(f))) + + def testMultiWorkers(self): + hparams = common_test_utils.create_test_hparams( + encoder_type="uni", + num_layers=2, + attention="scaled_luong", + attention_architecture="standard", + use_residual=False,) + + num_workers = 3 + + # There are 5 examples, make batch_size=3 makes job0 has 3 examples, job1 + # has 2 examples, and job2 has 0 example. This helps testing some edge + # cases. + hparams.batch_size = 3 + + ckpt_path = self._createTestInferCheckpoint(hparams, "multi_worker_infer") + infer_file = "nmt/testdata/test_infer_file" + output_infer = os.path.join(hparams.out_dir, "output_infer") + inference.inference( + ckpt_path, infer_file, output_infer, hparams, num_workers, jobid=1) + + inference.inference( + ckpt_path, infer_file, output_infer, hparams, num_workers, jobid=2) + + # Note: Need to start job 0 at the end; otherwise, it will block the testing + # thread. + inference.inference( + ckpt_path, infer_file, output_infer, hparams, num_workers, jobid=0) + + with open(output_infer) as f: + self.assertEqual(5, len(list(f))) + + def testBasicModelWithInferIndices(self): + hparams = common_test_utils.create_test_hparams( + encoder_type="uni", + num_layers=1, + attention="", + attention_architecture="", + use_residual=False, + inference_indices=[0]) + ckpt_path = self._createTestInferCheckpoint(hparams, + "basic_infer_with_indices") + infer_file = "nmt/testdata/test_infer_file" + output_infer = os.path.join(hparams.out_dir, "output_infer") + inference.inference(ckpt_path, infer_file, output_infer, hparams) + with open(output_infer) as f: + self.assertEqual(1, len(list(f))) + + def testAttentionModelWithInferIndices(self): + hparams = common_test_utils.create_test_hparams( + encoder_type="uni", + num_layers=1, + attention="scaled_luong", + attention_architecture="standard", + use_residual=False, + inference_indices=[1, 2]) + # TODO(rzhao): Make infer indices support batch_size > 1. + hparams.infer_batch_size = 1 + ckpt_path = self._createTestInferCheckpoint(hparams, + "attention_infer_with_indices") + infer_file = "nmt/testdata/test_infer_file" + output_infer = os.path.join(hparams.out_dir, "output_infer") + inference.inference(ckpt_path, infer_file, output_infer, hparams) + with open(output_infer) as f: + self.assertEqual(2, len(list(f))) + self.assertTrue(os.path.exists(output_infer+str(1)+".png")) + self.assertTrue(os.path.exists(output_infer+str(2)+".png")) if __name__ == "__main__": - tf.test.main() + tf.test.main() diff --git a/nmt/model.py b/nmt/model.py index 431bf5c..bb5c26e 100644 --- a/nmt/model.py +++ b/nmt/model.py @@ -15,6 +15,8 @@ """Basic sequence-to-sequence model with dynamic RNN support.""" from __future__ import absolute_import +from kdq_embedding import full_embed, kdq_embed, KDQhparam +from kd_quantizer import KDQuantizer from __future__ import division from __future__ import print_function @@ -34,8 +36,6 @@ parent_path = "/".join(os.getcwd().split('/')) sys.path.append(os.path.join(parent_path, "core")) sys.path.append(os.path.join(parent_path, "kdq/core")) -from kd_quantizer import KDQuantizer -from kdq_embedding import full_embed, kdq_embed, KDQhparam utils.check_tensorflow_version() @@ -46,864 +46,865 @@ class TrainOutputTuple(collections.namedtuple( "TrainOutputTuple", ("train_summary", "train_loss", "predict_count", "global_step", "word_count", "batch_size", "grad_norm", "learning_rate"))): - """To allow for flexibily in returing different outputs.""" - pass + """To allow for flexibily in returing different outputs.""" + pass class EvalOutputTuple(collections.namedtuple( - "EvalOutputTuple", ("eval_loss", "predict_count", "batch_size"))): - """To allow for flexibily in returing different outputs.""" - pass + "EvalOutputTuple", ("eval_loss", "predict_count", "batch_size"))): + """To allow for flexibily in returing different outputs.""" + pass class InferOutputTuple(collections.namedtuple( "InferOutputTuple", ("infer_logits", "infer_summary", "sample_id", "sample_words"))): - """To allow for flexibily in returing different outputs.""" - pass + """To allow for flexibily in returing different outputs.""" + pass class BaseModel(object): - """Sequence-to-sequence base class. - """ - - def __init__(self, - hparams, - mode, - iterator, - source_vocab_table, - target_vocab_table, - reverse_target_vocab_table=None, - scope=None, - extra_args=None): - """Create the model. - - Args: - hparams: Hyperparameter configurations. - mode: TRAIN | EVAL | INFER - iterator: Dataset Iterator that feeds data. - source_vocab_table: Lookup table mapping source words to ids. - target_vocab_table: Lookup table mapping target words to ids. - reverse_target_vocab_table: Lookup table mapping ids to target words. Only - required in INFER mode. Defaults to None. - scope: scope of the model. - extra_args: model_helper.ExtraArgs, for passing customizable functions. - - """ - # Set params - self._set_params_initializer(hparams, mode, iterator, - source_vocab_table, target_vocab_table, - scope, extra_args) - - # Not used in general seq2seq models; when True, ignore decoder & training - self.extract_encoder_layers = (hasattr(hparams, "extract_encoder_layers") - and hparams.extract_encoder_layers) - - # Train graph - res = self.build_graph(hparams, scope=scope) - if not self.extract_encoder_layers: - self._set_train_or_infer(res, reverse_target_vocab_table, hparams) - - # Saver - self.saver = tf.train.Saver( - tf.global_variables(), max_to_keep=hparams.num_keep_ckpts) - - def _set_params_initializer(self, - hparams, - mode, - iterator, - source_vocab_table, - target_vocab_table, - scope, - extra_args=None): - """Set various params for self and initialize.""" - assert isinstance(iterator, iterator_utils.BatchedInput) - self.iterator = iterator - self.mode = mode - self.src_vocab_table = source_vocab_table - self.tgt_vocab_table = target_vocab_table - - self.src_vocab_size = hparams.src_vocab_size - self.tgt_vocab_size = hparams.tgt_vocab_size - self.num_gpus = hparams.num_gpus - self.time_major = hparams.time_major - - if hparams.use_char_encode: - assert (not self.time_major), ("Can't use time major for" - " char-level inputs.") - - self.dtype = tf.float32 - self.num_sampled_softmax = hparams.num_sampled_softmax - - # extra_args: to make it flexible for adding external customizable code - self.single_cell_fn = None - if extra_args: - self.single_cell_fn = extra_args.single_cell_fn - - # Set num units - self.num_units = hparams.num_units - - # Set num layers - self.num_encoder_layers = hparams.num_encoder_layers - self.num_decoder_layers = hparams.num_decoder_layers - assert self.num_encoder_layers - assert self.num_decoder_layers - - # Set num residual layers - if hasattr(hparams, "num_residual_layers"): # compatible common_test_utils - self.num_encoder_residual_layers = hparams.num_residual_layers - self.num_decoder_residual_layers = hparams.num_residual_layers - else: - self.num_encoder_residual_layers = hparams.num_encoder_residual_layers - self.num_decoder_residual_layers = hparams.num_decoder_residual_layers - - # Batch size - self.batch_size = tf.size(self.iterator.source_sequence_length) - - # Global step - self.global_step = tf.Variable(0, trainable=False) - - # Initializer - self.random_seed = hparams.random_seed - initializer = model_helper.get_initializer( - hparams.init_op, self.random_seed, hparams.init_weight) - tf.get_variable_scope().set_initializer(initializer) - - # Embeddings - if extra_args and extra_args.encoder_emb_lookup_fn: - self.encoder_emb_lookup_fn = extra_args.encoder_emb_lookup_fn - else: - self.encoder_emb_lookup_fn = tf.nn.embedding_lookup - self.init_embeddings(hparams, scope) - - def _set_train_or_infer(self, res, reverse_target_vocab_table, hparams): - """Set up training and inference.""" - if self.mode == tf.contrib.learn.ModeKeys.TRAIN: - self.train_loss = res[1] - self.word_count = tf.reduce_sum( - self.iterator.source_sequence_length) + tf.reduce_sum( - self.iterator.target_sequence_length) - elif self.mode == tf.contrib.learn.ModeKeys.EVAL: - self.eval_loss = res[1] - elif self.mode == tf.contrib.learn.ModeKeys.INFER: - self.infer_logits, _, self.final_context_state, self.sample_id = res - self.sample_words = reverse_target_vocab_table.lookup( - tf.to_int64(self.sample_id)) - - if self.mode != tf.contrib.learn.ModeKeys.INFER: - ## Count the number of predicted words for compute ppl. - self.predict_count = tf.reduce_sum( - self.iterator.target_sequence_length) - - params = tf.trainable_variables() - - # Gradients and SGD update operation for training the model. - # Arrange for the embedding vars to appear at the beginning. - if self.mode == tf.contrib.learn.ModeKeys.TRAIN: - self.learning_rate = tf.constant(hparams.learning_rate) - # warm-up - self.learning_rate = self._get_learning_rate_warmup(hparams) - # decay - self.learning_rate = self._get_learning_rate_decay(hparams) - - # Optimizer - if hparams.optimizer == "sgd": - opt = tf.train.GradientDescentOptimizer(self.learning_rate) - elif hparams.optimizer == "adam": - opt = tf.train.AdamOptimizer(self.learning_rate) - else: - raise ValueError("Unknown optimizer type %s" % hparams.optimizer) - - # Gradients - gradients = tf.gradients( - self.train_loss, - params, - colocate_gradients_with_ops=hparams.colocate_gradients_with_ops) - - clipped_grads, grad_norm_summary, grad_norm = model_helper.gradient_clip( - gradients, max_gradient_norm=hparams.max_gradient_norm) - self.grad_norm_summary = grad_norm_summary - self.grad_norm = grad_norm - - self.update = opt.apply_gradients( - zip(clipped_grads, params), global_step=self.global_step) - - # Summary - self.train_summary = self._get_train_summary() - elif self.mode == tf.contrib.learn.ModeKeys.INFER: - self.infer_summary = self._get_infer_summary(hparams) - - # Print trainable variables - utils.print_out("# Trainable variables") - utils.print_out("Format: , , <(soft) device placement>") - for param in params: - utils.print_out(" %s, %s, %s" % (param.name, str(param.get_shape()), - param.op.device)) - - def _get_learning_rate_warmup(self, hparams): - """Get learning rate warmup.""" - warmup_steps = hparams.warmup_steps - warmup_scheme = hparams.warmup_scheme - utils.print_out(" learning_rate=%g, warmup_steps=%d, warmup_scheme=%s" % - (hparams.learning_rate, warmup_steps, warmup_scheme)) - - # Apply inverse decay if global steps less than warmup steps. - # Inspired by https://arxiv.org/pdf/1706.03762.pdf (Section 5.3) - # When step < warmup_steps, - # learing_rate *= warmup_factor ** (warmup_steps - step) - if warmup_scheme == "t2t": - # 0.01^(1/warmup_steps): we start with a lr, 100 times smaller - warmup_factor = tf.exp(tf.log(0.01) / warmup_steps) - inv_decay = warmup_factor**( - tf.to_float(warmup_steps - self.global_step)) - else: - raise ValueError("Unknown warmup scheme %s" % warmup_scheme) - - return tf.cond( - self.global_step < hparams.warmup_steps, - lambda: inv_decay * self.learning_rate, - lambda: self.learning_rate, - name="learning_rate_warmup_cond") - - def _get_decay_info(self, hparams): - """Return decay info based on decay_scheme.""" - if hparams.decay_scheme in ["luong5", "luong10", "luong234"]: - decay_factor = 0.5 - if hparams.decay_scheme == "luong5": - start_decay_step = int(hparams.num_train_steps / 2) - decay_times = 5 - elif hparams.decay_scheme == "luong10": - start_decay_step = int(hparams.num_train_steps / 2) - decay_times = 10 - elif hparams.decay_scheme == "luong234": - start_decay_step = int(hparams.num_train_steps * 2 / 3) - decay_times = 4 - remain_steps = hparams.num_train_steps - start_decay_step - decay_steps = int(remain_steps / decay_times) - elif not hparams.decay_scheme: # no decay - start_decay_step = hparams.num_train_steps - decay_steps = 0 - decay_factor = 1.0 - elif hparams.decay_scheme: - raise ValueError("Unknown decay scheme %s" % hparams.decay_scheme) - return start_decay_step, decay_steps, decay_factor - - def _get_learning_rate_decay(self, hparams): - """Get learning rate decay.""" - start_decay_step, decay_steps, decay_factor = self._get_decay_info(hparams) - utils.print_out(" decay_scheme=%s, start_decay_step=%d, decay_steps %d, " - "decay_factor %g" % (hparams.decay_scheme, - start_decay_step, - decay_steps, - decay_factor)) - - return tf.cond( - self.global_step < start_decay_step, - lambda: self.learning_rate, - lambda: tf.train.exponential_decay( - self.learning_rate, - (self.global_step - start_decay_step), - decay_steps, decay_factor, staircase=True), - name="learning_rate_decay_cond") - - def init_embeddings(self, hparams, scope): - """Init embeddings.""" - self.embedding_encoder, self.embedding_decoder = ( - model_helper.create_emb_for_encoder_and_decoder( - share_vocab=hparams.share_vocab, - src_vocab_size=self.src_vocab_size, - tgt_vocab_size=self.tgt_vocab_size, - src_embed_size=self.num_units, - tgt_embed_size=self.num_units, - num_enc_partitions=hparams.num_enc_emb_partitions, - num_dec_partitions=hparams.num_dec_emb_partitions, - src_vocab_file=hparams.src_vocab_file, - tgt_vocab_file=hparams.tgt_vocab_file, - src_embed_file=hparams.src_embed_file, - tgt_embed_file=hparams.tgt_embed_file, - use_char_encode=hparams.use_char_encode, - scope=scope,)) - - def _get_train_summary(self): - """Get train summary.""" - train_summary = tf.summary.merge( - [tf.summary.scalar("lr", self.learning_rate), - tf.summary.scalar("train_loss", self.train_loss)] + - self.grad_norm_summary) - return train_summary - - def train(self, sess): - """Execute train graph.""" - assert self.mode == tf.contrib.learn.ModeKeys.TRAIN - output_tuple = TrainOutputTuple(train_summary=self.train_summary, - train_loss=self.train_loss, - predict_count=self.predict_count, - global_step=self.global_step, - word_count=self.word_count, - batch_size=self.batch_size, - grad_norm=self.grad_norm, - learning_rate=self.learning_rate) - return sess.run([self.update, output_tuple]) - - def eval(self, sess): - """Execute eval graph.""" - assert self.mode == tf.contrib.learn.ModeKeys.EVAL - output_tuple = EvalOutputTuple(eval_loss=self.eval_loss, - predict_count=self.predict_count, - batch_size=self.batch_size) - return sess.run(output_tuple) - - def build_graph(self, hparams, scope=None): - """Subclass must implement this method. - - Creates a sequence-to-sequence model with dynamic RNN decoder API. - Args: - hparams: Hyperparameter configurations. - scope: VariableScope for the created subgraph; default "dynamic_seq2seq". - - Returns: - A tuple of the form (logits, loss_tuple, final_context_state, sample_id), - where: - logits: float32 Tensor [batch_size x num_decoder_symbols]. - loss: loss = the total loss / batch_size. - final_context_state: the final state of decoder RNN. - sample_id: sampling indices. - - Raises: - ValueError: if encoder_type differs from mono and bi, or - attention_option is not (luong | scaled_luong | - bahdanau | normed_bahdanau). - """ - utils.print_out("# Creating %s graph ..." % self.mode) - - # Projection - if not self.extract_encoder_layers: - with tf.variable_scope(scope or "build_network"): - with tf.variable_scope("decoder/output_projection"): - self.output_layer = tf.layers.Dense( - self.tgt_vocab_size, use_bias=False, name="output_projection") - - with tf.variable_scope(scope or "dynamic_seq2seq", dtype=self.dtype): - # Encoder - if hparams.language_model: # no encoder for language modeling - utils.print_out(" language modeling: no encoder") - self.encoder_outputs = None - encoder_state = None - else: - self.encoder_outputs, encoder_state = self._build_encoder(hparams) - - # Skip decoder if extracting only encoder layers - if self.extract_encoder_layers: - return - - ## Decoder - logits, decoder_cell_outputs, sample_id, final_context_state = ( - self._build_decoder(self.encoder_outputs, encoder_state, hparams)) - - ## Loss - if self.mode != tf.contrib.learn.ModeKeys.INFER: - with tf.device(model_helper.get_device_str(self.num_encoder_layers - 1, - self.num_gpus)): - loss = self._compute_loss(logits, decoder_cell_outputs) - else: - loss = tf.constant(0.0) - - return logits, loss, final_context_state, sample_id - - @abc.abstractmethod - def _build_encoder(self, hparams): - """Subclass must implement this. - - Build and run an RNN encoder. - - Args: - hparams: Hyperparameters configurations. - - Returns: - A tuple of encoder_outputs and encoder_state. + """Sequence-to-sequence base class. """ - pass - def _build_encoder_cell(self, hparams, num_layers, num_residual_layers, - base_gpu=0): - """Build a multi-layer RNN cell that can be used by encoder.""" - - return model_helper.create_rnn_cell( - unit_type=hparams.unit_type, - num_units=self.num_units, - num_layers=num_layers, - num_residual_layers=num_residual_layers, - forget_bias=hparams.forget_bias, - dropout=hparams.dropout, - num_gpus=hparams.num_gpus, - mode=self.mode, - base_gpu=base_gpu, - single_cell_fn=self.single_cell_fn) - - def _get_infer_maximum_iterations(self, hparams, source_sequence_length): - """Maximum decoding steps at inference time.""" - if hparams.tgt_max_len_infer: - maximum_iterations = hparams.tgt_max_len_infer - utils.print_out(" decoding maximum_iterations %d" % maximum_iterations) - else: - # TODO(thangluong): add decoding_length_factor flag - decoding_length_factor = 2.0 - max_encoder_length = tf.reduce_max(source_sequence_length) - maximum_iterations = tf.to_int32(tf.round( - tf.to_float(max_encoder_length) * decoding_length_factor)) - return maximum_iterations - - def _build_decoder(self, encoder_outputs, encoder_state, hparams): - """Build and run a RNN decoder with a final projection layer. - - Args: - encoder_outputs: The outputs of encoder for every time step. - encoder_state: The final state of the encoder. - hparams: The Hyperparameters configurations. - - Returns: - A tuple of final logits and final decoder state: - logits: size [time, batch_size, vocab_size] when time_major=True. - """ - tgt_sos_id = tf.cast(self.tgt_vocab_table.lookup(tf.constant(hparams.sos)), - tf.int32) - tgt_eos_id = tf.cast(self.tgt_vocab_table.lookup(tf.constant(hparams.eos)), - tf.int32) - iterator = self.iterator - - # maximum_iteration: The maximum decoding steps. - maximum_iterations = self._get_infer_maximum_iterations( - hparams, iterator.source_sequence_length) - - ## Decoder. - with tf.variable_scope("decoder") as decoder_scope: - cell, decoder_initial_state = self._build_decoder_cell( - hparams, encoder_outputs, encoder_state, - iterator.source_sequence_length) - - # Optional ops depends on which mode we are in and which loss function we - # are using. - logits = tf.no_op() - decoder_cell_outputs = None - - ## Train or eval - if self.mode != tf.contrib.learn.ModeKeys.INFER: - # decoder_emp_inp: [max_time, batch_size, num_units] - target_input = iterator.target_input - if self.time_major: - target_input = tf.transpose(target_input) - decoder_emb_inp = tf.nn.embedding_lookup( - self.embedding_decoder, target_input) - - # Helper - helper = tf.contrib.seq2seq.TrainingHelper( - decoder_emb_inp, iterator.target_sequence_length, - time_major=self.time_major) - - # Decoder - my_decoder = tf.contrib.seq2seq.BasicDecoder( - cell, - helper, - decoder_initial_state,) - - # Dynamic decoding - outputs, final_context_state, _ = tf.contrib.seq2seq.dynamic_decode( - my_decoder, - output_time_major=self.time_major, - swap_memory=True, - scope=decoder_scope) - - sample_id = outputs.sample_id + def __init__(self, + hparams, + mode, + iterator, + source_vocab_table, + target_vocab_table, + reverse_target_vocab_table=None, + scope=None, + extra_args=None): + """Create the model. + + Args: + hparams: Hyperparameter configurations. + mode: TRAIN | EVAL | INFER + iterator: Dataset Iterator that feeds data. + source_vocab_table: Lookup table mapping source words to ids. + target_vocab_table: Lookup table mapping target words to ids. + reverse_target_vocab_table: Lookup table mapping ids to target words. Only + required in INFER mode. Defaults to None. + scope: scope of the model. + extra_args: model_helper.ExtraArgs, for passing customizable functions. + + """ + # Set params + self._set_params_initializer(hparams, mode, iterator, + source_vocab_table, target_vocab_table, + scope, extra_args) + + # Not used in general seq2seq models; when True, ignore decoder & training + self.extract_encoder_layers = (hasattr(hparams, "extract_encoder_layers") + and hparams.extract_encoder_layers) + + # Train graph + res = self.build_graph(hparams, scope=scope) + if not self.extract_encoder_layers: + self._set_train_or_infer(res, reverse_target_vocab_table, hparams) + + # Saver + self.saver = tf.train.Saver( + tf.global_variables(), max_to_keep=hparams.num_keep_ckpts) + + def _set_params_initializer(self, + hparams, + mode, + iterator, + source_vocab_table, + target_vocab_table, + scope, + extra_args=None): + """Set various params for self and initialize.""" + assert isinstance(iterator, iterator_utils.BatchedInput) + self.iterator = iterator + self.mode = mode + self.src_vocab_table = source_vocab_table + self.tgt_vocab_table = target_vocab_table + + self.src_vocab_size = hparams.src_vocab_size + self.tgt_vocab_size = hparams.tgt_vocab_size + self.num_gpus = hparams.num_gpus + self.time_major = hparams.time_major + + if hparams.use_char_encode: + assert (not self.time_major), ("Can't use time major for" + " char-level inputs.") + + self.dtype = tf.float32 + self.num_sampled_softmax = hparams.num_sampled_softmax + + # extra_args: to make it flexible for adding external customizable code + self.single_cell_fn = None + if extra_args: + self.single_cell_fn = extra_args.single_cell_fn + + # Set num units + self.num_units = hparams.num_units + + # Set num layers + self.num_encoder_layers = hparams.num_encoder_layers + self.num_decoder_layers = hparams.num_decoder_layers + assert self.num_encoder_layers + assert self.num_decoder_layers + + # Set num residual layers + if hasattr(hparams, "num_residual_layers"): # compatible common_test_utils + self.num_encoder_residual_layers = hparams.num_residual_layers + self.num_decoder_residual_layers = hparams.num_residual_layers + else: + self.num_encoder_residual_layers = hparams.num_encoder_residual_layers + self.num_decoder_residual_layers = hparams.num_decoder_residual_layers - if self.num_sampled_softmax > 0: - # Note: this is required when using sampled_softmax_loss. - decoder_cell_outputs = outputs.rnn_output - - # Note: there's a subtle difference here between train and inference. - # We could have set output_layer when create my_decoder - # and shared more code between train and inference. - # We chose to apply the output_layer to all timesteps for speed: - # 10% improvements for small models & 20% for larger ones. - # If memory is a concern, we should apply output_layer per timestep. - num_layers = self.num_decoder_layers - num_gpus = self.num_gpus - device_id = num_layers if num_layers < num_gpus else (num_layers - 1) - # Colocate output layer with the last RNN cell if there is no extra GPU - # available. Otherwise, put last layer on a separate GPU. - with tf.device(model_helper.get_device_str(device_id, num_gpus)): - logits = self.output_layer(outputs.rnn_output) + # Batch size + self.batch_size = tf.size(self.iterator.source_sequence_length) - if self.num_sampled_softmax > 0: - logits = tf.no_op() # unused when using sampled softmax loss. - - ## Inference - else: - infer_mode = hparams.infer_mode - start_tokens = tf.fill([self.batch_size], tgt_sos_id) - end_token = tgt_eos_id - utils.print_out( - " decoder: infer_mode=%sbeam_width=%d, " - "length_penalty=%f, coverage_penalty=%f" - % (infer_mode, hparams.beam_width, hparams.length_penalty_weight, - hparams.coverage_penalty_weight)) - - if infer_mode == "beam_search": - beam_width = hparams.beam_width - length_penalty_weight = hparams.length_penalty_weight - coverage_penalty_weight = hparams.coverage_penalty_weight - - my_decoder = tf.contrib.seq2seq.BeamSearchDecoder( - cell=cell, - embedding=self.embedding_decoder, - start_tokens=start_tokens, - end_token=end_token, - initial_state=decoder_initial_state, - beam_width=beam_width, - output_layer=self.output_layer, - length_penalty_weight=length_penalty_weight, - coverage_penalty_weight=coverage_penalty_weight) - elif infer_mode == "sample": - # Helper - sampling_temperature = hparams.sampling_temperature - assert sampling_temperature > 0.0, ( - "sampling_temperature must greater than 0.0 when using sample" - " decoder.") - helper = tf.contrib.seq2seq.SampleEmbeddingHelper( - self.embedding_decoder, start_tokens, end_token, - softmax_temperature=sampling_temperature, - seed=self.random_seed) - elif infer_mode == "greedy": - helper = tf.contrib.seq2seq.GreedyEmbeddingHelper( - self.embedding_decoder, start_tokens, end_token) + # Global step + self.global_step = tf.Variable(0, trainable=False) + + # Initializer + self.random_seed = hparams.random_seed + initializer = model_helper.get_initializer( + hparams.init_op, self.random_seed, hparams.init_weight) + tf.get_variable_scope().set_initializer(initializer) + + # Embeddings + if extra_args and extra_args.encoder_emb_lookup_fn: + self.encoder_emb_lookup_fn = extra_args.encoder_emb_lookup_fn else: - raise ValueError("Unknown infer_mode '%s'", infer_mode) - - if infer_mode != "beam_search": - my_decoder = tf.contrib.seq2seq.BasicDecoder( - cell, - helper, - decoder_initial_state, - output_layer=self.output_layer # applied per timestep - ) - - # Dynamic decoding - outputs, final_context_state, _ = tf.contrib.seq2seq.dynamic_decode( - my_decoder, - maximum_iterations=maximum_iterations, - output_time_major=self.time_major, - swap_memory=True, - scope=decoder_scope) - - if infer_mode == "beam_search": - sample_id = outputs.predicted_ids + self.encoder_emb_lookup_fn = tf.nn.embedding_lookup + self.init_embeddings(hparams, scope) + + def _set_train_or_infer(self, res, reverse_target_vocab_table, hparams): + """Set up training and inference.""" + if self.mode == tf.contrib.learn.ModeKeys.TRAIN: + self.train_loss = res[1] + self.word_count = tf.reduce_sum( + self.iterator.source_sequence_length) + tf.reduce_sum( + self.iterator.target_sequence_length) + elif self.mode == tf.contrib.learn.ModeKeys.EVAL: + self.eval_loss = res[1] + elif self.mode == tf.contrib.learn.ModeKeys.INFER: + self.infer_logits, _, self.final_context_state, self.sample_id = res + self.sample_words = reverse_target_vocab_table.lookup( + tf.to_int64(self.sample_id)) + + if self.mode != tf.contrib.learn.ModeKeys.INFER: + # Count the number of predicted words for compute ppl. + self.predict_count = tf.reduce_sum( + self.iterator.target_sequence_length) + + params = tf.trainable_variables() + + # Gradients and SGD update operation for training the model. + # Arrange for the embedding vars to appear at the beginning. + if self.mode == tf.contrib.learn.ModeKeys.TRAIN: + self.learning_rate = tf.constant(hparams.learning_rate) + # warm-up + self.learning_rate = self._get_learning_rate_warmup(hparams) + # decay + self.learning_rate = self._get_learning_rate_decay(hparams) + + # Optimizer + if hparams.optimizer == "sgd": + opt = tf.train.GradientDescentOptimizer(self.learning_rate) + elif hparams.optimizer == "adam": + opt = tf.train.AdamOptimizer(self.learning_rate) + else: + raise ValueError("Unknown optimizer type %s" % hparams.optimizer) + + # Gradients + gradients = tf.gradients( + self.train_loss, + params, + colocate_gradients_with_ops=hparams.colocate_gradients_with_ops) + + clipped_grads, grad_norm_summary, grad_norm = model_helper.gradient_clip( + gradients, max_gradient_norm=hparams.max_gradient_norm) + self.grad_norm_summary = grad_norm_summary + self.grad_norm = grad_norm + + self.update = opt.apply_gradients( + zip(clipped_grads, params), global_step=self.global_step) + + # Summary + self.train_summary = self._get_train_summary() + elif self.mode == tf.contrib.learn.ModeKeys.INFER: + self.infer_summary = self._get_infer_summary(hparams) + + # Print trainable variables + utils.print_out("# Trainable variables") + utils.print_out("Format: , , <(soft) device placement>") + for param in params: + utils.print_out(" %s, %s, %s" % (param.name, str(param.get_shape()), + param.op.device)) + + def _get_learning_rate_warmup(self, hparams): + """Get learning rate warmup.""" + warmup_steps = hparams.warmup_steps + warmup_scheme = hparams.warmup_scheme + utils.print_out(" learning_rate=%g, warmup_steps=%d, warmup_scheme=%s" % + (hparams.learning_rate, warmup_steps, warmup_scheme)) + + # Apply inverse decay if global steps less than warmup steps. + # Inspired by https://arxiv.org/pdf/1706.03762.pdf (Section 5.3) + # When step < warmup_steps, + # learing_rate *= warmup_factor ** (warmup_steps - step) + if warmup_scheme == "t2t": + # 0.01^(1/warmup_steps): we start with a lr, 100 times smaller + warmup_factor = tf.exp(tf.log(0.01) / warmup_steps) + inv_decay = warmup_factor**( + tf.to_float(warmup_steps - self.global_step)) else: - logits = outputs.rnn_output - sample_id = outputs.sample_id - - return logits, decoder_cell_outputs, sample_id, final_context_state + raise ValueError("Unknown warmup scheme %s" % warmup_scheme) + + return tf.cond( + self.global_step < hparams.warmup_steps, + lambda: inv_decay * self.learning_rate, + lambda: self.learning_rate, + name="learning_rate_warmup_cond") + + def _get_decay_info(self, hparams): + """Return decay info based on decay_scheme.""" + if hparams.decay_scheme in ["luong5", "luong10", "luong234"]: + decay_factor = 0.5 + if hparams.decay_scheme == "luong5": + start_decay_step = int(hparams.num_train_steps / 2) + decay_times = 5 + elif hparams.decay_scheme == "luong10": + start_decay_step = int(hparams.num_train_steps / 2) + decay_times = 10 + elif hparams.decay_scheme == "luong234": + start_decay_step = int(hparams.num_train_steps * 2 / 3) + decay_times = 4 + remain_steps = hparams.num_train_steps - start_decay_step + decay_steps = int(remain_steps / decay_times) + elif not hparams.decay_scheme: # no decay + start_decay_step = hparams.num_train_steps + decay_steps = 0 + decay_factor = 1.0 + elif hparams.decay_scheme: + raise ValueError("Unknown decay scheme %s" % hparams.decay_scheme) + return start_decay_step, decay_steps, decay_factor + + def _get_learning_rate_decay(self, hparams): + """Get learning rate decay.""" + start_decay_step, decay_steps, decay_factor = self._get_decay_info(hparams) + utils.print_out(" decay_scheme=%s, start_decay_step=%d, decay_steps %d, " + "decay_factor %g" % (hparams.decay_scheme, + start_decay_step, + decay_steps, + decay_factor)) + + return tf.cond( + self.global_step < start_decay_step, + lambda: self.learning_rate, + lambda: tf.train.exponential_decay( + self.learning_rate, + (self.global_step - start_decay_step), + decay_steps, decay_factor, staircase=True), + name="learning_rate_decay_cond") + + def init_embeddings(self, hparams, scope): + """Init embeddings.""" + self.embedding_encoder, self.embedding_decoder = ( + model_helper.create_emb_for_encoder_and_decoder( + share_vocab=hparams.share_vocab, + src_vocab_size=self.src_vocab_size, + tgt_vocab_size=self.tgt_vocab_size, + src_embed_size=self.num_units, + tgt_embed_size=self.num_units, + num_enc_partitions=hparams.num_enc_emb_partitions, + num_dec_partitions=hparams.num_dec_emb_partitions, + src_vocab_file=hparams.src_vocab_file, + tgt_vocab_file=hparams.tgt_vocab_file, + src_embed_file=hparams.src_embed_file, + tgt_embed_file=hparams.tgt_embed_file, + use_char_encode=hparams.use_char_encode, + scope=scope,)) + + def _get_train_summary(self): + """Get train summary.""" + train_summary = tf.summary.merge( + [tf.summary.scalar("lr", self.learning_rate), + tf.summary.scalar("train_loss", self.train_loss)] + + self.grad_norm_summary) + return train_summary + + def train(self, sess): + """Execute train graph.""" + assert self.mode == tf.contrib.learn.ModeKeys.TRAIN + output_tuple = TrainOutputTuple(train_summary=self.train_summary, + train_loss=self.train_loss, + predict_count=self.predict_count, + global_step=self.global_step, + word_count=self.word_count, + batch_size=self.batch_size, + grad_norm=self.grad_norm, + learning_rate=self.learning_rate) + return sess.run([self.update, output_tuple]) + + def eval(self, sess): + """Execute eval graph.""" + assert self.mode == tf.contrib.learn.ModeKeys.EVAL + output_tuple = EvalOutputTuple(eval_loss=self.eval_loss, + predict_count=self.predict_count, + batch_size=self.batch_size) + return sess.run(output_tuple) + + def build_graph(self, hparams, scope=None): + """Subclass must implement this method. + + Creates a sequence-to-sequence model with dynamic RNN decoder API. + Args: + hparams: Hyperparameter configurations. + scope: VariableScope for the created subgraph; default "dynamic_seq2seq". + + Returns: + A tuple of the form (logits, loss_tuple, final_context_state, sample_id), + where: + logits: float32 Tensor [batch_size x num_decoder_symbols]. + loss: loss = the total loss / batch_size. + final_context_state: the final state of decoder RNN. + sample_id: sampling indices. + + Raises: + ValueError: if encoder_type differs from mono and bi, or + attention_option is not (luong | scaled_luong | + bahdanau | normed_bahdanau). + """ + utils.print_out("# Creating %s graph ..." % self.mode) + + # Projection + if not self.extract_encoder_layers: + with tf.variable_scope(scope or "build_network"): + with tf.variable_scope("decoder/output_projection"): + self.output_layer = tf.layers.Dense( + self.tgt_vocab_size, use_bias=False, name="output_projection") + + with tf.variable_scope(scope or "dynamic_seq2seq", dtype=self.dtype): + # Encoder + if hparams.language_model: # no encoder for language modeling + utils.print_out(" language modeling: no encoder") + self.encoder_outputs = None + encoder_state = None + else: + self.encoder_outputs, encoder_state = self._build_encoder(hparams) + + # Skip decoder if extracting only encoder layers + if self.extract_encoder_layers: + return + + # Decoder + logits, decoder_cell_outputs, sample_id, final_context_state = ( + self._build_decoder(self.encoder_outputs, encoder_state, hparams)) + + # Loss + if self.mode != tf.contrib.learn.ModeKeys.INFER: + with tf.device(model_helper.get_device_str(self.num_encoder_layers - 1, + self.num_gpus)): + loss = self._compute_loss(logits, decoder_cell_outputs) + else: + loss = tf.constant(0.0) + + return logits, loss, final_context_state, sample_id + + @abc.abstractmethod + def _build_encoder(self, hparams): + """Subclass must implement this. + + Build and run an RNN encoder. + + Args: + hparams: Hyperparameters configurations. + + Returns: + A tuple of encoder_outputs and encoder_state. + """ + pass + + def _build_encoder_cell(self, hparams, num_layers, num_residual_layers, + base_gpu=0): + """Build a multi-layer RNN cell that can be used by encoder.""" + + return model_helper.create_rnn_cell( + unit_type=hparams.unit_type, + num_units=self.num_units, + num_layers=num_layers, + num_residual_layers=num_residual_layers, + forget_bias=hparams.forget_bias, + dropout=hparams.dropout, + num_gpus=hparams.num_gpus, + mode=self.mode, + base_gpu=base_gpu, + single_cell_fn=self.single_cell_fn) + + def _get_infer_maximum_iterations(self, hparams, source_sequence_length): + """Maximum decoding steps at inference time.""" + if hparams.tgt_max_len_infer: + maximum_iterations = hparams.tgt_max_len_infer + utils.print_out(" decoding maximum_iterations %d" % maximum_iterations) + else: + # TODO(thangluong): add decoding_length_factor flag + decoding_length_factor = 2.0 + max_encoder_length = tf.reduce_max(source_sequence_length) + maximum_iterations = tf.to_int32(tf.round( + tf.to_float(max_encoder_length) * decoding_length_factor)) + return maximum_iterations + + def _build_decoder(self, encoder_outputs, encoder_state, hparams): + """Build and run a RNN decoder with a final projection layer. + + Args: + encoder_outputs: The outputs of encoder for every time step. + encoder_state: The final state of the encoder. + hparams: The Hyperparameters configurations. + + Returns: + A tuple of final logits and final decoder state: + logits: size [time, batch_size, vocab_size] when time_major=True. + """ + tgt_sos_id = tf.cast(self.tgt_vocab_table.lookup(tf.constant(hparams.sos)), + tf.int32) + tgt_eos_id = tf.cast(self.tgt_vocab_table.lookup(tf.constant(hparams.eos)), + tf.int32) + iterator = self.iterator + + # maximum_iteration: The maximum decoding steps. + maximum_iterations = self._get_infer_maximum_iterations( + hparams, iterator.source_sequence_length) + + # Decoder. + with tf.variable_scope("decoder") as decoder_scope: + cell, decoder_initial_state = self._build_decoder_cell( + hparams, encoder_outputs, encoder_state, + iterator.source_sequence_length) + + # Optional ops depends on which mode we are in and which loss function we + # are using. + logits = tf.no_op() + decoder_cell_outputs = None + + ## Train or eval + if self.mode != tf.contrib.learn.ModeKeys.INFER: + # decoder_emp_inp: [max_time, batch_size, num_units] + target_input = iterator.target_input + if self.time_major: + target_input = tf.transpose(target_input) + decoder_emb_inp = tf.nn.embedding_lookup( + self.embedding_decoder, target_input) + + # Helper + helper = tf.contrib.seq2seq.TrainingHelper( + decoder_emb_inp, iterator.target_sequence_length, + time_major=self.time_major) + + # Decoder + my_decoder = tf.contrib.seq2seq.BasicDecoder( + cell, + helper, + decoder_initial_state,) + + # Dynamic decoding + outputs, final_context_state, _ = tf.contrib.seq2seq.dynamic_decode( + my_decoder, + output_time_major=self.time_major, + swap_memory=True, + scope=decoder_scope) + + sample_id = outputs.sample_id + + if self.num_sampled_softmax > 0: + # Note: this is required when using sampled_softmax_loss. + decoder_cell_outputs = outputs.rnn_output + + # Note: there's a subtle difference here between train and inference. + # We could have set output_layer when create my_decoder + # and shared more code between train and inference. + # We chose to apply the output_layer to all timesteps for speed: + # 10% improvements for small models & 20% for larger ones. + # If memory is a concern, we should apply output_layer per timestep. + num_layers = self.num_decoder_layers + num_gpus = self.num_gpus + device_id = num_layers if num_layers < num_gpus else (num_layers - 1) + # Colocate output layer with the last RNN cell if there is no extra GPU + # available. Otherwise, put last layer on a separate GPU. + with tf.device(model_helper.get_device_str(device_id, num_gpus)): + logits = self.output_layer(outputs.rnn_output) + + if self.num_sampled_softmax > 0: + logits = tf.no_op() # unused when using sampled softmax loss. + + # Inference + else: + infer_mode = hparams.infer_mode + start_tokens = tf.fill([self.batch_size], tgt_sos_id) + end_token = tgt_eos_id + utils.print_out( + " decoder: infer_mode=%sbeam_width=%d, " + "length_penalty=%f, coverage_penalty=%f" + % (infer_mode, hparams.beam_width, hparams.length_penalty_weight, + hparams.coverage_penalty_weight)) + + if infer_mode == "beam_search": + beam_width = hparams.beam_width + length_penalty_weight = hparams.length_penalty_weight + coverage_penalty_weight = hparams.coverage_penalty_weight + + my_decoder = tf.contrib.seq2seq.BeamSearchDecoder( + cell=cell, + embedding=self.embedding_decoder, + start_tokens=start_tokens, + end_token=end_token, + initial_state=decoder_initial_state, + beam_width=beam_width, + output_layer=self.output_layer, + length_penalty_weight=length_penalty_weight, + coverage_penalty_weight=coverage_penalty_weight) + elif infer_mode == "sample": + # Helper + sampling_temperature = hparams.sampling_temperature + assert sampling_temperature > 0.0, ( + "sampling_temperature must greater than 0.0 when using sample" + " decoder.") + helper = tf.contrib.seq2seq.SampleEmbeddingHelper( + self.embedding_decoder, start_tokens, end_token, + softmax_temperature=sampling_temperature, + seed=self.random_seed) + elif infer_mode == "greedy": + helper = tf.contrib.seq2seq.GreedyEmbeddingHelper( + self.embedding_decoder, start_tokens, end_token) + else: + raise ValueError("Unknown infer_mode '%s'", infer_mode) + + if infer_mode != "beam_search": + my_decoder = tf.contrib.seq2seq.BasicDecoder( + cell, + helper, + decoder_initial_state, + output_layer=self.output_layer # applied per timestep + ) + + # Dynamic decoding + outputs, final_context_state, _ = tf.contrib.seq2seq.dynamic_decode( + my_decoder, + maximum_iterations=maximum_iterations, + output_time_major=self.time_major, + swap_memory=True, + scope=decoder_scope) + + if infer_mode == "beam_search": + sample_id = outputs.predicted_ids + else: + logits = outputs.rnn_output + sample_id = outputs.sample_id + + return logits, decoder_cell_outputs, sample_id, final_context_state + + def get_max_time(self, tensor): + time_axis = 0 if self.time_major else 1 + return tensor.shape[time_axis].value or tf.shape(tensor)[time_axis] + + @abc.abstractmethod + def _build_decoder_cell(self, hparams, encoder_outputs, encoder_state, + source_sequence_length): + """Subclass must implement this. + + Args: + hparams: Hyperparameters configurations. + encoder_outputs: The outputs of encoder for every time step. + encoder_state: The final state of the encoder. + source_sequence_length: sequence length of encoder_outputs. + + Returns: + A tuple of a multi-layer RNN cell used by decoder and the initial state of + the decoder RNN. + """ + pass + + def _softmax_cross_entropy_loss( + self, logits, decoder_cell_outputs, labels): + """Compute softmax loss or sampled softmax loss.""" + if self.num_sampled_softmax > 0: - def get_max_time(self, tensor): - time_axis = 0 if self.time_major else 1 - return tensor.shape[time_axis].value or tf.shape(tensor)[time_axis] + is_sequence = (decoder_cell_outputs.shape.ndims == 3) - @abc.abstractmethod - def _build_decoder_cell(self, hparams, encoder_outputs, encoder_state, - source_sequence_length): - """Subclass must implement this. + if is_sequence: + labels = tf.reshape(labels, [-1, 1]) + inputs = tf.reshape(decoder_cell_outputs, [-1, self.num_units]) - Args: - hparams: Hyperparameters configurations. - encoder_outputs: The outputs of encoder for every time step. - encoder_state: The final state of the encoder. - source_sequence_length: sequence length of encoder_outputs. + crossent = tf.nn.sampled_softmax_loss( + weights=tf.transpose(self.output_layer.kernel), + biases=self.output_layer.bias or tf.zeros([self.tgt_vocab_size]), + labels=labels, + inputs=inputs, + num_sampled=self.num_sampled_softmax, + num_classes=self.tgt_vocab_size, + partition_strategy="div", + seed=self.random_seed) - Returns: - A tuple of a multi-layer RNN cell used by decoder and the initial state of - the decoder RNN. - """ - pass + if is_sequence: + if self.time_major: + crossent = tf.reshape(crossent, [-1, self.batch_size]) + else: + crossent = tf.reshape(crossent, [self.batch_size, -1]) - def _softmax_cross_entropy_loss( - self, logits, decoder_cell_outputs, labels): - """Compute softmax loss or sampled softmax loss.""" - if self.num_sampled_softmax > 0: + else: + crossent = tf.nn.sparse_softmax_cross_entropy_with_logits( + labels=labels, logits=logits) - is_sequence = (decoder_cell_outputs.shape.ndims == 3) + return crossent - if is_sequence: - labels = tf.reshape(labels, [-1, 1]) - inputs = tf.reshape(decoder_cell_outputs, [-1, self.num_units]) + def _compute_loss(self, logits, decoder_cell_outputs): + """Compute optimization loss.""" + target_output = self.iterator.target_output + if self.time_major: + target_output = tf.transpose(target_output) + max_time = self.get_max_time(target_output) - crossent = tf.nn.sampled_softmax_loss( - weights=tf.transpose(self.output_layer.kernel), - biases=self.output_layer.bias or tf.zeros([self.tgt_vocab_size]), - labels=labels, - inputs=inputs, - num_sampled=self.num_sampled_softmax, - num_classes=self.tgt_vocab_size, - partition_strategy="div", - seed=self.random_seed) + crossent = self._softmax_cross_entropy_loss( + logits, decoder_cell_outputs, target_output) - if is_sequence: + target_weights = tf.sequence_mask( + self.iterator.target_sequence_length, max_time, dtype=self.dtype) if self.time_major: - crossent = tf.reshape(crossent, [-1, self.batch_size]) + target_weights = tf.transpose(target_weights) + + loss = tf.reduce_sum( + crossent * target_weights) / tf.to_float(self.batch_size) + return loss + + def _get_infer_summary(self, hparams): + del hparams + return tf.no_op() + + def infer(self, sess): + assert self.mode == tf.contrib.learn.ModeKeys.INFER + output_tuple = InferOutputTuple(infer_logits=self.infer_logits, + infer_summary=self.infer_summary, + sample_id=self.sample_id, + sample_words=self.sample_words) + return sess.run(output_tuple) + + def decode(self, sess): + """Decode a batch. + + Args: + sess: tensorflow session to use. + + Returns: + A tuple consiting of outputs, infer_summary. + outputs: of size [batch_size, time] + """ + output_tuple = self.infer(sess) + sample_words = output_tuple.sample_words + infer_summary = output_tuple.infer_summary + + # make sure outputs is of shape [batch_size, time] or [beam_width, + # batch_size, time] when using beam search. + if self.time_major: + sample_words = sample_words.transpose() + elif sample_words.ndim == 3: + # beam search output in [batch_size, time, beam_width] shape. + sample_words = sample_words.transpose([2, 0, 1]) + return sample_words, infer_summary + + def build_encoder_states(self, include_embeddings=False): + """Stack encoder states and return tensor [batch, length, layer, size].""" + assert self.mode == tf.contrib.learn.ModeKeys.INFER + if include_embeddings: + stack_state_list = tf.stack( + [self.encoder_emb_inp] + self.encoder_state_list, 2) else: - crossent = tf.reshape(crossent, [self.batch_size, -1]) - - else: - crossent = tf.nn.sparse_softmax_cross_entropy_with_logits( - labels=labels, logits=logits) - - return crossent - - def _compute_loss(self, logits, decoder_cell_outputs): - """Compute optimization loss.""" - target_output = self.iterator.target_output - if self.time_major: - target_output = tf.transpose(target_output) - max_time = self.get_max_time(target_output) - - crossent = self._softmax_cross_entropy_loss( - logits, decoder_cell_outputs, target_output) - - target_weights = tf.sequence_mask( - self.iterator.target_sequence_length, max_time, dtype=self.dtype) - if self.time_major: - target_weights = tf.transpose(target_weights) - - loss = tf.reduce_sum( - crossent * target_weights) / tf.to_float(self.batch_size) - return loss - - def _get_infer_summary(self, hparams): - del hparams - return tf.no_op() - - def infer(self, sess): - assert self.mode == tf.contrib.learn.ModeKeys.INFER - output_tuple = InferOutputTuple(infer_logits=self.infer_logits, - infer_summary=self.infer_summary, - sample_id=self.sample_id, - sample_words=self.sample_words) - return sess.run(output_tuple) - - def decode(self, sess): - """Decode a batch. - - Args: - sess: tensorflow session to use. - - Returns: - A tuple consiting of outputs, infer_summary. - outputs: of size [batch_size, time] - """ - output_tuple = self.infer(sess) - sample_words = output_tuple.sample_words - infer_summary = output_tuple.infer_summary - - # make sure outputs is of shape [batch_size, time] or [beam_width, - # batch_size, time] when using beam search. - if self.time_major: - sample_words = sample_words.transpose() - elif sample_words.ndim == 3: - # beam search output in [batch_size, time, beam_width] shape. - sample_words = sample_words.transpose([2, 0, 1]) - return sample_words, infer_summary - - def build_encoder_states(self, include_embeddings=False): - """Stack encoder states and return tensor [batch, length, layer, size].""" - assert self.mode == tf.contrib.learn.ModeKeys.INFER - if include_embeddings: - stack_state_list = tf.stack( - [self.encoder_emb_inp] + self.encoder_state_list, 2) - else: - stack_state_list = tf.stack(self.encoder_state_list, 2) - - # transform from [length, batch, ...] -> [batch, length, ...] - if self.time_major: - stack_state_list = tf.transpose(stack_state_list, [1, 0, 2, 3]) - - return stack_state_list + stack_state_list = tf.stack(self.encoder_state_list, 2) + + # transform from [length, batch, ...] -> [batch, length, ...] + if self.time_major: + stack_state_list = tf.transpose(stack_state_list, [1, 0, 2, 3]) + + return stack_state_list class Model(BaseModel): - """Sequence-to-sequence dynamic model. + """Sequence-to-sequence dynamic model. - This class implements a multi-layer recurrent neural network as encoder, - and a multi-layer recurrent neural network decoder. - """ - def _build_encoder_from_sequence(self, hparams, sequence, sequence_length): - """Build an encoder from a sequence. + This class implements a multi-layer recurrent neural network as encoder, + and a multi-layer recurrent neural network decoder. + """ - Args: - hparams: hyperparameters. - sequence: tensor with input sequence data. - sequence_length: tensor with length of the input sequence. + def _build_encoder_from_sequence(self, hparams, sequence, sequence_length): + """Build an encoder from a sequence. - Returns: - encoder_outputs: RNN encoder outputs. - encoder_state: RNN encoder state. + Args: + hparams: hyperparameters. + sequence: tensor with input sequence data. + sequence_length: tensor with length of the input sequence. - Raises: - ValueError: if encoder_type is neither "uni" nor "bi". - """ - num_layers = self.num_encoder_layers - num_residual_layers = self.num_encoder_residual_layers - - if self.time_major: - sequence = tf.transpose(sequence) - - with tf.variable_scope("encoder") as scope: - dtype = scope.dtype - - if self.embedding_encoder is not None: - self.encoder_emb_inp = self.encoder_emb_lookup_fn( - self.embedding_encoder, sequence) - else: # KDQ controlled - # import pdb - # pdb.set_trace() - FLAGS = tf.flags.FLAGS - vocab_size = hparams.src_vocab_size - is_training = self.mode == tf.contrib.learn.ModeKeys.TRAIN - size = hparams.num_units - if FLAGS.kdq_type == "none": - inputs = full_embed(sequence, vocab_size, size) - else: - kdq_hparam = KDQhparam( - K=FLAGS.K, D=FLAGS.D, - kdq_type=FLAGS.kdq_type, - kdq_d_in=FLAGS.kdq_d_in, - kdq_share_subspace=FLAGS.kdq_share_subspace, - additive_quantization=FLAGS.additive_quantization) - inputs = kdq_embed( - sequence, vocab_size, size, kdq_hparam, is_training) - self.encoder_emb_inp = inputs - - # Encoder_outputs: [max_time, batch_size, num_units] - if hparams.encoder_type == "uni": - utils.print_out(" num_layers = %d, num_residual_layers=%d" % - (num_layers, num_residual_layers)) - cell = self._build_encoder_cell(hparams, num_layers, - num_residual_layers) - - encoder_outputs, encoder_state = tf.nn.dynamic_rnn( - cell, - self.encoder_emb_inp, + Returns: + encoder_outputs: RNN encoder outputs. + encoder_state: RNN encoder state. + + Raises: + ValueError: if encoder_type is neither "uni" nor "bi". + """ + num_layers = self.num_encoder_layers + num_residual_layers = self.num_encoder_residual_layers + + if self.time_major: + sequence = tf.transpose(sequence) + + with tf.variable_scope("encoder") as scope: + dtype = scope.dtype + + if self.embedding_encoder is not None: + self.encoder_emb_inp = self.encoder_emb_lookup_fn( + self.embedding_encoder, sequence) + else: # KDQ controlled + # import pdb + # pdb.set_trace() + FLAGS = tf.flags.FLAGS + vocab_size = hparams.src_vocab_size + is_training = self.mode == tf.contrib.learn.ModeKeys.TRAIN + size = hparams.num_units + if FLAGS.kdq_type == "none": + inputs = full_embed(sequence, vocab_size, size) + else: + kdq_hparam = KDQhparam( + K=FLAGS.K, D=FLAGS.D, + kdq_type=FLAGS.kdq_type, + kdq_d_in=FLAGS.kdq_d_in, + kdq_share_subspace=FLAGS.kdq_share_subspace, + additive_quantization=FLAGS.additive_quantization) + inputs = kdq_embed( + sequence, vocab_size, size, kdq_hparam, is_training) + self.encoder_emb_inp = inputs + + # Encoder_outputs: [max_time, batch_size, num_units] + if hparams.encoder_type == "uni": + utils.print_out(" num_layers = %d, num_residual_layers=%d" % + (num_layers, num_residual_layers)) + cell = self._build_encoder_cell(hparams, num_layers, + num_residual_layers) + + encoder_outputs, encoder_state = tf.nn.dynamic_rnn( + cell, + self.encoder_emb_inp, + dtype=dtype, + sequence_length=sequence_length, + time_major=self.time_major, + swap_memory=True) + elif hparams.encoder_type == "bi": + num_bi_layers = int(num_layers / 2) + num_bi_residual_layers = int(num_residual_layers / 2) + utils.print_out(" num_bi_layers = %d, num_bi_residual_layers=%d" % + (num_bi_layers, num_bi_residual_layers)) + + encoder_outputs, bi_encoder_state = ( + self._build_bidirectional_rnn( + inputs=self.encoder_emb_inp, + sequence_length=sequence_length, + dtype=dtype, + hparams=hparams, + num_bi_layers=num_bi_layers, + num_bi_residual_layers=num_bi_residual_layers)) + + if num_bi_layers == 1: + encoder_state = bi_encoder_state + else: + # alternatively concat forward and backward states + encoder_state = [] + for layer_id in range(num_bi_layers): + encoder_state.append(bi_encoder_state[0][layer_id]) # forward + encoder_state.append(bi_encoder_state[1][layer_id]) # backward + encoder_state = tuple(encoder_state) + else: + raise ValueError("Unknown encoder_type %s" % hparams.encoder_type) + + # Use the top layer for now + self.encoder_state_list = [encoder_outputs] + + return encoder_outputs, encoder_state + + def _build_encoder(self, hparams): + """Build encoder from source.""" + utils.print_out("# Build a basic encoder") + return self._build_encoder_from_sequence( + hparams, self.iterator.source, self.iterator.source_sequence_length) + + def _build_bidirectional_rnn(self, inputs, sequence_length, + dtype, hparams, + num_bi_layers, + num_bi_residual_layers, + base_gpu=0): + """Create and call biddirectional RNN cells. + + Args: + num_residual_layers: Number of residual layers from top to bottom. For + example, if `num_bi_layers=4` and `num_residual_layers=2`, the last 2 RNN + layers in each RNN cell will be wrapped with `ResidualWrapper`. + base_gpu: The gpu device id to use for the first forward RNN layer. The + i-th forward RNN layer will use `(base_gpu + i) % num_gpus` as its + device id. The `base_gpu` for backward RNN cell is `(base_gpu + + num_bi_layers)`. + + Returns: + The concatenated bidirectional output and the bidirectional RNN cell"s + state. + """ + # Construct forward and backward cells + fw_cell = self._build_encoder_cell(hparams, + num_bi_layers, + num_bi_residual_layers, + base_gpu=base_gpu) + bw_cell = self._build_encoder_cell(hparams, + num_bi_layers, + num_bi_residual_layers, + base_gpu=(base_gpu + num_bi_layers)) + + bi_outputs, bi_state = tf.nn.bidirectional_dynamic_rnn( + fw_cell, + bw_cell, + inputs, dtype=dtype, sequence_length=sequence_length, time_major=self.time_major, swap_memory=True) - elif hparams.encoder_type == "bi": - num_bi_layers = int(num_layers / 2) - num_bi_residual_layers = int(num_residual_layers / 2) - utils.print_out(" num_bi_layers = %d, num_bi_residual_layers=%d" % - (num_bi_layers, num_bi_residual_layers)) - - encoder_outputs, bi_encoder_state = ( - self._build_bidirectional_rnn( - inputs=self.encoder_emb_inp, - sequence_length=sequence_length, - dtype=dtype, - hparams=hparams, - num_bi_layers=num_bi_layers, - num_bi_residual_layers=num_bi_residual_layers)) - - if num_bi_layers == 1: - encoder_state = bi_encoder_state + + return tf.concat(bi_outputs, -1), bi_state + + def _build_decoder_cell(self, hparams, encoder_outputs, encoder_state, + source_sequence_length, base_gpu=0): + """Build an RNN cell that can be used by decoder.""" + # We only make use of encoder_outputs in attention-based models + if hparams.attention: + raise ValueError("BasicModel doesn't support attention.") + + cell = model_helper.create_rnn_cell( + unit_type=hparams.unit_type, + num_units=self.num_units, + num_layers=self.num_decoder_layers, + num_residual_layers=self.num_decoder_residual_layers, + forget_bias=hparams.forget_bias, + dropout=hparams.dropout, + num_gpus=self.num_gpus, + mode=self.mode, + single_cell_fn=self.single_cell_fn, + base_gpu=base_gpu + ) + + if hparams.language_model: + encoder_state = cell.zero_state(self.batch_size, self.dtype) + elif not hparams.pass_hidden_state: + raise ValueError("For non-attentional model, " + "pass_hidden_state needs to be set to True") + + # For beam search, we need to replicate encoder infos beam_width times + if (self.mode == tf.contrib.learn.ModeKeys.INFER and + hparams.infer_mode == "beam_search"): + decoder_initial_state = tf.contrib.seq2seq.tile_batch( + encoder_state, multiplier=hparams.beam_width) else: - # alternatively concat forward and backward states - encoder_state = [] - for layer_id in range(num_bi_layers): - encoder_state.append(bi_encoder_state[0][layer_id]) # forward - encoder_state.append(bi_encoder_state[1][layer_id]) # backward - encoder_state = tuple(encoder_state) - else: - raise ValueError("Unknown encoder_type %s" % hparams.encoder_type) - - # Use the top layer for now - self.encoder_state_list = [encoder_outputs] - - return encoder_outputs, encoder_state - - def _build_encoder(self, hparams): - """Build encoder from source.""" - utils.print_out("# Build a basic encoder") - return self._build_encoder_from_sequence( - hparams, self.iterator.source, self.iterator.source_sequence_length) - - def _build_bidirectional_rnn(self, inputs, sequence_length, - dtype, hparams, - num_bi_layers, - num_bi_residual_layers, - base_gpu=0): - """Create and call biddirectional RNN cells. - - Args: - num_residual_layers: Number of residual layers from top to bottom. For - example, if `num_bi_layers=4` and `num_residual_layers=2`, the last 2 RNN - layers in each RNN cell will be wrapped with `ResidualWrapper`. - base_gpu: The gpu device id to use for the first forward RNN layer. The - i-th forward RNN layer will use `(base_gpu + i) % num_gpus` as its - device id. The `base_gpu` for backward RNN cell is `(base_gpu + - num_bi_layers)`. - - Returns: - The concatenated bidirectional output and the bidirectional RNN cell"s - state. - """ - # Construct forward and backward cells - fw_cell = self._build_encoder_cell(hparams, - num_bi_layers, - num_bi_residual_layers, - base_gpu=base_gpu) - bw_cell = self._build_encoder_cell(hparams, - num_bi_layers, - num_bi_residual_layers, - base_gpu=(base_gpu + num_bi_layers)) - - bi_outputs, bi_state = tf.nn.bidirectional_dynamic_rnn( - fw_cell, - bw_cell, - inputs, - dtype=dtype, - sequence_length=sequence_length, - time_major=self.time_major, - swap_memory=True) - - return tf.concat(bi_outputs, -1), bi_state - - def _build_decoder_cell(self, hparams, encoder_outputs, encoder_state, - source_sequence_length, base_gpu=0): - """Build an RNN cell that can be used by decoder.""" - # We only make use of encoder_outputs in attention-based models - if hparams.attention: - raise ValueError("BasicModel doesn't support attention.") - - cell = model_helper.create_rnn_cell( - unit_type=hparams.unit_type, - num_units=self.num_units, - num_layers=self.num_decoder_layers, - num_residual_layers=self.num_decoder_residual_layers, - forget_bias=hparams.forget_bias, - dropout=hparams.dropout, - num_gpus=self.num_gpus, - mode=self.mode, - single_cell_fn=self.single_cell_fn, - base_gpu=base_gpu - ) - - if hparams.language_model: - encoder_state = cell.zero_state(self.batch_size, self.dtype) - elif not hparams.pass_hidden_state: - raise ValueError("For non-attentional model, " - "pass_hidden_state needs to be set to True") - - # For beam search, we need to replicate encoder infos beam_width times - if (self.mode == tf.contrib.learn.ModeKeys.INFER and - hparams.infer_mode == "beam_search"): - decoder_initial_state = tf.contrib.seq2seq.tile_batch( - encoder_state, multiplier=hparams.beam_width) - else: - decoder_initial_state = encoder_state - - return cell, decoder_initial_state + decoder_initial_state = encoder_state + + return cell, decoder_initial_state diff --git a/nmt/model_helper.py b/nmt/model_helper.py index d8571e5..5e537de 100644 --- a/nmt/model_helper.py +++ b/nmt/model_helper.py @@ -41,352 +41,352 @@ def get_initializer(init_op, seed=None, init_weight=None): - """Create an initializer. init_weight is only for uniform.""" - if init_op == "uniform": - assert init_weight - return tf.random_uniform_initializer( - -init_weight, init_weight, seed=seed) - elif init_op == "glorot_normal": - return tf.keras.initializers.glorot_normal( - seed=seed) - elif init_op == "glorot_uniform": - return tf.keras.initializers.glorot_uniform( - seed=seed) - else: - raise ValueError("Unknown init_op %s" % init_op) + """Create an initializer. init_weight is only for uniform.""" + if init_op == "uniform": + assert init_weight + return tf.random_uniform_initializer( + -init_weight, init_weight, seed=seed) + elif init_op == "glorot_normal": + return tf.keras.initializers.glorot_normal( + seed=seed) + elif init_op == "glorot_uniform": + return tf.keras.initializers.glorot_uniform( + seed=seed) + else: + raise ValueError("Unknown init_op %s" % init_op) def get_device_str(device_id, num_gpus): - """Return a device string for multi-GPU setup.""" - if num_gpus == 0: - return "/cpu:0" - device_str_output = "/gpu:%d" % (device_id % num_gpus) - return device_str_output + """Return a device string for multi-GPU setup.""" + if num_gpus == 0: + return "/cpu:0" + device_str_output = "/gpu:%d" % (device_id % num_gpus) + return device_str_output class ExtraArgs(collections.namedtuple( "ExtraArgs", ("single_cell_fn", "model_device_fn", "attention_mechanism_fn", "encoder_emb_lookup_fn"))): - pass + pass class TrainModel( collections.namedtuple("TrainModel", ("graph", "model", "iterator", "skip_count_placeholder"))): - pass + pass def create_train_model( - model_creator, hparams, scope=None, num_workers=1, jobid=0, - extra_args=None): - """Create train graph, model, and iterator.""" - src_file = "%s.%s" % (hparams.train_prefix, hparams.src) - tgt_file = "%s.%s" % (hparams.train_prefix, hparams.tgt) - src_vocab_file = hparams.src_vocab_file - tgt_vocab_file = hparams.tgt_vocab_file - - graph = tf.Graph() - - with graph.as_default(), tf.container(scope or "train"): - src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables( - src_vocab_file, tgt_vocab_file, hparams.share_vocab) - - src_dataset = tf.data.TextLineDataset(tf.gfile.Glob(src_file)) - tgt_dataset = tf.data.TextLineDataset(tf.gfile.Glob(tgt_file)) - skip_count_placeholder = tf.placeholder(shape=(), dtype=tf.int64) - - iterator = iterator_utils.get_iterator( - src_dataset, - tgt_dataset, - src_vocab_table, - tgt_vocab_table, - batch_size=hparams.batch_size, - sos=hparams.sos, - eos=hparams.eos, - random_seed=hparams.random_seed, - num_buckets=hparams.num_buckets, - src_max_len=hparams.src_max_len, - tgt_max_len=hparams.tgt_max_len, - skip_count=skip_count_placeholder, - num_shards=num_workers, - shard_index=jobid, - use_char_encode=hparams.use_char_encode) - - # Note: One can set model_device_fn to - # `tf.train.replica_device_setter(ps_tasks)` for distributed training. - model_device_fn = None - if extra_args: model_device_fn = extra_args.model_device_fn - with tf.device(model_device_fn): - model = model_creator( - hparams, - iterator=iterator, - mode=tf.contrib.learn.ModeKeys.TRAIN, - source_vocab_table=src_vocab_table, - target_vocab_table=tgt_vocab_table, - scope=scope, - extra_args=extra_args) - - return TrainModel( - graph=graph, - model=model, - iterator=iterator, - skip_count_placeholder=skip_count_placeholder) + model_creator, hparams, scope=None, num_workers=1, jobid=0, + extra_args=None): + """Create train graph, model, and iterator.""" + src_file = "%s.%s" % (hparams.train_prefix, hparams.src) + tgt_file = "%s.%s" % (hparams.train_prefix, hparams.tgt) + src_vocab_file = hparams.src_vocab_file + tgt_vocab_file = hparams.tgt_vocab_file + + graph = tf.Graph() + + with graph.as_default(), tf.container(scope or "train"): + src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables( + src_vocab_file, tgt_vocab_file, hparams.share_vocab) + + src_dataset = tf.data.TextLineDataset(tf.gfile.Glob(src_file)) + tgt_dataset = tf.data.TextLineDataset(tf.gfile.Glob(tgt_file)) + skip_count_placeholder = tf.placeholder(shape=(), dtype=tf.int64) + + iterator = iterator_utils.get_iterator( + src_dataset, + tgt_dataset, + src_vocab_table, + tgt_vocab_table, + batch_size=hparams.batch_size, + sos=hparams.sos, + eos=hparams.eos, + random_seed=hparams.random_seed, + num_buckets=hparams.num_buckets, + src_max_len=hparams.src_max_len, + tgt_max_len=hparams.tgt_max_len, + skip_count=skip_count_placeholder, + num_shards=num_workers, + shard_index=jobid, + use_char_encode=hparams.use_char_encode) + + # Note: One can set model_device_fn to + # `tf.train.replica_device_setter(ps_tasks)` for distributed training. + model_device_fn = None + if extra_args: + model_device_fn = extra_args.model_device_fn + with tf.device(model_device_fn): + model = model_creator( + hparams, + iterator=iterator, + mode=tf.contrib.learn.ModeKeys.TRAIN, + source_vocab_table=src_vocab_table, + target_vocab_table=tgt_vocab_table, + scope=scope, + extra_args=extra_args) + + return TrainModel( + graph=graph, + model=model, + iterator=iterator, + skip_count_placeholder=skip_count_placeholder) class EvalModel( collections.namedtuple("EvalModel", ("graph", "model", "src_file_placeholder", "tgt_file_placeholder", "iterator"))): - pass + pass def create_eval_model(model_creator, hparams, scope=None, extra_args=None): - """Create train graph, model, src/tgt file holders, and iterator.""" - src_vocab_file = hparams.src_vocab_file - tgt_vocab_file = hparams.tgt_vocab_file - graph = tf.Graph() - - with graph.as_default(), tf.container(scope or "eval"): - src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables( - src_vocab_file, tgt_vocab_file, hparams.share_vocab) - reverse_tgt_vocab_table = lookup_ops.index_to_string_table_from_file( - tgt_vocab_file, default_value=vocab_utils.UNK) - - src_file_placeholder = tf.placeholder(shape=(), dtype=tf.string) - tgt_file_placeholder = tf.placeholder(shape=(), dtype=tf.string) - src_dataset = tf.data.TextLineDataset(src_file_placeholder) - tgt_dataset = tf.data.TextLineDataset(tgt_file_placeholder) - iterator = iterator_utils.get_iterator( - src_dataset, - tgt_dataset, - src_vocab_table, - tgt_vocab_table, - hparams.batch_size, - sos=hparams.sos, - eos=hparams.eos, - random_seed=hparams.random_seed, - num_buckets=hparams.num_buckets, - src_max_len=hparams.src_max_len_infer, - tgt_max_len=hparams.tgt_max_len_infer, - use_char_encode=hparams.use_char_encode) - model = model_creator( - hparams, - iterator=iterator, - mode=tf.contrib.learn.ModeKeys.EVAL, - source_vocab_table=src_vocab_table, - target_vocab_table=tgt_vocab_table, - reverse_target_vocab_table=reverse_tgt_vocab_table, - scope=scope, - extra_args=extra_args) - return EvalModel( - graph=graph, - model=model, - src_file_placeholder=src_file_placeholder, - tgt_file_placeholder=tgt_file_placeholder, - iterator=iterator) + """Create train graph, model, src/tgt file holders, and iterator.""" + src_vocab_file = hparams.src_vocab_file + tgt_vocab_file = hparams.tgt_vocab_file + graph = tf.Graph() + + with graph.as_default(), tf.container(scope or "eval"): + src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables( + src_vocab_file, tgt_vocab_file, hparams.share_vocab) + reverse_tgt_vocab_table = lookup_ops.index_to_string_table_from_file( + tgt_vocab_file, default_value=vocab_utils.UNK) + + src_file_placeholder = tf.placeholder(shape=(), dtype=tf.string) + tgt_file_placeholder = tf.placeholder(shape=(), dtype=tf.string) + src_dataset = tf.data.TextLineDataset(src_file_placeholder) + tgt_dataset = tf.data.TextLineDataset(tgt_file_placeholder) + iterator = iterator_utils.get_iterator( + src_dataset, + tgt_dataset, + src_vocab_table, + tgt_vocab_table, + hparams.batch_size, + sos=hparams.sos, + eos=hparams.eos, + random_seed=hparams.random_seed, + num_buckets=hparams.num_buckets, + src_max_len=hparams.src_max_len_infer, + tgt_max_len=hparams.tgt_max_len_infer, + use_char_encode=hparams.use_char_encode) + model = model_creator( + hparams, + iterator=iterator, + mode=tf.contrib.learn.ModeKeys.EVAL, + source_vocab_table=src_vocab_table, + target_vocab_table=tgt_vocab_table, + reverse_target_vocab_table=reverse_tgt_vocab_table, + scope=scope, + extra_args=extra_args) + return EvalModel( + graph=graph, + model=model, + src_file_placeholder=src_file_placeholder, + tgt_file_placeholder=tgt_file_placeholder, + iterator=iterator) class InferModel( collections.namedtuple("InferModel", ("graph", "model", "src_placeholder", "batch_size_placeholder", "iterator"))): - pass + pass def create_infer_model(model_creator, hparams, scope=None, extra_args=None): - """Create inference model.""" - graph = tf.Graph() - src_vocab_file = hparams.src_vocab_file - tgt_vocab_file = hparams.tgt_vocab_file - - with graph.as_default(), tf.container(scope or "infer"): - src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables( - src_vocab_file, tgt_vocab_file, hparams.share_vocab) - reverse_tgt_vocab_table = lookup_ops.index_to_string_table_from_file( - tgt_vocab_file, default_value=vocab_utils.UNK) - - src_placeholder = tf.placeholder(shape=[None], dtype=tf.string) - batch_size_placeholder = tf.placeholder(shape=[], dtype=tf.int64) - - src_dataset = tf.data.Dataset.from_tensor_slices( - src_placeholder) - iterator = iterator_utils.get_infer_iterator( - src_dataset, - src_vocab_table, - batch_size=batch_size_placeholder, - eos=hparams.eos, - src_max_len=hparams.src_max_len_infer, - use_char_encode=hparams.use_char_encode) - model = model_creator( - hparams, - iterator=iterator, - mode=tf.contrib.learn.ModeKeys.INFER, - source_vocab_table=src_vocab_table, - target_vocab_table=tgt_vocab_table, - reverse_target_vocab_table=reverse_tgt_vocab_table, - scope=scope, - extra_args=extra_args) - return InferModel( - graph=graph, - model=model, - src_placeholder=src_placeholder, - batch_size_placeholder=batch_size_placeholder, - iterator=iterator) + """Create inference model.""" + graph = tf.Graph() + src_vocab_file = hparams.src_vocab_file + tgt_vocab_file = hparams.tgt_vocab_file + + with graph.as_default(), tf.container(scope or "infer"): + src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables( + src_vocab_file, tgt_vocab_file, hparams.share_vocab) + reverse_tgt_vocab_table = lookup_ops.index_to_string_table_from_file( + tgt_vocab_file, default_value=vocab_utils.UNK) + + src_placeholder = tf.placeholder(shape=[None], dtype=tf.string) + batch_size_placeholder = tf.placeholder(shape=[], dtype=tf.int64) + + src_dataset = tf.data.Dataset.from_tensor_slices( + src_placeholder) + iterator = iterator_utils.get_infer_iterator( + src_dataset, + src_vocab_table, + batch_size=batch_size_placeholder, + eos=hparams.eos, + src_max_len=hparams.src_max_len_infer, + use_char_encode=hparams.use_char_encode) + model = model_creator( + hparams, + iterator=iterator, + mode=tf.contrib.learn.ModeKeys.INFER, + source_vocab_table=src_vocab_table, + target_vocab_table=tgt_vocab_table, + reverse_target_vocab_table=reverse_tgt_vocab_table, + scope=scope, + extra_args=extra_args) + return InferModel( + graph=graph, + model=model, + src_placeholder=src_placeholder, + batch_size_placeholder=batch_size_placeholder, + iterator=iterator) def _get_embed_device(vocab_size): - """Decide on which device to place an embed matrix given its vocab size.""" - if vocab_size > VOCAB_SIZE_THRESHOLD_CPU: - return "/cpu:0" - else: - return "/gpu:0" + """Decide on which device to place an embed matrix given its vocab size.""" + if vocab_size > VOCAB_SIZE_THRESHOLD_CPU: + return "/cpu:0" + else: + return "/gpu:0" def _create_pretrained_emb_from_txt( - vocab_file, embed_file, num_trainable_tokens=3, dtype=tf.float32, - scope=None): - """Load pretrain embeding from embed_file, and return an embedding matrix. - - Args: - embed_file: Path to a Glove formatted embedding txt file. - num_trainable_tokens: Make the first n tokens in the vocab file as trainable - variables. Default is 3, which is "", "" and "". - """ - vocab, _ = vocab_utils.load_vocab(vocab_file) - trainable_tokens = vocab[:num_trainable_tokens] - - utils.print_out("# Using pretrained embedding: %s." % embed_file) - utils.print_out(" with trainable tokens: ") - - emb_dict, emb_size = vocab_utils.load_embed_txt(embed_file) - for token in trainable_tokens: - utils.print_out(" %s" % token) - if token not in emb_dict: - emb_dict[token] = [0.0] * emb_size - - emb_mat = np.array( - [emb_dict[token] for token in vocab], dtype=dtype.as_numpy_dtype()) - emb_mat = tf.constant(emb_mat) - emb_mat_const = tf.slice(emb_mat, [num_trainable_tokens, 0], [-1, -1]) - with tf.variable_scope(scope or "pretrain_embeddings", dtype=dtype) as scope: - with tf.device(_get_embed_device(num_trainable_tokens)): - emb_mat_var = tf.get_variable( - "emb_mat_var", [num_trainable_tokens, emb_size]) - return tf.concat([emb_mat_var, emb_mat_const], 0) + vocab_file, embed_file, num_trainable_tokens=3, dtype=tf.float32, + scope=None): + """Load pretrain embeding from embed_file, and return an embedding matrix. + + Args: + embed_file: Path to a Glove formatted embedding txt file. + num_trainable_tokens: Make the first n tokens in the vocab file as trainable + variables. Default is 3, which is "", "" and "". + """ + vocab, _ = vocab_utils.load_vocab(vocab_file) + trainable_tokens = vocab[:num_trainable_tokens] + + utils.print_out("# Using pretrained embedding: %s." % embed_file) + utils.print_out(" with trainable tokens: ") + + emb_dict, emb_size = vocab_utils.load_embed_txt(embed_file) + for token in trainable_tokens: + utils.print_out(" %s" % token) + if token not in emb_dict: + emb_dict[token] = [0.0] * emb_size + + emb_mat = np.array( + [emb_dict[token] for token in vocab], dtype=dtype.as_numpy_dtype()) + emb_mat = tf.constant(emb_mat) + emb_mat_const = tf.slice(emb_mat, [num_trainable_tokens, 0], [-1, -1]) + with tf.variable_scope(scope or "pretrain_embeddings", dtype=dtype) as scope: + with tf.device(_get_embed_device(num_trainable_tokens)): + emb_mat_var = tf.get_variable( + "emb_mat_var", [num_trainable_tokens, emb_size]) + return tf.concat([emb_mat_var, emb_mat_const], 0) def _create_or_load_embed(embed_name, vocab_file, embed_file, vocab_size, embed_size, dtype): - """Create a new or load an existing embedding matrix.""" - if vocab_file and embed_file: - embedding = _create_pretrained_emb_from_txt(vocab_file, embed_file) - else: - with tf.device(_get_embed_device(vocab_size)): - embedding = tf.get_variable( - embed_name, [vocab_size, embed_size], dtype) - return embedding + """Create a new or load an existing embedding matrix.""" + if vocab_file and embed_file: + embedding = _create_pretrained_emb_from_txt(vocab_file, embed_file) + else: + with tf.device(_get_embed_device(vocab_size)): + embedding = tf.get_variable( + embed_name, [vocab_size, embed_size], dtype) + return embedding def _create_emb_for_encoder_and_decoder(share_vocab, - src_vocab_size, - tgt_vocab_size, - src_embed_size, - tgt_embed_size, - dtype=tf.float32, - num_enc_partitions=0, - num_dec_partitions=0, - src_vocab_file=None, - tgt_vocab_file=None, - src_embed_file=None, - tgt_embed_file=None, - use_char_encode=False, - scope=None): - """Create embedding matrix for both encoder and decoder. - - Args: - share_vocab: A boolean. Whether to share embedding matrix for both - encoder and decoder. - src_vocab_size: An integer. The source vocab size. - tgt_vocab_size: An integer. The target vocab size. - src_embed_size: An integer. The embedding dimension for the encoder's - embedding. - tgt_embed_size: An integer. The embedding dimension for the decoder's - embedding. - dtype: dtype of the embedding matrix. Default to float32. - num_enc_partitions: number of partitions used for the encoder's embedding - vars. - num_dec_partitions: number of partitions used for the decoder's embedding - vars. - scope: VariableScope for the created subgraph. Default to "embedding". - - Returns: - embedding_encoder: Encoder's embedding matrix. - embedding_decoder: Decoder's embedding matrix. - - Raises: - ValueError: if use share_vocab but source and target have different vocab - size. - """ - if num_enc_partitions <= 1: - enc_partitioner = None - else: - # Note: num_partitions > 1 is required for distributed training due to - # embedding_lookup tries to colocate single partition-ed embedding variable - # with lookup ops. This may cause embedding variables being placed on worker - # jobs. - enc_partitioner = tf.fixed_size_partitioner(num_enc_partitions) - - if num_dec_partitions <= 1: - dec_partitioner = None - else: - # Note: num_partitions > 1 is required for distributed training due to - # embedding_lookup tries to colocate single partition-ed embedding variable - # with lookup ops. This may cause embedding variables being placed on worker - # jobs. - dec_partitioner = tf.fixed_size_partitioner(num_dec_partitions) - - if src_embed_file and enc_partitioner: - raise ValueError( - "Can't set num_enc_partitions > 1 when using pretrained encoder " - "embedding") - - if tgt_embed_file and dec_partitioner: - raise ValueError( - "Can't set num_dec_partitions > 1 when using pretrained decdoer " - "embedding") - - with tf.variable_scope( - scope or "embeddings", dtype=dtype, partitioner=enc_partitioner) as scope: - # Share embedding - if share_vocab: - if src_vocab_size != tgt_vocab_size: - raise ValueError("Share embedding but different src/tgt vocab sizes" - " %d vs. %d" % (src_vocab_size, tgt_vocab_size)) - assert src_embed_size == tgt_embed_size - utils.print_out("# Use the same embedding for source and target") - vocab_file = src_vocab_file or tgt_vocab_file - embed_file = src_embed_file or tgt_embed_file - - embedding_encoder = _create_or_load_embed( - "embedding_share", vocab_file, embed_file, - src_vocab_size, src_embed_size, dtype) - embedding_decoder = embedding_encoder + src_vocab_size, + tgt_vocab_size, + src_embed_size, + tgt_embed_size, + dtype=tf.float32, + num_enc_partitions=0, + num_dec_partitions=0, + src_vocab_file=None, + tgt_vocab_file=None, + src_embed_file=None, + tgt_embed_file=None, + use_char_encode=False, + scope=None): + """Create embedding matrix for both encoder and decoder. + + Args: + share_vocab: A boolean. Whether to share embedding matrix for both + encoder and decoder. + src_vocab_size: An integer. The source vocab size. + tgt_vocab_size: An integer. The target vocab size. + src_embed_size: An integer. The embedding dimension for the encoder's + embedding. + tgt_embed_size: An integer. The embedding dimension for the decoder's + embedding. + dtype: dtype of the embedding matrix. Default to float32. + num_enc_partitions: number of partitions used for the encoder's embedding + vars. + num_dec_partitions: number of partitions used for the decoder's embedding + vars. + scope: VariableScope for the created subgraph. Default to "embedding". + + Returns: + embedding_encoder: Encoder's embedding matrix. + embedding_decoder: Decoder's embedding matrix. + + Raises: + ValueError: if use share_vocab but source and target have different vocab + size. + """ + if num_enc_partitions <= 1: + enc_partitioner = None else: - if not use_char_encode: - with tf.variable_scope("encoder", partitioner=enc_partitioner): - embedding_encoder = _create_or_load_embed( - "embedding_encoder", src_vocab_file, src_embed_file, - src_vocab_size, src_embed_size, dtype) - else: - embedding_encoder = None - - with tf.variable_scope("decoder", partitioner=dec_partitioner): - embedding_decoder = _create_or_load_embed( - "embedding_decoder", tgt_vocab_file, tgt_embed_file, - tgt_vocab_size, tgt_embed_size, dtype) - - return embedding_encoder, embedding_decoder - + # Note: num_partitions > 1 is required for distributed training due to + # embedding_lookup tries to colocate single partition-ed embedding variable + # with lookup ops. This may cause embedding variables being placed on worker + # jobs. + enc_partitioner = tf.fixed_size_partitioner(num_enc_partitions) + + if num_dec_partitions <= 1: + dec_partitioner = None + else: + # Note: num_partitions > 1 is required for distributed training due to + # embedding_lookup tries to colocate single partition-ed embedding variable + # with lookup ops. This may cause embedding variables being placed on worker + # jobs. + dec_partitioner = tf.fixed_size_partitioner(num_dec_partitions) + + if src_embed_file and enc_partitioner: + raise ValueError( + "Can't set num_enc_partitions > 1 when using pretrained encoder " + "embedding") + + if tgt_embed_file and dec_partitioner: + raise ValueError( + "Can't set num_dec_partitions > 1 when using pretrained decdoer " + "embedding") + + with tf.variable_scope( + scope or "embeddings", dtype=dtype, partitioner=enc_partitioner) as scope: + # Share embedding + if share_vocab: + if src_vocab_size != tgt_vocab_size: + raise ValueError("Share embedding but different src/tgt vocab sizes" + " %d vs. %d" % (src_vocab_size, tgt_vocab_size)) + assert src_embed_size == tgt_embed_size + utils.print_out("# Use the same embedding for source and target") + vocab_file = src_vocab_file or tgt_vocab_file + embed_file = src_embed_file or tgt_embed_file + + embedding_encoder = _create_or_load_embed( + "embedding_share", vocab_file, embed_file, + src_vocab_size, src_embed_size, dtype) + embedding_decoder = embedding_encoder + else: + if not use_char_encode: + with tf.variable_scope("encoder", partitioner=enc_partitioner): + embedding_encoder = _create_or_load_embed( + "embedding_encoder", src_vocab_file, src_embed_file, + src_vocab_size, src_embed_size, dtype) + else: + embedding_encoder = None + + with tf.variable_scope("decoder", partitioner=dec_partitioner): + embedding_decoder = _create_or_load_embed( + "embedding_decoder", tgt_vocab_file, tgt_embed_file, + tgt_vocab_size, tgt_embed_size, dtype) + + return embedding_encoder, embedding_decoder def create_emb_for_encoder_and_decoder(share_vocab, @@ -403,323 +403,323 @@ def create_emb_for_encoder_and_decoder(share_vocab, tgt_embed_file=None, use_char_encode=False, scope=None): - """Create embedding matrix for both encoder and decoder. - - Args: - src_vocab_size: An integer. The source vocab size. - tgt_vocab_size: An integer. The target vocab size. - src_embed_size: An integer. The embedding dimension for the encoder's - embedding. - tgt_embed_size: An integer. The embedding dimension for the decoder's - embedding. - dtype: dtype of the embedding matrix. Default to float32. - scope: VariableScope for the created subgraph. Default to "embedding". - - Returns: - embedding_encoder: Encoder's embedding matrix. - embedding_decoder: Decoder's embedding matrix. - - Raises: - ValueError: if use share_vocab but source and target have different vocab - size. - """ - assert share_vocab == False - assert num_enc_partitions == 0 - assert num_dec_partitions == 0 - assert src_embed_file == None or src_embed_file == '' - assert tgt_embed_file == None or tgt_embed_file == '' - assert use_char_encode == False - # import pdb - # pdb.set_trace() - use_kdq_to_control = True # controlled by kdq_type - - with tf.variable_scope(scope or "embeddings", dtype=dtype) as scope: - with tf.variable_scope("encoder"): - if use_kdq_to_control: - embedding_encoder = None - else: - embedding_encoder = _create_or_load_embed( - "embedding_encoder", src_vocab_file, src_embed_file, - src_vocab_size, src_embed_size, dtype) - - with tf.variable_scope("decoder"): - embedding_decoder = _create_or_load_embed( - "embedding_decoder", tgt_vocab_file, tgt_embed_file, - tgt_vocab_size, tgt_embed_size, dtype) - - return embedding_encoder, embedding_decoder + """Create embedding matrix for both encoder and decoder. + + Args: + src_vocab_size: An integer. The source vocab size. + tgt_vocab_size: An integer. The target vocab size. + src_embed_size: An integer. The embedding dimension for the encoder's + embedding. + tgt_embed_size: An integer. The embedding dimension for the decoder's + embedding. + dtype: dtype of the embedding matrix. Default to float32. + scope: VariableScope for the created subgraph. Default to "embedding". + + Returns: + embedding_encoder: Encoder's embedding matrix. + embedding_decoder: Decoder's embedding matrix. + + Raises: + ValueError: if use share_vocab but source and target have different vocab + size. + """ + assert share_vocab == False + assert num_enc_partitions == 0 + assert num_dec_partitions == 0 + assert src_embed_file == None or src_embed_file == '' + assert tgt_embed_file == None or tgt_embed_file == '' + assert use_char_encode == False + # import pdb + # pdb.set_trace() + use_kdq_to_control = True # controlled by kdq_type + + with tf.variable_scope(scope or "embeddings", dtype=dtype) as scope: + with tf.variable_scope("encoder"): + if use_kdq_to_control: + embedding_encoder = None + else: + embedding_encoder = _create_or_load_embed( + "embedding_encoder", src_vocab_file, src_embed_file, + src_vocab_size, src_embed_size, dtype) + + with tf.variable_scope("decoder"): + embedding_decoder = _create_or_load_embed( + "embedding_decoder", tgt_vocab_file, tgt_embed_file, + tgt_vocab_size, tgt_embed_size, dtype) + + return embedding_encoder, embedding_decoder def _single_cell(unit_type, num_units, forget_bias, dropout, mode, residual_connection=False, device_str=None, residual_fn=None): - """Create an instance of a single RNN cell.""" - # dropout (= 1 - keep_prob) is set to 0 during eval and infer - dropout = dropout if mode == tf.contrib.learn.ModeKeys.TRAIN else 0.0 - - # Cell Type - if unit_type == "lstm": - utils.print_out(" LSTM, forget_bias=%g" % forget_bias, new_line=False) - single_cell = tf.contrib.rnn.BasicLSTMCell( - num_units, - forget_bias=forget_bias) - elif unit_type == "gru": - utils.print_out(" GRU", new_line=False) - single_cell = tf.contrib.rnn.GRUCell(num_units) - elif unit_type == "layer_norm_lstm": - utils.print_out(" Layer Normalized LSTM, forget_bias=%g" % forget_bias, - new_line=False) - single_cell = tf.contrib.rnn.LayerNormBasicLSTMCell( - num_units, - forget_bias=forget_bias, - layer_norm=True) - elif unit_type == "nas": - utils.print_out(" NASCell", new_line=False) - single_cell = tf.contrib.rnn.NASCell(num_units) - else: - raise ValueError("Unknown unit type %s!" % unit_type) - - # Dropout (= 1 - keep_prob) - if dropout > 0.0: - single_cell = tf.contrib.rnn.DropoutWrapper( - cell=single_cell, input_keep_prob=(1.0 - dropout)) - utils.print_out(" %s, dropout=%g " %(type(single_cell).__name__, dropout), - new_line=False) - - # Residual - if residual_connection: - single_cell = tf.contrib.rnn.ResidualWrapper( - single_cell, residual_fn=residual_fn) - utils.print_out(" %s" % type(single_cell).__name__, new_line=False) - - # Device Wrapper - if device_str: - single_cell = tf.contrib.rnn.DeviceWrapper(single_cell, device_str) - utils.print_out(" %s, device=%s" % - (type(single_cell).__name__, device_str), new_line=False) - - return single_cell + """Create an instance of a single RNN cell.""" + # dropout (= 1 - keep_prob) is set to 0 during eval and infer + dropout = dropout if mode == tf.contrib.learn.ModeKeys.TRAIN else 0.0 + + # Cell Type + if unit_type == "lstm": + utils.print_out(" LSTM, forget_bias=%g" % forget_bias, new_line=False) + single_cell = tf.contrib.rnn.BasicLSTMCell( + num_units, + forget_bias=forget_bias) + elif unit_type == "gru": + utils.print_out(" GRU", new_line=False) + single_cell = tf.contrib.rnn.GRUCell(num_units) + elif unit_type == "layer_norm_lstm": + utils.print_out(" Layer Normalized LSTM, forget_bias=%g" % forget_bias, + new_line=False) + single_cell = tf.contrib.rnn.LayerNormBasicLSTMCell( + num_units, + forget_bias=forget_bias, + layer_norm=True) + elif unit_type == "nas": + utils.print_out(" NASCell", new_line=False) + single_cell = tf.contrib.rnn.NASCell(num_units) + else: + raise ValueError("Unknown unit type %s!" % unit_type) + + # Dropout (= 1 - keep_prob) + if dropout > 0.0: + single_cell = tf.contrib.rnn.DropoutWrapper( + cell=single_cell, input_keep_prob=(1.0 - dropout)) + utils.print_out(" %s, dropout=%g " % (type(single_cell).__name__, dropout), + new_line=False) + + # Residual + if residual_connection: + single_cell = tf.contrib.rnn.ResidualWrapper( + single_cell, residual_fn=residual_fn) + utils.print_out(" %s" % type(single_cell).__name__, new_line=False) + + # Device Wrapper + if device_str: + single_cell = tf.contrib.rnn.DeviceWrapper(single_cell, device_str) + utils.print_out(" %s, device=%s" % + (type(single_cell).__name__, device_str), new_line=False) + + return single_cell def _cell_list(unit_type, num_units, num_layers, num_residual_layers, forget_bias, dropout, mode, num_gpus, base_gpu=0, single_cell_fn=None, residual_fn=None): - """Create a list of RNN cells.""" - if not single_cell_fn: - single_cell_fn = _single_cell - - # Multi-GPU - cell_list = [] - for i in range(num_layers): - utils.print_out(" cell %d" % i, new_line=False) - single_cell = single_cell_fn( - unit_type=unit_type, - num_units=num_units, - forget_bias=forget_bias, - dropout=dropout, - mode=mode, - residual_connection=(i >= num_layers - num_residual_layers), - device_str=get_device_str(i + base_gpu, num_gpus), - residual_fn=residual_fn - ) - utils.print_out("") - cell_list.append(single_cell) - - return cell_list + """Create a list of RNN cells.""" + if not single_cell_fn: + single_cell_fn = _single_cell + + # Multi-GPU + cell_list = [] + for i in range(num_layers): + utils.print_out(" cell %d" % i, new_line=False) + single_cell = single_cell_fn( + unit_type=unit_type, + num_units=num_units, + forget_bias=forget_bias, + dropout=dropout, + mode=mode, + residual_connection=(i >= num_layers - num_residual_layers), + device_str=get_device_str(i + base_gpu, num_gpus), + residual_fn=residual_fn + ) + utils.print_out("") + cell_list.append(single_cell) + + return cell_list def create_rnn_cell(unit_type, num_units, num_layers, num_residual_layers, forget_bias, dropout, mode, num_gpus, base_gpu=0, single_cell_fn=None): - """Create multi-layer RNN cell. - - Args: - unit_type: string representing the unit type, i.e. "lstm". - num_units: the depth of each unit. - num_layers: number of cells. - num_residual_layers: Number of residual layers from top to bottom. For - example, if `num_layers=4` and `num_residual_layers=2`, the last 2 RNN - cells in the returned list will be wrapped with `ResidualWrapper`. - forget_bias: the initial forget bias of the RNNCell(s). - dropout: floating point value between 0.0 and 1.0: - the probability of dropout. this is ignored if `mode != TRAIN`. - mode: either tf.contrib.learn.TRAIN/EVAL/INFER - num_gpus: The number of gpus to use when performing round-robin - placement of layers. - base_gpu: The gpu device id to use for the first RNN cell in the - returned list. The i-th RNN cell will use `(base_gpu + i) % num_gpus` - as its device id. - single_cell_fn: allow for adding customized cell. - When not specified, we default to model_helper._single_cell - Returns: - An `RNNCell` instance. - """ - cell_list = _cell_list(unit_type=unit_type, - num_units=num_units, - num_layers=num_layers, - num_residual_layers=num_residual_layers, - forget_bias=forget_bias, - dropout=dropout, - mode=mode, - num_gpus=num_gpus, - base_gpu=base_gpu, - single_cell_fn=single_cell_fn) - - if len(cell_list) == 1: # Single layer. - return cell_list[0] - else: # Multi layers - return tf.contrib.rnn.MultiRNNCell(cell_list) + """Create multi-layer RNN cell. + + Args: + unit_type: string representing the unit type, i.e. "lstm". + num_units: the depth of each unit. + num_layers: number of cells. + num_residual_layers: Number of residual layers from top to bottom. For + example, if `num_layers=4` and `num_residual_layers=2`, the last 2 RNN + cells in the returned list will be wrapped with `ResidualWrapper`. + forget_bias: the initial forget bias of the RNNCell(s). + dropout: floating point value between 0.0 and 1.0: + the probability of dropout. this is ignored if `mode != TRAIN`. + mode: either tf.contrib.learn.TRAIN/EVAL/INFER + num_gpus: The number of gpus to use when performing round-robin + placement of layers. + base_gpu: The gpu device id to use for the first RNN cell in the + returned list. The i-th RNN cell will use `(base_gpu + i) % num_gpus` + as its device id. + single_cell_fn: allow for adding customized cell. + When not specified, we default to model_helper._single_cell + Returns: + An `RNNCell` instance. + """ + cell_list = _cell_list(unit_type=unit_type, + num_units=num_units, + num_layers=num_layers, + num_residual_layers=num_residual_layers, + forget_bias=forget_bias, + dropout=dropout, + mode=mode, + num_gpus=num_gpus, + base_gpu=base_gpu, + single_cell_fn=single_cell_fn) + + if len(cell_list) == 1: # Single layer. + return cell_list[0] + else: # Multi layers + return tf.contrib.rnn.MultiRNNCell(cell_list) def gradient_clip(gradients, max_gradient_norm): - """Clipping gradients of a model.""" - clipped_gradients, gradient_norm = tf.clip_by_global_norm( - gradients, max_gradient_norm) - gradient_norm_summary = [tf.summary.scalar("grad_norm", gradient_norm)] - gradient_norm_summary.append( - tf.summary.scalar("clipped_gradient", tf.global_norm(clipped_gradients))) + """Clipping gradients of a model.""" + clipped_gradients, gradient_norm = tf.clip_by_global_norm( + gradients, max_gradient_norm) + gradient_norm_summary = [tf.summary.scalar("grad_norm", gradient_norm)] + gradient_norm_summary.append( + tf.summary.scalar("clipped_gradient", tf.global_norm(clipped_gradients))) - return clipped_gradients, gradient_norm_summary, gradient_norm + return clipped_gradients, gradient_norm_summary, gradient_norm def print_variables_in_ckpt(ckpt_path): - """Print a list of variables in a checkpoint together with their shapes.""" - utils.print_out("# Variables in ckpt %s" % ckpt_path) - reader = tf.train.NewCheckpointReader(ckpt_path) - variable_map = reader.get_variable_to_shape_map() - for key in sorted(variable_map.keys()): - utils.print_out(" %s: %s" % (key, variable_map[key])) + """Print a list of variables in a checkpoint together with their shapes.""" + utils.print_out("# Variables in ckpt %s" % ckpt_path) + reader = tf.train.NewCheckpointReader(ckpt_path) + variable_map = reader.get_variable_to_shape_map() + for key in sorted(variable_map.keys()): + utils.print_out(" %s: %s" % (key, variable_map[key])) def load_model(model, ckpt_path, session, name): - """Load model from a checkpoint.""" - start_time = time.time() - try: - model.saver.restore(session, ckpt_path) - except tf.errors.NotFoundError as e: - utils.print_out("Can't load checkpoint") - print_variables_in_ckpt(ckpt_path) - utils.print_out("%s" % str(e)) - - session.run(tf.tables_initializer()) - utils.print_out( - " loaded %s model parameters from %s, time %.2fs" % - (name, ckpt_path, time.time() - start_time)) - return model + """Load model from a checkpoint.""" + start_time = time.time() + try: + model.saver.restore(session, ckpt_path) + except tf.errors.NotFoundError as e: + utils.print_out("Can't load checkpoint") + print_variables_in_ckpt(ckpt_path) + utils.print_out("%s" % str(e)) + + session.run(tf.tables_initializer()) + utils.print_out( + " loaded %s model parameters from %s, time %.2fs" % + (name, ckpt_path, time.time() - start_time)) + return model def avg_checkpoints(model_dir, num_last_checkpoints, global_step, global_step_name): - """Average the last N checkpoints in the model_dir.""" - checkpoint_state = tf.train.get_checkpoint_state(model_dir) - if not checkpoint_state: - utils.print_out("# No checkpoint file found in directory: %s" % model_dir) - return None + """Average the last N checkpoints in the model_dir.""" + checkpoint_state = tf.train.get_checkpoint_state(model_dir) + if not checkpoint_state: + utils.print_out("# No checkpoint file found in directory: %s" % model_dir) + return None + + # Checkpoints are ordered from oldest to newest. + checkpoints = ( + checkpoint_state.all_model_checkpoint_paths[-num_last_checkpoints:]) + + if len(checkpoints) < num_last_checkpoints: + utils.print_out( + "# Skipping averaging checkpoints because not enough checkpoints is " + "available." + ) + return None + + avg_model_dir = os.path.join(model_dir, "avg_checkpoints") + if not tf.gfile.Exists(avg_model_dir): + utils.print_out( + "# Creating new directory %s for saving averaged checkpoints." % + avg_model_dir) + tf.gfile.MakeDirs(avg_model_dir) + + utils.print_out("# Reading and averaging variables in checkpoints:") + var_list = tf.contrib.framework.list_variables(checkpoints[0]) + var_values, var_dtypes = {}, {} + for (name, shape) in var_list: + if name != global_step_name: + var_values[name] = np.zeros(shape) + + for checkpoint in checkpoints: + utils.print_out(" %s" % checkpoint) + reader = tf.contrib.framework.load_checkpoint(checkpoint) + for name in var_values: + tensor = reader.get_tensor(name) + var_dtypes[name] = tensor.dtype + var_values[name] += tensor - # Checkpoints are ordered from oldest to newest. - checkpoints = ( - checkpoint_state.all_model_checkpoint_paths[-num_last_checkpoints:]) + for name in var_values: + var_values[name] /= len(checkpoints) - if len(checkpoints) < num_last_checkpoints: - utils.print_out( - "# Skipping averaging checkpoints because not enough checkpoints is " - "available." - ) - return None + # Build a graph with same variables in the checkpoints, and save the averaged + # variables into the avg_model_dir. + with tf.Graph().as_default(): + tf_vars = [ + tf.get_variable(v, shape=var_values[v].shape, dtype=var_dtypes[name]) + for v in var_values + ] - avg_model_dir = os.path.join(model_dir, "avg_checkpoints") - if not tf.gfile.Exists(avg_model_dir): - utils.print_out( - "# Creating new directory %s for saving averaged checkpoints." % - avg_model_dir) - tf.gfile.MakeDirs(avg_model_dir) - - utils.print_out("# Reading and averaging variables in checkpoints:") - var_list = tf.contrib.framework.list_variables(checkpoints[0]) - var_values, var_dtypes = {}, {} - for (name, shape) in var_list: - if name != global_step_name: - var_values[name] = np.zeros(shape) - - for checkpoint in checkpoints: - utils.print_out(" %s" % checkpoint) - reader = tf.contrib.framework.load_checkpoint(checkpoint) - for name in var_values: - tensor = reader.get_tensor(name) - var_dtypes[name] = tensor.dtype - var_values[name] += tensor - - for name in var_values: - var_values[name] /= len(checkpoints) - - # Build a graph with same variables in the checkpoints, and save the averaged - # variables into the avg_model_dir. - with tf.Graph().as_default(): - tf_vars = [ - tf.get_variable(v, shape=var_values[v].shape, dtype=var_dtypes[name]) - for v in var_values - ] - - placeholders = [tf.placeholder(v.dtype, shape=v.shape) for v in tf_vars] - assign_ops = [tf.assign(v, p) for (v, p) in zip(tf_vars, placeholders)] - global_step_var = tf.Variable( - global_step, name=global_step_name, trainable=False) - saver = tf.train.Saver(tf.all_variables()) - - with tf.Session() as sess: - sess.run(tf.initialize_all_variables()) - for p, assign_op, (name, value) in zip(placeholders, assign_ops, - six.iteritems(var_values)): - sess.run(assign_op, {p: value}) - - # Use the built saver to save the averaged checkpoint. Only keep 1 - # checkpoint and the best checkpoint will be moved to avg_best_metric_dir. - saver.save( - sess, - os.path.join(avg_model_dir, "translate.ckpt")) - - return avg_model_dir + placeholders = [tf.placeholder(v.dtype, shape=v.shape) for v in tf_vars] + assign_ops = [tf.assign(v, p) for (v, p) in zip(tf_vars, placeholders)] + global_step_var = tf.Variable( + global_step, name=global_step_name, trainable=False) + saver = tf.train.Saver(tf.all_variables()) + with tf.Session() as sess: + sess.run(tf.initialize_all_variables()) + for p, assign_op, (name, value) in zip(placeholders, assign_ops, + six.iteritems(var_values)): + sess.run(assign_op, {p: value}) -def create_or_load_model(model, model_dir, session, name): - """Create translation model and initialize or load parameters in session.""" - latest_ckpt = tf.train.latest_checkpoint(model_dir) - if latest_ckpt: - model = load_model(model, latest_ckpt, session, name) - else: - start_time = time.time() - session.run(tf.global_variables_initializer()) - session.run(tf.tables_initializer()) - utils.print_out(" created %s model with fresh parameters, time %.2fs" % - (name, time.time() - start_time)) + # Use the built saver to save the averaged checkpoint. Only keep 1 + # checkpoint and the best checkpoint will be moved to avg_best_metric_dir. + saver.save( + sess, + os.path.join(avg_model_dir, "translate.ckpt")) - global_step = model.global_step.eval(session=session) - return model, global_step + return avg_model_dir -def compute_perplexity(model, sess, name): - """Compute perplexity of the output of the model. +def create_or_load_model(model, model_dir, session, name): + """Create translation model and initialize or load parameters in session.""" + latest_ckpt = tf.train.latest_checkpoint(model_dir) + if latest_ckpt: + model = load_model(model, latest_ckpt, session, name) + else: + start_time = time.time() + session.run(tf.global_variables_initializer()) + session.run(tf.tables_initializer()) + utils.print_out(" created %s model with fresh parameters, time %.2fs" % + (name, time.time() - start_time)) - Args: - model: model for compute perplexity. - sess: tensorflow session to use. - name: name of the batch. + global_step = model.global_step.eval(session=session) + return model, global_step - Returns: - The perplexity of the eval outputs. - """ - total_loss = 0 - total_predict_count = 0 - start_time = time.time() - while True: - try: - output_tuple = model.eval(sess) - total_loss += output_tuple.eval_loss * output_tuple.batch_size - total_predict_count += output_tuple.predict_count - except tf.errors.OutOfRangeError: - break - - perplexity = utils.safe_exp(total_loss / total_predict_count) - utils.print_time(" eval %s: perplexity %.2f" % (name, perplexity), - start_time) - return perplexity +def compute_perplexity(model, sess, name): + """Compute perplexity of the output of the model. + + Args: + model: model for compute perplexity. + sess: tensorflow session to use. + name: name of the batch. + + Returns: + The perplexity of the eval outputs. + """ + total_loss = 0 + total_predict_count = 0 + start_time = time.time() + + while True: + try: + output_tuple = model.eval(sess) + total_loss += output_tuple.eval_loss * output_tuple.batch_size + total_predict_count += output_tuple.predict_count + except tf.errors.OutOfRangeError: + break + + perplexity = utils.safe_exp(total_loss / total_predict_count) + utils.print_time(" eval %s: perplexity %.2f" % (name, perplexity), + start_time) + return perplexity diff --git a/nmt/model_test.py b/nmt/model_test.py index e79a0ce..ae15a36 100644 --- a/nmt/model_test.py +++ b/nmt/model_test.py @@ -40,997 +40,998 @@ class ModelTest(tf.test.TestCase): - @classmethod - def setUpClass(cls): - cls.actual_vars_values = {} - cls.expected_vars_values = { - 'AttentionMechanismBahdanau/att_layer_weight/shape': (10, 5), - 'AttentionMechanismBahdanau/att_layer_weight/sum': - -0.64981574, - 'AttentionMechanismBahdanau/last_dec_weight/shape': (10, 20), - 'AttentionMechanismBahdanau/last_dec_weight/sum': - 0.058069646, - 'AttentionMechanismBahdanau/last_enc_weight/shape': (10, 20), - 'AttentionMechanismBahdanau/last_enc_weight/sum': - 0.058028102, - 'AttentionMechanismLuong/att_layer_weight/shape': (10, 5), - 'AttentionMechanismLuong/att_layer_weight/sum': - -0.64981574, - 'AttentionMechanismLuong/last_dec_weight/shape': (10, 20), - 'AttentionMechanismLuong/last_dec_weight/sum': - 0.058069646, - 'AttentionMechanismLuong/last_enc_weight/shape': (10, 20), - 'AttentionMechanismLuong/last_enc_weight/sum': - 0.058028102, - 'AttentionMechanismNormedBahdanau/att_layer_weight/shape': (10, 5), - 'AttentionMechanismNormedBahdanau/att_layer_weight/sum': - -0.64981973, - 'AttentionMechanismNormedBahdanau/last_dec_weight/shape': (10, 20), - 'AttentionMechanismNormedBahdanau/last_dec_weight/sum': - 0.058067322, - 'AttentionMechanismNormedBahdanau/last_enc_weight/shape': (10, 20), - 'AttentionMechanismNormedBahdanau/last_enc_weight/sum': - 0.058022559, - 'AttentionMechanismScaledLuong/att_layer_weight/shape': (10, 5), - 'AttentionMechanismScaledLuong/att_layer_weight/sum': - -0.64981574, - 'AttentionMechanismScaledLuong/last_dec_weight/shape': (10, 20), - 'AttentionMechanismScaledLuong/last_dec_weight/sum': - 0.058069646, - 'AttentionMechanismScaledLuong/last_enc_weight/shape': (10, 20), - 'AttentionMechanismScaledLuong/last_enc_weight/sum': - 0.058028102, - 'GNMTModel_gnmt/last_dec_weight/shape': (15, 20), - 'GNMTModel_gnmt/last_dec_weight/sum': - -0.48634407, - 'GNMTModel_gnmt/last_enc_weight/shape': (10, 20), - 'GNMTModel_gnmt/last_enc_weight/sum': - 0.058025002, - 'GNMTModel_gnmt/mem_layer_weight/shape': (5, 5), - 'GNMTModel_gnmt/mem_layer_weight/sum': - -0.44815454, - 'GNMTModel_gnmt_v2/last_dec_weight/shape': (15, 20), - 'GNMTModel_gnmt_v2/last_dec_weight/sum': - -0.48634392, - 'GNMTModel_gnmt_v2/last_enc_weight/shape': (10, 20), - 'GNMTModel_gnmt_v2/last_enc_weight/sum': - 0.058024824, - 'GNMTModel_gnmt_v2/mem_layer_weight/shape': (5, 5), - 'GNMTModel_gnmt_v2/mem_layer_weight/sum': - -0.44815454, - 'NoAttentionNoResidualUniEncoder/last_dec_weight/shape': (10, 20), - 'NoAttentionNoResidualUniEncoder/last_dec_weight/sum': - 0.057424068, - 'NoAttentionNoResidualUniEncoder/last_enc_weight/shape': (10, 20), - 'NoAttentionNoResidualUniEncoder/last_enc_weight/sum': - 0.058453858, - 'NoAttentionResidualBiEncoder/last_dec_weight/shape': (10, 20), - 'NoAttentionResidualBiEncoder/last_dec_weight/sum': - 0.058025062, - 'NoAttentionResidualBiEncoder/last_enc_weight/shape': (10, 20), - 'NoAttentionResidualBiEncoder/last_enc_weight/sum': - 0.058053195, - 'UniEncoderBottomAttentionArchitecture/last_dec_weight/shape': (10, 20), - 'UniEncoderBottomAttentionArchitecture/last_dec_weight/sum': - 0.058024943, - 'UniEncoderBottomAttentionArchitecture/last_enc_weight/shape': (10, 20), - 'UniEncoderBottomAttentionArchitecture/last_enc_weight/sum': - 0.058025122, - 'UniEncoderBottomAttentionArchitecture/mem_layer_weight/shape': (5, 5), - 'UniEncoderBottomAttentionArchitecture/mem_layer_weight/sum': - -0.44815454, - 'UniEncoderStandardAttentionArchitecture/last_dec_weight/shape': (10, - 20), - 'UniEncoderStandardAttentionArchitecture/last_dec_weight/sum': - 0.058025002, - 'UniEncoderStandardAttentionArchitecture/last_enc_weight/shape': (10, - 20), - 'UniEncoderStandardAttentionArchitecture/last_enc_weight/sum': - 0.058024883, - 'UniEncoderStandardAttentionArchitecture/mem_layer_weight/shape': (5, - 5), - 'UniEncoderStandardAttentionArchitecture/mem_layer_weight/sum': - -0.44815454, - } - - cls.actual_train_values = {} - cls.expected_train_values = { - 'AttentionMechanismBahdanau/loss': 8.8519039, - 'AttentionMechanismLuong/loss': 8.8519039, - 'AttentionMechanismNormedBahdanau/loss': 8.851902, - 'AttentionMechanismScaledLuong/loss': 8.8519039, - 'GNMTModel_gnmt/loss': 8.8519087, - 'GNMTModel_gnmt_v2/loss': 8.8519087, - 'NoAttentionNoResidualUniEncoder/loss': 8.8516064, - 'NoAttentionResidualBiEncoder/loss': 8.851984, - 'UniEncoderStandardAttentionArchitecture/loss': 8.8519087, - 'InitializerGlorotNormal/loss': 8.9779415, - 'InitializerGlorotUniform/loss': 8.7643699, - 'SampledSoftmaxLoss/loss': 5.83928, - } - - cls.actual_eval_values = {} - cls.expected_eval_values = { - 'AttentionMechanismBahdanau/loss': 8.8517132, - 'AttentionMechanismBahdanau/predict_count': 11.0, - 'AttentionMechanismLuong/loss': 8.8517132, - 'AttentionMechanismLuong/predict_count': 11.0, - 'AttentionMechanismNormedBahdanau/loss': 8.8517132, - 'AttentionMechanismNormedBahdanau/predict_count': 11.0, - 'AttentionMechanismScaledLuong/loss': 8.8517132, - 'AttentionMechanismScaledLuong/predict_count': 11.0, - 'GNMTModel_gnmt/loss': 8.8443403, - 'GNMTModel_gnmt/predict_count': 11.0, - 'GNMTModel_gnmt_v2/loss': 8.8443756, - 'GNMTModel_gnmt_v2/predict_count': 11.0, - 'NoAttentionNoResidualUniEncoder/loss': 8.8440113, - 'NoAttentionNoResidualUniEncoder/predict_count': 11.0, - 'NoAttentionResidualBiEncoder/loss': 8.8291245, - 'NoAttentionResidualBiEncoder/predict_count': 11.0, - 'UniEncoderBottomAttentionArchitecture/loss': 8.844492, - 'UniEncoderBottomAttentionArchitecture/predict_count': 11.0, - 'UniEncoderStandardAttentionArchitecture/loss': 8.8517151, - 'UniEncoderStandardAttentionArchitecture/predict_count': 11.0 - } - - cls.actual_infer_values = {} - cls.expected_infer_values = { - 'AttentionMechanismBahdanau/logits_sum': -0.026374687, - 'AttentionMechanismLuong/logits_sum': -0.026374735, - 'AttentionMechanismNormedBahdanau/logits_sum': -0.026376063, - 'AttentionMechanismScaledLuong/logits_sum': -0.026374735, - 'GNMTModel_gnmt/logits_sum': -1.10848486, - 'GNMTModel_gnmt_v2/logits_sum': -1.10950875, - 'NoAttentionNoResidualUniEncoder/logits_sum': -1.0808625, - 'NoAttentionResidualBiEncoder/logits_sum': -2.8147559, - 'UniEncoderBottomAttentionArchitecture/logits_sum': -0.97026241, - 'UniEncoderStandardAttentionArchitecture/logits_sum': -0.02665353 - } - - cls.actual_beam_sentences = {} - cls.expected_beam_sentences = { - 'BeamSearchAttentionModel: batch 0 of beam 0': '', - 'BeamSearchAttentionModel: batch 0 of beam 1': '%s a %s a' % (SOS, SOS), - 'BeamSearchAttentionModel: batch 1 of beam 0': '', - 'BeamSearchAttentionModel: batch 1 of beam 1': 'b', - 'BeamSearchBasicModel: batch 0 of beam 0': 'b b b b', - 'BeamSearchBasicModel: batch 0 of beam 1': 'b b b %s' % SOS, - 'BeamSearchBasicModel: batch 0 of beam 2': 'b b b c', - 'BeamSearchBasicModel: batch 1 of beam 0': 'b b b b', - 'BeamSearchBasicModel: batch 1 of beam 1': 'a b b b', - 'BeamSearchBasicModel: batch 1 of beam 2': 'b b b %s' % SOS, - 'BeamSearchGNMTModel: batch 0 of beam 0': '', - 'BeamSearchGNMTModel: batch 1 of beam 0': '', - } - cls.expected_beam_sentences = dict( - (k, v.encode()) for k, v in cls.expected_beam_sentences.items()) - - @classmethod - def tearDownClass(cls): - print('ModelTest - actual_vars_values: ') - pprint.pprint(cls.actual_vars_values) - sys.stdout.flush() - - print('ModelTest - actual_train_values: ') - pprint.pprint(cls.actual_train_values) - sys.stdout.flush() - - print('ModelTest - actual_eval_values: ') - pprint.pprint(cls.actual_eval_values) - sys.stdout.flush() - - print('ModelTest - actual_infer_values: ') - pprint.pprint(cls.actual_infer_values) - sys.stdout.flush() - - print('ModelTest - actual_beam_sentences: ') - pprint.pprint(cls.actual_beam_sentences) - sys.stdout.flush() - - def assertAllClose(self, *args, **kwargs): - kwargs['atol'] = 5e-2 - kwargs['rtol'] = 5e-2 - return super(ModelTest, self).assertAllClose(*args, **kwargs) - - def _assertModelVariableNames(self, expected_var_names, model_var_names, - name): - - print('{} variable names are: '.format(name), model_var_names) - - self.assertEqual(len(expected_var_names), len(model_var_names)) - self.assertEqual(sorted(expected_var_names), sorted(model_var_names)) - - def _assertModelVariable(self, variable, sess, name): - var_shape = tuple(variable.get_shape().as_list()) - var_res = sess.run(variable) - var_weight_sum = np.sum(var_res) - - print('{} weight sum is: '.format(name), var_weight_sum) - expected_sum = self.expected_vars_values[name + '/sum'] - expected_shape = self.expected_vars_values[name + '/shape'] - self.actual_vars_values[name + '/sum'] = var_weight_sum - self.actual_vars_values[name + '/shape'] = var_shape - - self.assertEqual(expected_shape, var_shape) - self.assertAllClose(expected_sum, var_weight_sum) - - def _assertTrainStepsLoss(self, m, sess, name, num_steps=1): - for _ in range(num_steps): - _, output_tuple = m.train(sess) - loss = output_tuple.train_loss - print('{} {}-th step loss is: '.format(name, num_steps), loss) - expected_loss = self.expected_train_values[name + '/loss'] - self.actual_train_values[name + '/loss'] = loss - - self.assertAllClose(expected_loss, loss) - - def _assertEvalLossAndPredictCount(self, m, sess, name): - output_tuple = m.eval(sess) - loss = output_tuple.eval_loss - predict_count = output_tuple.predict_count - print('{} eval loss is: '.format(name), loss) - print('{} predict count is: '.format(name), predict_count) - expected_loss = self.expected_eval_values[name + '/loss'] - expected_predict_count = self.expected_eval_values[name + '/predict_count'] - self.actual_eval_values[name + '/loss'] = loss - self.actual_eval_values[name + '/predict_count'] = predict_count - - self.assertAllClose(expected_loss, loss) - self.assertAllClose(expected_predict_count, predict_count) - - def _assertInferLogits(self, m, sess, name): - output_tuple = m.infer(sess) - logits_sum = np.sum(output_tuple.infer_logits) - - print('{} infer logits sum is: '.format(name), logits_sum) - expected_logits_sum = self.expected_infer_values[name + '/logits_sum'] - self.actual_infer_values[name + '/logits_sum'] = logits_sum - - self.assertAllClose(expected_logits_sum, logits_sum) - - def _assertBeamSearchOutputs(self, m, sess, assert_top_k_sentence, name): - nmt_outputs, _ = m.decode(sess) - - for i in range(assert_top_k_sentence): - output_words = nmt_outputs[i] - for j in range(output_words.shape[0]): - sentence = nmt_utils.get_translation( - output_words, j, tgt_eos=EOS, subword_option='') - sentence_key = ('%s: batch %d of beam %d' % (name, j, i)) - self.actual_beam_sentences[sentence_key] = sentence - expected_sentence = self.expected_beam_sentences[sentence_key] - self.assertEqual(expected_sentence, sentence) - - def _createTestTrainModel(self, m_creator, hparams, sess): - train_mode = tf.contrib.learn.ModeKeys.TRAIN - train_iterator, src_vocab_table, tgt_vocab_table = ( - common_test_utils.create_test_iterator(hparams, train_mode)) - train_m = m_creator( - hparams, - train_mode, - train_iterator, - src_vocab_table, - tgt_vocab_table, - scope='dynamic_seq2seq') - sess.run(tf.global_variables_initializer()) - sess.run(tf.tables_initializer()) - sess.run(train_iterator.initializer) - return train_m - - def _createTestEvalModel(self, m_creator, hparams, sess): - eval_mode = tf.contrib.learn.ModeKeys.EVAL - eval_iterator, src_vocab_table, tgt_vocab_table = ( - common_test_utils.create_test_iterator(hparams, eval_mode)) - eval_m = m_creator( - hparams, - eval_mode, - eval_iterator, - src_vocab_table, - tgt_vocab_table, - scope='dynamic_seq2seq') - sess.run(tf.tables_initializer()) - sess.run(eval_iterator.initializer) - return eval_m - - def _createTestInferModel( - self, m_creator, hparams, sess, init_global_vars=False): - infer_mode = tf.contrib.learn.ModeKeys.INFER - (infer_iterator, src_vocab_table, - tgt_vocab_table, reverse_tgt_vocab_table) = ( - common_test_utils.create_test_iterator(hparams, infer_mode)) - infer_m = m_creator( - hparams, - infer_mode, - infer_iterator, - src_vocab_table, - tgt_vocab_table, - reverse_tgt_vocab_table, - scope='dynamic_seq2seq') - if init_global_vars: - sess.run(tf.global_variables_initializer()) - sess.run(tf.tables_initializer()) - sess.run(infer_iterator.initializer) - return infer_m - - def _get_session_config(self): - config = tf.ConfigProto() - config.allow_soft_placement = True - return config - - ## Testing 3 encoders: - # uni: no attention, no residual, 1 layers - # bi: no attention, with residual, 4 layers - def testNoAttentionNoResidualUniEncoder(self): - hparams = common_test_utils.create_test_hparams( - encoder_type='uni', - num_layers=1, - attention='', - attention_architecture='', - use_residual=False,) - - workers, _ = tf.test.create_local_cluster(1, 0) - worker = workers[0] - - # pylint: disable=line-too-long - expected_var_names = [ - 'dynamic_seq2seq/encoder/embedding_encoder:0', - 'dynamic_seq2seq/decoder/embedding_decoder:0', - 'dynamic_seq2seq/encoder/rnn/basic_lstm_cell/kernel:0', - 'dynamic_seq2seq/encoder/rnn/basic_lstm_cell/bias:0', - 'dynamic_seq2seq/decoder/basic_lstm_cell/kernel:0', - 'dynamic_seq2seq/decoder/basic_lstm_cell/bias:0', - 'dynamic_seq2seq/decoder/output_projection/kernel:0' - ] - # pylint: enable=line-too-long - - with tf.Graph().as_default(): - with tf.Session(worker.target, config=self._get_session_config()) as sess: - train_m = self._createTestTrainModel(model.Model, hparams, sess) - - m_vars = tf.trainable_variables() - self._assertModelVariableNames(expected_var_names, - [v.name for v in m_vars], - 'NoAttentionNoResidualUniEncoder') - - with tf.variable_scope('dynamic_seq2seq', reuse=True): - last_enc_weight = tf.get_variable( - 'encoder/rnn/basic_lstm_cell/kernel') - last_dec_weight = tf.get_variable('decoder/basic_lstm_cell/kernel') - self._assertTrainStepsLoss(train_m, sess, - 'NoAttentionNoResidualUniEncoder') - self._assertModelVariable( - last_enc_weight, sess, - 'NoAttentionNoResidualUniEncoder/last_enc_weight') - self._assertModelVariable( - last_dec_weight, sess, - 'NoAttentionNoResidualUniEncoder/last_dec_weight') - - with tf.Graph().as_default(): - with tf.Session(worker.target, config=self._get_session_config()) as sess: - eval_m = self._createTestEvalModel(model.Model, hparams, sess) - self._assertEvalLossAndPredictCount(eval_m, sess, - 'NoAttentionNoResidualUniEncoder') - - with tf.Graph().as_default(): - with tf.Session(worker.target, config=self._get_session_config()) as sess: - infer_m = self._createTestInferModel(model.Model, hparams, sess) - self._assertInferLogits(infer_m, sess, - 'NoAttentionNoResidualUniEncoder') - - def testNoAttentionResidualBiEncoder(self): - hparams = common_test_utils.create_test_hparams( - encoder_type='bi', - num_layers=4, - attention='', - attention_architecture='', - use_residual=True,) - - workers, _ = tf.test.create_local_cluster(1, 0) - worker = workers[0] - - # pylint: disable=line-too-long - expected_var_names = [ - 'dynamic_seq2seq/encoder/embedding_encoder:0', - 'dynamic_seq2seq/decoder/embedding_decoder:0', - 'dynamic_seq2seq/encoder/bidirectional_rnn/fw/multi_rnn_cell/cell_0/basic_lstm_cell/kernel:0', - 'dynamic_seq2seq/encoder/bidirectional_rnn/fw/multi_rnn_cell/cell_0/basic_lstm_cell/bias:0', - 'dynamic_seq2seq/encoder/bidirectional_rnn/fw/multi_rnn_cell/cell_1/basic_lstm_cell/kernel:0', - 'dynamic_seq2seq/encoder/bidirectional_rnn/fw/multi_rnn_cell/cell_1/basic_lstm_cell/bias:0', - 'dynamic_seq2seq/encoder/bidirectional_rnn/bw/multi_rnn_cell/cell_0/basic_lstm_cell/kernel:0', - 'dynamic_seq2seq/encoder/bidirectional_rnn/bw/multi_rnn_cell/cell_0/basic_lstm_cell/bias:0', - 'dynamic_seq2seq/encoder/bidirectional_rnn/bw/multi_rnn_cell/cell_1/basic_lstm_cell/kernel:0', - 'dynamic_seq2seq/encoder/bidirectional_rnn/bw/multi_rnn_cell/cell_1/basic_lstm_cell/bias:0', - 'dynamic_seq2seq/decoder/multi_rnn_cell/cell_0/basic_lstm_cell/kernel:0', - 'dynamic_seq2seq/decoder/multi_rnn_cell/cell_0/basic_lstm_cell/bias:0', - 'dynamic_seq2seq/decoder/multi_rnn_cell/cell_1/basic_lstm_cell/kernel:0', - 'dynamic_seq2seq/decoder/multi_rnn_cell/cell_1/basic_lstm_cell/bias:0', - 'dynamic_seq2seq/decoder/multi_rnn_cell/cell_2/basic_lstm_cell/kernel:0', - 'dynamic_seq2seq/decoder/multi_rnn_cell/cell_2/basic_lstm_cell/bias:0', - 'dynamic_seq2seq/decoder/multi_rnn_cell/cell_3/basic_lstm_cell/kernel:0', - 'dynamic_seq2seq/decoder/multi_rnn_cell/cell_3/basic_lstm_cell/bias:0', - 'dynamic_seq2seq/decoder/output_projection/kernel:0' - ] - # pylint: enable=line-too-long - - with tf.Graph().as_default(): - with tf.Session(worker.target, config=self._get_session_config()) as sess: - train_m = self._createTestTrainModel(model.Model, hparams, sess) - - m_vars = tf.trainable_variables() - self._assertModelVariableNames(expected_var_names, - [v.name for v in m_vars], - 'NoAttentionResidualBiEncoder') - with tf.variable_scope('dynamic_seq2seq', reuse=True): - last_enc_weight = tf.get_variable( - 'encoder/bidirectional_rnn/bw/multi_rnn_cell/cell_1/basic_lstm_cell/kernel' - ) - last_dec_weight = tf.get_variable( - 'decoder/multi_rnn_cell/cell_3/basic_lstm_cell/kernel') - self._assertTrainStepsLoss(train_m, sess, - 'NoAttentionResidualBiEncoder') - self._assertModelVariable( - last_enc_weight, sess, - 'NoAttentionResidualBiEncoder/last_enc_weight') - self._assertModelVariable( - last_dec_weight, sess, - 'NoAttentionResidualBiEncoder/last_dec_weight') - - with tf.Graph().as_default(): - with tf.Session(worker.target, config=self._get_session_config()) as sess: - eval_m = self._createTestEvalModel(model.Model, hparams, sess) - self._assertEvalLossAndPredictCount(eval_m, sess, - 'NoAttentionResidualBiEncoder') - - with tf.Graph().as_default(): - with tf.Session(worker.target, config=self._get_session_config()) as sess: - infer_m = self._createTestInferModel(model.Model, hparams, sess) - self._assertInferLogits(infer_m, sess, 'NoAttentionResidualBiEncoder') - - ## Test attention mechanisms: luong, scaled_luong, bahdanau, normed_bahdanau - def testAttentionMechanismLuong(self): - hparams = common_test_utils.create_test_hparams( - encoder_type='uni', - attention='luong', - attention_architecture='standard', - num_layers=2, - use_residual=False,) - - workers, _ = tf.test.create_local_cluster(1, 0) - worker = workers[0] - - # pylint: disable=line-too-long - expected_var_names = [ - 'dynamic_seq2seq/encoder/embedding_encoder:0', - 'dynamic_seq2seq/decoder/embedding_decoder:0', - 'dynamic_seq2seq/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel:0', - 'dynamic_seq2seq/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/bias:0', - 'dynamic_seq2seq/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel:0', - 'dynamic_seq2seq/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/bias:0', - 'dynamic_seq2seq/decoder/memory_layer/kernel:0', - 'dynamic_seq2seq/decoder/attention/multi_rnn_cell/cell_0/basic_lstm_cell/kernel:0', - 'dynamic_seq2seq/decoder/attention/multi_rnn_cell/cell_0/basic_lstm_cell/bias:0', - 'dynamic_seq2seq/decoder/attention/multi_rnn_cell/cell_1/basic_lstm_cell/kernel:0', - 'dynamic_seq2seq/decoder/attention/multi_rnn_cell/cell_1/basic_lstm_cell/bias:0', - 'dynamic_seq2seq/decoder/attention/attention_layer/kernel:0', - 'dynamic_seq2seq/decoder/output_projection/kernel:0' - ] - # pylint: enable=line-too-long - with tf.Graph().as_default(): - with tf.Session(worker.target, config=self._get_session_config()) as sess: - train_m = self._createTestTrainModel(attention_model.AttentionModel, - hparams, sess) - - m_vars = tf.trainable_variables() - self._assertModelVariableNames( - expected_var_names, [v.name - for v in m_vars], 'AttentionMechanismLuong') - - with tf.variable_scope('dynamic_seq2seq', reuse=True): - # pylint: disable=line-too-long - last_enc_weight = tf.get_variable( - 'encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel') - last_dec_weight = tf.get_variable( - 'decoder/attention/multi_rnn_cell/cell_1/basic_lstm_cell/kernel') - att_layer_weight = tf.get_variable( - 'decoder/attention/attention_layer/kernel') - # pylint: enable=line-too-long - self._assertTrainStepsLoss(train_m, sess, 'AttentionMechanismLuong') - self._assertModelVariable(last_enc_weight, sess, - 'AttentionMechanismLuong/last_enc_weight') - self._assertModelVariable(last_dec_weight, sess, - 'AttentionMechanismLuong/last_dec_weight') - self._assertModelVariable(att_layer_weight, sess, - 'AttentionMechanismLuong/att_layer_weight') - - with tf.Graph().as_default(): - with tf.Session(worker.target, config=self._get_session_config()) as sess: - eval_m = self._createTestEvalModel(attention_model.AttentionModel, - hparams, sess) - self._assertEvalLossAndPredictCount(eval_m, sess, - 'AttentionMechanismLuong') - - with tf.Graph().as_default(): - with tf.Session(worker.target, config=self._get_session_config()) as sess: - infer_m = self._createTestInferModel(attention_model.AttentionModel, - hparams, sess) - self._assertInferLogits(infer_m, sess, 'AttentionMechanismLuong') - - def testAttentionMechanismScaledLuong(self): - hparams = common_test_utils.create_test_hparams( - encoder_type='uni', - attention='scaled_luong', - attention_architecture='standard', - num_layers=2, - use_residual=False,) - - workers, _ = tf.test.create_local_cluster(1, 0) - worker = workers[0] - - # pylint: disable=line-too-long - expected_var_names = [ - 'dynamic_seq2seq/encoder/embedding_encoder:0', - 'dynamic_seq2seq/decoder/embedding_decoder:0', - 'dynamic_seq2seq/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel:0', - 'dynamic_seq2seq/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/bias:0', - 'dynamic_seq2seq/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel:0', - 'dynamic_seq2seq/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/bias:0', - 'dynamic_seq2seq/decoder/memory_layer/kernel:0', - 'dynamic_seq2seq/decoder/attention/multi_rnn_cell/cell_0/basic_lstm_cell/kernel:0', - 'dynamic_seq2seq/decoder/attention/multi_rnn_cell/cell_0/basic_lstm_cell/bias:0', - 'dynamic_seq2seq/decoder/attention/multi_rnn_cell/cell_1/basic_lstm_cell/kernel:0', - 'dynamic_seq2seq/decoder/attention/multi_rnn_cell/cell_1/basic_lstm_cell/bias:0', - 'dynamic_seq2seq/decoder/attention/luong_attention/attention_g:0', - 'dynamic_seq2seq/decoder/attention/attention_layer/kernel:0', - 'dynamic_seq2seq/decoder/output_projection/kernel:0' - ] - # pylint: enable=line-too-long - with tf.Graph().as_default(): - with tf.Session(worker.target, config=self._get_session_config()) as sess: - train_m = self._createTestTrainModel(attention_model.AttentionModel, - hparams, sess) - - m_vars = tf.trainable_variables() - self._assertModelVariableNames(expected_var_names, - [v.name for v in m_vars], - 'AttentionMechanismScaledLuong') - - with tf.variable_scope('dynamic_seq2seq', reuse=True): - # pylint: disable=line-too-long - last_enc_weight = tf.get_variable( - 'encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel') - last_dec_weight = tf.get_variable( - 'decoder/attention/multi_rnn_cell/cell_1/basic_lstm_cell/kernel') - att_layer_weight = tf.get_variable( - 'decoder/attention/attention_layer/kernel') - # pylint: enable=line-too-long - - self._assertTrainStepsLoss(train_m, sess, - 'AttentionMechanismScaledLuong') - self._assertModelVariable( - last_enc_weight, sess, - 'AttentionMechanismScaledLuong/last_enc_weight') - self._assertModelVariable( - last_dec_weight, sess, - 'AttentionMechanismScaledLuong/last_dec_weight') - self._assertModelVariable( - att_layer_weight, sess, - 'AttentionMechanismScaledLuong/att_layer_weight') - - with tf.Graph().as_default(): - with tf.Session(worker.target, config=self._get_session_config()) as sess: - eval_m = self._createTestEvalModel(attention_model.AttentionModel, - hparams, sess) - self._assertEvalLossAndPredictCount(eval_m, sess, - 'AttentionMechanismScaledLuong') - - with tf.Graph().as_default(): - with tf.Session(worker.target, config=self._get_session_config()) as sess: - infer_m = self._createTestInferModel(attention_model.AttentionModel, - hparams, sess) - self._assertInferLogits(infer_m, sess, 'AttentionMechanismScaledLuong') - - def testAttentionMechanismBahdanau(self): - hparams = common_test_utils.create_test_hparams( - encoder_type='uni', - attention='bahdanau', - attention_architecture='standard', - num_layers=2, - use_residual=False,) - - workers, _ = tf.test.create_local_cluster(1, 0) - worker = workers[0] - - # pylint: disable=line-too-long - expected_var_names = [ - 'dynamic_seq2seq/encoder/embedding_encoder:0', - 'dynamic_seq2seq/decoder/embedding_decoder:0', - 'dynamic_seq2seq/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel:0', - 'dynamic_seq2seq/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/bias:0', - 'dynamic_seq2seq/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel:0', - 'dynamic_seq2seq/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/bias:0', - 'dynamic_seq2seq/decoder/memory_layer/kernel:0', - 'dynamic_seq2seq/decoder/attention/multi_rnn_cell/cell_0/basic_lstm_cell/kernel:0', - 'dynamic_seq2seq/decoder/attention/multi_rnn_cell/cell_0/basic_lstm_cell/bias:0', - 'dynamic_seq2seq/decoder/attention/multi_rnn_cell/cell_1/basic_lstm_cell/kernel:0', - 'dynamic_seq2seq/decoder/attention/multi_rnn_cell/cell_1/basic_lstm_cell/bias:0', - 'dynamic_seq2seq/decoder/attention/bahdanau_attention/query_layer/kernel:0', - 'dynamic_seq2seq/decoder/attention/bahdanau_attention/attention_v:0', - 'dynamic_seq2seq/decoder/attention/attention_layer/kernel:0', - 'dynamic_seq2seq/decoder/output_projection/kernel:0' - ] - # pylint: enable=line-too-long - with tf.Graph().as_default(): - with tf.Session(worker.target, config=self._get_session_config()) as sess: - train_m = self._createTestTrainModel(attention_model.AttentionModel, - hparams, sess) - - m_vars = tf.trainable_variables() - self._assertModelVariableNames( - expected_var_names, [v.name - for v in m_vars], 'AttentionMechanismBahdanau') - - with tf.variable_scope('dynamic_seq2seq', reuse=True): - # pylint: disable=line-too-long - last_enc_weight = tf.get_variable( - 'encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel') - last_dec_weight = tf.get_variable( - 'decoder/attention/multi_rnn_cell/cell_1/basic_lstm_cell/kernel') - att_layer_weight = tf.get_variable( - 'decoder/attention/attention_layer/kernel') - # pylint: enable=line-too-long - self._assertTrainStepsLoss(train_m, sess, 'AttentionMechanismBahdanau') - self._assertModelVariable(last_enc_weight, sess, - 'AttentionMechanismBahdanau/last_enc_weight') - self._assertModelVariable(last_dec_weight, sess, - 'AttentionMechanismBahdanau/last_dec_weight') - self._assertModelVariable(att_layer_weight, sess, - 'AttentionMechanismBahdanau/att_layer_weight') - - with tf.Graph().as_default(): - with tf.Session(worker.target, config=self._get_session_config()) as sess: - eval_m = self._createTestEvalModel(attention_model.AttentionModel, - hparams, sess) - self._assertEvalLossAndPredictCount(eval_m, sess, - 'AttentionMechanismBahdanau') - - with tf.Graph().as_default(): - with tf.Session(worker.target, config=self._get_session_config()) as sess: - infer_m = self._createTestInferModel(attention_model.AttentionModel, - hparams, sess) - self._assertInferLogits(infer_m, sess, 'AttentionMechanismBahdanau') - - def testAttentionMechanismNormedBahdanau(self): - hparams = common_test_utils.create_test_hparams( - encoder_type='uni', - attention='normed_bahdanau', - attention_architecture='standard', - num_layers=2, - use_residual=False,) - - workers, _ = tf.test.create_local_cluster(1, 0) - worker = workers[0] - - # pylint: disable=line-too-long - expected_var_names = [ - 'dynamic_seq2seq/encoder/embedding_encoder:0', - 'dynamic_seq2seq/decoder/embedding_decoder:0', - 'dynamic_seq2seq/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel:0', - 'dynamic_seq2seq/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/bias:0', - 'dynamic_seq2seq/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel:0', - 'dynamic_seq2seq/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/bias:0', - 'dynamic_seq2seq/decoder/memory_layer/kernel:0', - 'dynamic_seq2seq/decoder/attention/multi_rnn_cell/cell_0/basic_lstm_cell/kernel:0', - 'dynamic_seq2seq/decoder/attention/multi_rnn_cell/cell_0/basic_lstm_cell/bias:0', - 'dynamic_seq2seq/decoder/attention/multi_rnn_cell/cell_1/basic_lstm_cell/kernel:0', - 'dynamic_seq2seq/decoder/attention/multi_rnn_cell/cell_1/basic_lstm_cell/bias:0', - 'dynamic_seq2seq/decoder/attention/bahdanau_attention/query_layer/kernel:0', - 'dynamic_seq2seq/decoder/attention/bahdanau_attention/attention_v:0', - 'dynamic_seq2seq/decoder/attention/bahdanau_attention/attention_g:0', - 'dynamic_seq2seq/decoder/attention/bahdanau_attention/attention_b:0', - 'dynamic_seq2seq/decoder/attention/attention_layer/kernel:0', - 'dynamic_seq2seq/decoder/output_projection/kernel:0' - ] - # pylint: enable=line-too-long - - with tf.Graph().as_default(): - with tf.Session(worker.target, config=self._get_session_config()) as sess: - train_m = self._createTestTrainModel(attention_model.AttentionModel, - hparams, sess) - - m_vars = tf.trainable_variables() - self._assertModelVariableNames(expected_var_names, - [v.name for v in m_vars], - 'AttentionMechanismNormedBahdanau') - - with tf.variable_scope('dynamic_seq2seq', reuse=True): - # pylint: disable=line-too-long - last_enc_weight = tf.get_variable( - 'encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel') - last_dec_weight = tf.get_variable( - 'decoder/attention/multi_rnn_cell/cell_1/basic_lstm_cell/kernel') - att_layer_weight = tf.get_variable( - 'decoder/attention/attention_layer/kernel') - # pylint: enable=line-too-long - self._assertTrainStepsLoss(train_m, sess, - 'AttentionMechanismNormedBahdanau') - self._assertModelVariable( - last_enc_weight, sess, - 'AttentionMechanismNormedBahdanau/last_enc_weight') - self._assertModelVariable( - last_dec_weight, sess, - 'AttentionMechanismNormedBahdanau/last_dec_weight') - self._assertModelVariable( - att_layer_weight, sess, - 'AttentionMechanismNormedBahdanau/att_layer_weight') - - with tf.Graph().as_default(): - with tf.Session(worker.target, config=self._get_session_config()) as sess: - eval_m = self._createTestEvalModel(attention_model.AttentionModel, - hparams, sess) - self._assertEvalLossAndPredictCount(eval_m, sess, - 'AttentionMechanismNormedBahdanau') - - with tf.Graph().as_default(): - with tf.Session(worker.target, config=self._get_session_config()) as sess: - infer_m = self._createTestInferModel(attention_model.AttentionModel, - hparams, sess) - self._assertInferLogits(infer_m, sess, - 'AttentionMechanismNormedBahdanau') - - ## Test encoder vs. attention (all use residual): - # uni encoder, standard attention - def testUniEncoderStandardAttentionArchitecture(self): - hparams = common_test_utils.create_test_hparams( - encoder_type='uni', - num_layers=4, - attention='scaled_luong', - attention_architecture='standard',) - - workers, _ = tf.test.create_local_cluster(1, 0) - worker = workers[0] - - # pylint: disable=line-too-long - expected_var_names = [ - 'dynamic_seq2seq/encoder/embedding_encoder:0', - 'dynamic_seq2seq/decoder/embedding_decoder:0', - 'dynamic_seq2seq/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel:0', - 'dynamic_seq2seq/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/bias:0', - 'dynamic_seq2seq/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel:0', - 'dynamic_seq2seq/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/bias:0', - 'dynamic_seq2seq/encoder/rnn/multi_rnn_cell/cell_2/basic_lstm_cell/kernel:0', - 'dynamic_seq2seq/encoder/rnn/multi_rnn_cell/cell_2/basic_lstm_cell/bias:0', - 'dynamic_seq2seq/encoder/rnn/multi_rnn_cell/cell_3/basic_lstm_cell/kernel:0', - 'dynamic_seq2seq/encoder/rnn/multi_rnn_cell/cell_3/basic_lstm_cell/bias:0', - 'dynamic_seq2seq/decoder/memory_layer/kernel:0', - 'dynamic_seq2seq/decoder/attention/multi_rnn_cell/cell_0/basic_lstm_cell/kernel:0', - 'dynamic_seq2seq/decoder/attention/multi_rnn_cell/cell_0/basic_lstm_cell/bias:0', - 'dynamic_seq2seq/decoder/attention/multi_rnn_cell/cell_1/basic_lstm_cell/kernel:0', - 'dynamic_seq2seq/decoder/attention/multi_rnn_cell/cell_1/basic_lstm_cell/bias:0', - 'dynamic_seq2seq/decoder/attention/multi_rnn_cell/cell_2/basic_lstm_cell/kernel:0', - 'dynamic_seq2seq/decoder/attention/multi_rnn_cell/cell_2/basic_lstm_cell/bias:0', - 'dynamic_seq2seq/decoder/attention/multi_rnn_cell/cell_3/basic_lstm_cell/kernel:0', - 'dynamic_seq2seq/decoder/attention/multi_rnn_cell/cell_3/basic_lstm_cell/bias:0', - 'dynamic_seq2seq/decoder/attention/luong_attention/attention_g:0', - 'dynamic_seq2seq/decoder/attention/attention_layer/kernel:0', - 'dynamic_seq2seq/decoder/output_projection/kernel:0' - ] - # pylint: enable=line-too-long - - with tf.Graph().as_default(): - with tf.Session(worker.target, config=self._get_session_config()) as sess: - train_m = self._createTestTrainModel(attention_model.AttentionModel, - hparams, sess) - - m_vars = tf.trainable_variables() - self._assertModelVariableNames(expected_var_names, [ - v.name for v in m_vars - ], 'UniEncoderStandardAttentionArchitecture') - with tf.variable_scope('dynamic_seq2seq', reuse=True): - last_enc_weight = tf.get_variable( - 'encoder/rnn/multi_rnn_cell/cell_3/basic_lstm_cell/kernel') - last_dec_weight = tf.get_variable( - 'decoder/attention/multi_rnn_cell/cell_3/basic_lstm_cell/kernel') - mem_layer_weight = tf.get_variable('decoder/memory_layer/kernel') - self._assertTrainStepsLoss(train_m, sess, - 'UniEncoderStandardAttentionArchitecture') - self._assertModelVariable( - last_enc_weight, sess, - 'UniEncoderStandardAttentionArchitecture/last_enc_weight') - self._assertModelVariable( - last_dec_weight, sess, - 'UniEncoderStandardAttentionArchitecture/last_dec_weight') - self._assertModelVariable( - mem_layer_weight, sess, - 'UniEncoderStandardAttentionArchitecture/mem_layer_weight') - - with tf.Graph().as_default(): - with tf.Session(worker.target, config=self._get_session_config()) as sess: - eval_m = self._createTestEvalModel(attention_model.AttentionModel, - hparams, sess) - self._assertEvalLossAndPredictCount( - eval_m, sess, 'UniEncoderStandardAttentionArchitecture') - - with tf.Graph().as_default(): - with tf.Session(worker.target, config=self._get_session_config()) as sess: - infer_m = self._createTestInferModel(attention_model.AttentionModel, - hparams, sess) - self._assertInferLogits(infer_m, sess, - 'UniEncoderStandardAttentionArchitecture') - - # Test gnmt model. - def _testGNMTModel(self, architecture): - hparams = common_test_utils.create_test_hparams( - encoder_type='gnmt', - num_layers=4, - attention='scaled_luong', - attention_architecture=architecture) - - workers, _ = tf.test.create_local_cluster(1, 0) - worker = workers[0] - - # pylint: disable=line-too-long - expected_var_names = [ - 'dynamic_seq2seq/encoder/embedding_encoder:0', - 'dynamic_seq2seq/decoder/embedding_decoder:0', - 'dynamic_seq2seq/encoder/bidirectional_rnn/fw/basic_lstm_cell/kernel:0', - 'dynamic_seq2seq/encoder/bidirectional_rnn/fw/basic_lstm_cell/bias:0', - 'dynamic_seq2seq/encoder/bidirectional_rnn/bw/basic_lstm_cell/kernel:0', - 'dynamic_seq2seq/encoder/bidirectional_rnn/bw/basic_lstm_cell/bias:0', - 'dynamic_seq2seq/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel:0', - 'dynamic_seq2seq/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/bias:0', - 'dynamic_seq2seq/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel:0', - 'dynamic_seq2seq/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/bias:0', - 'dynamic_seq2seq/encoder/rnn/multi_rnn_cell/cell_2/basic_lstm_cell/kernel:0', - 'dynamic_seq2seq/encoder/rnn/multi_rnn_cell/cell_2/basic_lstm_cell/bias:0', - 'dynamic_seq2seq/decoder/memory_layer/kernel:0', - 'dynamic_seq2seq/decoder/multi_rnn_cell/cell_0_attention/attention/basic_lstm_cell/kernel:0', - 'dynamic_seq2seq/decoder/multi_rnn_cell/cell_0_attention/attention/basic_lstm_cell/bias:0', - 'dynamic_seq2seq/decoder/multi_rnn_cell/cell_0_attention/attention/luong_attention/attention_g:0', - 'dynamic_seq2seq/decoder/multi_rnn_cell/cell_1/basic_lstm_cell/kernel:0', - 'dynamic_seq2seq/decoder/multi_rnn_cell/cell_1/basic_lstm_cell/bias:0', - 'dynamic_seq2seq/decoder/multi_rnn_cell/cell_2/basic_lstm_cell/kernel:0', - 'dynamic_seq2seq/decoder/multi_rnn_cell/cell_2/basic_lstm_cell/bias:0', - 'dynamic_seq2seq/decoder/multi_rnn_cell/cell_3/basic_lstm_cell/kernel:0', - 'dynamic_seq2seq/decoder/multi_rnn_cell/cell_3/basic_lstm_cell/bias:0', - 'dynamic_seq2seq/decoder/output_projection/kernel:0' - ] - # pylint: enable=line-too-long - - test_prefix = 'GNMTModel_%s' % architecture - with tf.Graph().as_default(): - with tf.Session(worker.target, config=self._get_session_config()) as sess: - train_m = self._createTestTrainModel(gnmt_model.GNMTModel, hparams, - sess) - - m_vars = tf.trainable_variables() - self._assertModelVariableNames(expected_var_names, - [v.name for v in m_vars], test_prefix) - with tf.variable_scope('dynamic_seq2seq', reuse=True): - last_enc_weight = tf.get_variable( - 'encoder/rnn/multi_rnn_cell/cell_2/basic_lstm_cell/kernel') - last_dec_weight = tf.get_variable( - 'decoder/multi_rnn_cell/cell_3/basic_lstm_cell/kernel') - mem_layer_weight = tf.get_variable('decoder/memory_layer/kernel') - self._assertTrainStepsLoss(train_m, sess, test_prefix) - - self._assertModelVariable(last_enc_weight, sess, - '%s/last_enc_weight' % test_prefix) - self._assertModelVariable(last_dec_weight, sess, - '%s/last_dec_weight' % test_prefix) - self._assertModelVariable(mem_layer_weight, sess, - '%s/mem_layer_weight' % test_prefix) - - with tf.Graph().as_default(): - with tf.Session(worker.target, config=self._get_session_config()) as sess: - eval_m = self._createTestEvalModel(gnmt_model.GNMTModel, hparams, sess) - self._assertEvalLossAndPredictCount(eval_m, sess, test_prefix) - - with tf.Graph().as_default(): - with tf.Session(worker.target, config=self._get_session_config()) as sess: - infer_m = self._createTestInferModel(gnmt_model.GNMTModel, hparams, - sess) - self._assertInferLogits(infer_m, sess, test_prefix) - - def testGNMTModel(self): - self._testGNMTModel('gnmt') - - def testGNMTModelV2(self): - self._testGNMTModel('gnmt_v2') - - # Test beam search. - def testBeamSearchBasicModel(self): - hparams = common_test_utils.create_test_hparams( - encoder_type='uni', - num_layers=1, - attention='', - attention_architecture='', - use_residual=False,) - hparams.beam_width = 3 - hparams.infer_mode = "beam_search" - hparams.tgt_max_len_infer = 4 - assert_top_k_sentence = 3 - - with self.test_session() as sess: - infer_m = self._createTestInferModel( - model.Model, hparams, sess, True) - self._assertBeamSearchOutputs( - infer_m, sess, assert_top_k_sentence, 'BeamSearchBasicModel') - - def testBeamSearchAttentionModel(self): - hparams = common_test_utils.create_test_hparams( - encoder_type='uni', - attention='scaled_luong', - attention_architecture='standard', - num_layers=2, - use_residual=False,) - hparams.beam_width = 3 - hparams.infer_mode = "beam_search" - hparams.tgt_max_len_infer = 4 - assert_top_k_sentence = 2 - - with self.test_session() as sess: - infer_m = self._createTestInferModel( - attention_model.AttentionModel, hparams, sess, True) - self._assertBeamSearchOutputs( - infer_m, sess, assert_top_k_sentence, 'BeamSearchAttentionModel') - - def testBeamSearchGNMTModel(self): - hparams = common_test_utils.create_test_hparams( - encoder_type='gnmt', - num_layers=4, - attention='scaled_luong', - attention_architecture='gnmt') - hparams.beam_width = 3 - hparams.infer_mode = "beam_search" - hparams.tgt_max_len_infer = 4 - assert_top_k_sentence = 1 - - with self.test_session() as sess: - infer_m = self._createTestInferModel( - gnmt_model.GNMTModel, hparams, sess, True) - self._assertBeamSearchOutputs( - infer_m, sess, assert_top_k_sentence, 'BeamSearchGNMTModel') - - def testInitializerGlorotNormal(self): - hparams = common_test_utils.create_test_hparams( - encoder_type='uni', - num_layers=1, - attention='', - attention_architecture='', - use_residual=False, - init_op='glorot_normal') - - with self.test_session() as sess: - train_m = self._createTestTrainModel(model.Model, hparams, sess) - self._assertTrainStepsLoss(train_m, sess, - 'InitializerGlorotNormal') - - def testInitializerGlorotUniform(self): - hparams = common_test_utils.create_test_hparams( - encoder_type='uni', - num_layers=1, - attention='', - attention_architecture='', - use_residual=False, - init_op='glorot_uniform') - - with self.test_session() as sess: - train_m = self._createTestTrainModel(model.Model, hparams, sess) - self._assertTrainStepsLoss(train_m, sess, - 'InitializerGlorotUniform') - - def testSampledSoftmaxLoss(self): - hparams = common_test_utils.create_test_hparams( - encoder_type='gnmt', - num_layers=4, - attention='scaled_luong', - attention_architecture='gnmt') - hparams.num_sampled_softmax = 3 - - with self.test_session() as sess: - train_m = self._createTestTrainModel(gnmt_model.GNMTModel, hparams, sess) - self._assertTrainStepsLoss(train_m, sess, - 'SampledSoftmaxLoss') + @classmethod + def setUpClass(cls): + cls.actual_vars_values = {} + cls.expected_vars_values = { + 'AttentionMechanismBahdanau/att_layer_weight/shape': (10, 5), + 'AttentionMechanismBahdanau/att_layer_weight/sum': + -0.64981574, + 'AttentionMechanismBahdanau/last_dec_weight/shape': (10, 20), + 'AttentionMechanismBahdanau/last_dec_weight/sum': + 0.058069646, + 'AttentionMechanismBahdanau/last_enc_weight/shape': (10, 20), + 'AttentionMechanismBahdanau/last_enc_weight/sum': + 0.058028102, + 'AttentionMechanismLuong/att_layer_weight/shape': (10, 5), + 'AttentionMechanismLuong/att_layer_weight/sum': + -0.64981574, + 'AttentionMechanismLuong/last_dec_weight/shape': (10, 20), + 'AttentionMechanismLuong/last_dec_weight/sum': + 0.058069646, + 'AttentionMechanismLuong/last_enc_weight/shape': (10, 20), + 'AttentionMechanismLuong/last_enc_weight/sum': + 0.058028102, + 'AttentionMechanismNormedBahdanau/att_layer_weight/shape': (10, 5), + 'AttentionMechanismNormedBahdanau/att_layer_weight/sum': + -0.64981973, + 'AttentionMechanismNormedBahdanau/last_dec_weight/shape': (10, 20), + 'AttentionMechanismNormedBahdanau/last_dec_weight/sum': + 0.058067322, + 'AttentionMechanismNormedBahdanau/last_enc_weight/shape': (10, 20), + 'AttentionMechanismNormedBahdanau/last_enc_weight/sum': + 0.058022559, + 'AttentionMechanismScaledLuong/att_layer_weight/shape': (10, 5), + 'AttentionMechanismScaledLuong/att_layer_weight/sum': + -0.64981574, + 'AttentionMechanismScaledLuong/last_dec_weight/shape': (10, 20), + 'AttentionMechanismScaledLuong/last_dec_weight/sum': + 0.058069646, + 'AttentionMechanismScaledLuong/last_enc_weight/shape': (10, 20), + 'AttentionMechanismScaledLuong/last_enc_weight/sum': + 0.058028102, + 'GNMTModel_gnmt/last_dec_weight/shape': (15, 20), + 'GNMTModel_gnmt/last_dec_weight/sum': + -0.48634407, + 'GNMTModel_gnmt/last_enc_weight/shape': (10, 20), + 'GNMTModel_gnmt/last_enc_weight/sum': + 0.058025002, + 'GNMTModel_gnmt/mem_layer_weight/shape': (5, 5), + 'GNMTModel_gnmt/mem_layer_weight/sum': + -0.44815454, + 'GNMTModel_gnmt_v2/last_dec_weight/shape': (15, 20), + 'GNMTModel_gnmt_v2/last_dec_weight/sum': + -0.48634392, + 'GNMTModel_gnmt_v2/last_enc_weight/shape': (10, 20), + 'GNMTModel_gnmt_v2/last_enc_weight/sum': + 0.058024824, + 'GNMTModel_gnmt_v2/mem_layer_weight/shape': (5, 5), + 'GNMTModel_gnmt_v2/mem_layer_weight/sum': + -0.44815454, + 'NoAttentionNoResidualUniEncoder/last_dec_weight/shape': (10, 20), + 'NoAttentionNoResidualUniEncoder/last_dec_weight/sum': + 0.057424068, + 'NoAttentionNoResidualUniEncoder/last_enc_weight/shape': (10, 20), + 'NoAttentionNoResidualUniEncoder/last_enc_weight/sum': + 0.058453858, + 'NoAttentionResidualBiEncoder/last_dec_weight/shape': (10, 20), + 'NoAttentionResidualBiEncoder/last_dec_weight/sum': + 0.058025062, + 'NoAttentionResidualBiEncoder/last_enc_weight/shape': (10, 20), + 'NoAttentionResidualBiEncoder/last_enc_weight/sum': + 0.058053195, + 'UniEncoderBottomAttentionArchitecture/last_dec_weight/shape': (10, 20), + 'UniEncoderBottomAttentionArchitecture/last_dec_weight/sum': + 0.058024943, + 'UniEncoderBottomAttentionArchitecture/last_enc_weight/shape': (10, 20), + 'UniEncoderBottomAttentionArchitecture/last_enc_weight/sum': + 0.058025122, + 'UniEncoderBottomAttentionArchitecture/mem_layer_weight/shape': (5, 5), + 'UniEncoderBottomAttentionArchitecture/mem_layer_weight/sum': + -0.44815454, + 'UniEncoderStandardAttentionArchitecture/last_dec_weight/shape': (10, + 20), + 'UniEncoderStandardAttentionArchitecture/last_dec_weight/sum': + 0.058025002, + 'UniEncoderStandardAttentionArchitecture/last_enc_weight/shape': (10, + 20), + 'UniEncoderStandardAttentionArchitecture/last_enc_weight/sum': + 0.058024883, + 'UniEncoderStandardAttentionArchitecture/mem_layer_weight/shape': (5, + 5), + 'UniEncoderStandardAttentionArchitecture/mem_layer_weight/sum': + -0.44815454, + } + + cls.actual_train_values = {} + cls.expected_train_values = { + 'AttentionMechanismBahdanau/loss': 8.8519039, + 'AttentionMechanismLuong/loss': 8.8519039, + 'AttentionMechanismNormedBahdanau/loss': 8.851902, + 'AttentionMechanismScaledLuong/loss': 8.8519039, + 'GNMTModel_gnmt/loss': 8.8519087, + 'GNMTModel_gnmt_v2/loss': 8.8519087, + 'NoAttentionNoResidualUniEncoder/loss': 8.8516064, + 'NoAttentionResidualBiEncoder/loss': 8.851984, + 'UniEncoderStandardAttentionArchitecture/loss': 8.8519087, + 'InitializerGlorotNormal/loss': 8.9779415, + 'InitializerGlorotUniform/loss': 8.7643699, + 'SampledSoftmaxLoss/loss': 5.83928, + } + + cls.actual_eval_values = {} + cls.expected_eval_values = { + 'AttentionMechanismBahdanau/loss': 8.8517132, + 'AttentionMechanismBahdanau/predict_count': 11.0, + 'AttentionMechanismLuong/loss': 8.8517132, + 'AttentionMechanismLuong/predict_count': 11.0, + 'AttentionMechanismNormedBahdanau/loss': 8.8517132, + 'AttentionMechanismNormedBahdanau/predict_count': 11.0, + 'AttentionMechanismScaledLuong/loss': 8.8517132, + 'AttentionMechanismScaledLuong/predict_count': 11.0, + 'GNMTModel_gnmt/loss': 8.8443403, + 'GNMTModel_gnmt/predict_count': 11.0, + 'GNMTModel_gnmt_v2/loss': 8.8443756, + 'GNMTModel_gnmt_v2/predict_count': 11.0, + 'NoAttentionNoResidualUniEncoder/loss': 8.8440113, + 'NoAttentionNoResidualUniEncoder/predict_count': 11.0, + 'NoAttentionResidualBiEncoder/loss': 8.8291245, + 'NoAttentionResidualBiEncoder/predict_count': 11.0, + 'UniEncoderBottomAttentionArchitecture/loss': 8.844492, + 'UniEncoderBottomAttentionArchitecture/predict_count': 11.0, + 'UniEncoderStandardAttentionArchitecture/loss': 8.8517151, + 'UniEncoderStandardAttentionArchitecture/predict_count': 11.0 + } + + cls.actual_infer_values = {} + cls.expected_infer_values = { + 'AttentionMechanismBahdanau/logits_sum': -0.026374687, + 'AttentionMechanismLuong/logits_sum': -0.026374735, + 'AttentionMechanismNormedBahdanau/logits_sum': -0.026376063, + 'AttentionMechanismScaledLuong/logits_sum': -0.026374735, + 'GNMTModel_gnmt/logits_sum': -1.10848486, + 'GNMTModel_gnmt_v2/logits_sum': -1.10950875, + 'NoAttentionNoResidualUniEncoder/logits_sum': -1.0808625, + 'NoAttentionResidualBiEncoder/logits_sum': -2.8147559, + 'UniEncoderBottomAttentionArchitecture/logits_sum': -0.97026241, + 'UniEncoderStandardAttentionArchitecture/logits_sum': -0.02665353 + } + + cls.actual_beam_sentences = {} + cls.expected_beam_sentences = { + 'BeamSearchAttentionModel: batch 0 of beam 0': '', + 'BeamSearchAttentionModel: batch 0 of beam 1': '%s a %s a' % (SOS, SOS), + 'BeamSearchAttentionModel: batch 1 of beam 0': '', + 'BeamSearchAttentionModel: batch 1 of beam 1': 'b', + 'BeamSearchBasicModel: batch 0 of beam 0': 'b b b b', + 'BeamSearchBasicModel: batch 0 of beam 1': 'b b b %s' % SOS, + 'BeamSearchBasicModel: batch 0 of beam 2': 'b b b c', + 'BeamSearchBasicModel: batch 1 of beam 0': 'b b b b', + 'BeamSearchBasicModel: batch 1 of beam 1': 'a b b b', + 'BeamSearchBasicModel: batch 1 of beam 2': 'b b b %s' % SOS, + 'BeamSearchGNMTModel: batch 0 of beam 0': '', + 'BeamSearchGNMTModel: batch 1 of beam 0': '', + } + cls.expected_beam_sentences = dict( + (k, v.encode()) for k, v in cls.expected_beam_sentences.items()) + + @classmethod + def tearDownClass(cls): + print('ModelTest - actual_vars_values: ') + pprint.pprint(cls.actual_vars_values) + sys.stdout.flush() + + print('ModelTest - actual_train_values: ') + pprint.pprint(cls.actual_train_values) + sys.stdout.flush() + + print('ModelTest - actual_eval_values: ') + pprint.pprint(cls.actual_eval_values) + sys.stdout.flush() + + print('ModelTest - actual_infer_values: ') + pprint.pprint(cls.actual_infer_values) + sys.stdout.flush() + + print('ModelTest - actual_beam_sentences: ') + pprint.pprint(cls.actual_beam_sentences) + sys.stdout.flush() + + def assertAllClose(self, *args, **kwargs): + kwargs['atol'] = 5e-2 + kwargs['rtol'] = 5e-2 + return super(ModelTest, self).assertAllClose(*args, **kwargs) + + def _assertModelVariableNames(self, expected_var_names, model_var_names, + name): + + print('{} variable names are: '.format(name), model_var_names) + + self.assertEqual(len(expected_var_names), len(model_var_names)) + self.assertEqual(sorted(expected_var_names), sorted(model_var_names)) + + def _assertModelVariable(self, variable, sess, name): + var_shape = tuple(variable.get_shape().as_list()) + var_res = sess.run(variable) + var_weight_sum = np.sum(var_res) + + print('{} weight sum is: '.format(name), var_weight_sum) + expected_sum = self.expected_vars_values[name + '/sum'] + expected_shape = self.expected_vars_values[name + '/shape'] + self.actual_vars_values[name + '/sum'] = var_weight_sum + self.actual_vars_values[name + '/shape'] = var_shape + + self.assertEqual(expected_shape, var_shape) + self.assertAllClose(expected_sum, var_weight_sum) + + def _assertTrainStepsLoss(self, m, sess, name, num_steps=1): + for _ in range(num_steps): + _, output_tuple = m.train(sess) + loss = output_tuple.train_loss + print('{} {}-th step loss is: '.format(name, num_steps), loss) + expected_loss = self.expected_train_values[name + '/loss'] + self.actual_train_values[name + '/loss'] = loss + + self.assertAllClose(expected_loss, loss) + + def _assertEvalLossAndPredictCount(self, m, sess, name): + output_tuple = m.eval(sess) + loss = output_tuple.eval_loss + predict_count = output_tuple.predict_count + print('{} eval loss is: '.format(name), loss) + print('{} predict count is: '.format(name), predict_count) + expected_loss = self.expected_eval_values[name + '/loss'] + expected_predict_count = self.expected_eval_values[name + '/predict_count'] + self.actual_eval_values[name + '/loss'] = loss + self.actual_eval_values[name + '/predict_count'] = predict_count + + self.assertAllClose(expected_loss, loss) + self.assertAllClose(expected_predict_count, predict_count) + + def _assertInferLogits(self, m, sess, name): + output_tuple = m.infer(sess) + logits_sum = np.sum(output_tuple.infer_logits) + + print('{} infer logits sum is: '.format(name), logits_sum) + expected_logits_sum = self.expected_infer_values[name + '/logits_sum'] + self.actual_infer_values[name + '/logits_sum'] = logits_sum + + self.assertAllClose(expected_logits_sum, logits_sum) + + def _assertBeamSearchOutputs(self, m, sess, assert_top_k_sentence, name): + nmt_outputs, _ = m.decode(sess) + + for i in range(assert_top_k_sentence): + output_words = nmt_outputs[i] + for j in range(output_words.shape[0]): + sentence = nmt_utils.get_translation( + output_words, j, tgt_eos=EOS, subword_option='') + sentence_key = ('%s: batch %d of beam %d' % (name, j, i)) + self.actual_beam_sentences[sentence_key] = sentence + expected_sentence = self.expected_beam_sentences[sentence_key] + self.assertEqual(expected_sentence, sentence) + + def _createTestTrainModel(self, m_creator, hparams, sess): + train_mode = tf.contrib.learn.ModeKeys.TRAIN + train_iterator, src_vocab_table, tgt_vocab_table = ( + common_test_utils.create_test_iterator(hparams, train_mode)) + train_m = m_creator( + hparams, + train_mode, + train_iterator, + src_vocab_table, + tgt_vocab_table, + scope='dynamic_seq2seq') + sess.run(tf.global_variables_initializer()) + sess.run(tf.tables_initializer()) + sess.run(train_iterator.initializer) + return train_m + + def _createTestEvalModel(self, m_creator, hparams, sess): + eval_mode = tf.contrib.learn.ModeKeys.EVAL + eval_iterator, src_vocab_table, tgt_vocab_table = ( + common_test_utils.create_test_iterator(hparams, eval_mode)) + eval_m = m_creator( + hparams, + eval_mode, + eval_iterator, + src_vocab_table, + tgt_vocab_table, + scope='dynamic_seq2seq') + sess.run(tf.tables_initializer()) + sess.run(eval_iterator.initializer) + return eval_m + + def _createTestInferModel( + self, m_creator, hparams, sess, init_global_vars=False): + infer_mode = tf.contrib.learn.ModeKeys.INFER + (infer_iterator, src_vocab_table, + tgt_vocab_table, reverse_tgt_vocab_table) = ( + common_test_utils.create_test_iterator(hparams, infer_mode)) + infer_m = m_creator( + hparams, + infer_mode, + infer_iterator, + src_vocab_table, + tgt_vocab_table, + reverse_tgt_vocab_table, + scope='dynamic_seq2seq') + if init_global_vars: + sess.run(tf.global_variables_initializer()) + sess.run(tf.tables_initializer()) + sess.run(infer_iterator.initializer) + return infer_m + + def _get_session_config(self): + config = tf.ConfigProto() + config.allow_soft_placement = True + return config + + # Testing 3 encoders: + # uni: no attention, no residual, 1 layers + # bi: no attention, with residual, 4 layers + def testNoAttentionNoResidualUniEncoder(self): + hparams = common_test_utils.create_test_hparams( + encoder_type='uni', + num_layers=1, + attention='', + attention_architecture='', + use_residual=False,) + + workers, _ = tf.test.create_local_cluster(1, 0) + worker = workers[0] + + # pylint: disable=line-too-long + expected_var_names = [ + 'dynamic_seq2seq/encoder/embedding_encoder:0', + 'dynamic_seq2seq/decoder/embedding_decoder:0', + 'dynamic_seq2seq/encoder/rnn/basic_lstm_cell/kernel:0', + 'dynamic_seq2seq/encoder/rnn/basic_lstm_cell/bias:0', + 'dynamic_seq2seq/decoder/basic_lstm_cell/kernel:0', + 'dynamic_seq2seq/decoder/basic_lstm_cell/bias:0', + 'dynamic_seq2seq/decoder/output_projection/kernel:0' + ] + # pylint: enable=line-too-long + + with tf.Graph().as_default(): + with tf.Session(worker.target, config=self._get_session_config()) as sess: + train_m = self._createTestTrainModel(model.Model, hparams, sess) + + m_vars = tf.trainable_variables() + self._assertModelVariableNames(expected_var_names, + [v.name for v in m_vars], + 'NoAttentionNoResidualUniEncoder') + + with tf.variable_scope('dynamic_seq2seq', reuse=True): + last_enc_weight = tf.get_variable( + 'encoder/rnn/basic_lstm_cell/kernel') + last_dec_weight = tf.get_variable('decoder/basic_lstm_cell/kernel') + self._assertTrainStepsLoss(train_m, sess, + 'NoAttentionNoResidualUniEncoder') + self._assertModelVariable( + last_enc_weight, sess, + 'NoAttentionNoResidualUniEncoder/last_enc_weight') + self._assertModelVariable( + last_dec_weight, sess, + 'NoAttentionNoResidualUniEncoder/last_dec_weight') + + with tf.Graph().as_default(): + with tf.Session(worker.target, config=self._get_session_config()) as sess: + eval_m = self._createTestEvalModel(model.Model, hparams, sess) + self._assertEvalLossAndPredictCount(eval_m, sess, + 'NoAttentionNoResidualUniEncoder') + + with tf.Graph().as_default(): + with tf.Session(worker.target, config=self._get_session_config()) as sess: + infer_m = self._createTestInferModel(model.Model, hparams, sess) + self._assertInferLogits(infer_m, sess, + 'NoAttentionNoResidualUniEncoder') + + def testNoAttentionResidualBiEncoder(self): + hparams = common_test_utils.create_test_hparams( + encoder_type='bi', + num_layers=4, + attention='', + attention_architecture='', + use_residual=True,) + + workers, _ = tf.test.create_local_cluster(1, 0) + worker = workers[0] + + # pylint: disable=line-too-long + expected_var_names = [ + 'dynamic_seq2seq/encoder/embedding_encoder:0', + 'dynamic_seq2seq/decoder/embedding_decoder:0', + 'dynamic_seq2seq/encoder/bidirectional_rnn/fw/multi_rnn_cell/cell_0/basic_lstm_cell/kernel:0', + 'dynamic_seq2seq/encoder/bidirectional_rnn/fw/multi_rnn_cell/cell_0/basic_lstm_cell/bias:0', + 'dynamic_seq2seq/encoder/bidirectional_rnn/fw/multi_rnn_cell/cell_1/basic_lstm_cell/kernel:0', + 'dynamic_seq2seq/encoder/bidirectional_rnn/fw/multi_rnn_cell/cell_1/basic_lstm_cell/bias:0', + 'dynamic_seq2seq/encoder/bidirectional_rnn/bw/multi_rnn_cell/cell_0/basic_lstm_cell/kernel:0', + 'dynamic_seq2seq/encoder/bidirectional_rnn/bw/multi_rnn_cell/cell_0/basic_lstm_cell/bias:0', + 'dynamic_seq2seq/encoder/bidirectional_rnn/bw/multi_rnn_cell/cell_1/basic_lstm_cell/kernel:0', + 'dynamic_seq2seq/encoder/bidirectional_rnn/bw/multi_rnn_cell/cell_1/basic_lstm_cell/bias:0', + 'dynamic_seq2seq/decoder/multi_rnn_cell/cell_0/basic_lstm_cell/kernel:0', + 'dynamic_seq2seq/decoder/multi_rnn_cell/cell_0/basic_lstm_cell/bias:0', + 'dynamic_seq2seq/decoder/multi_rnn_cell/cell_1/basic_lstm_cell/kernel:0', + 'dynamic_seq2seq/decoder/multi_rnn_cell/cell_1/basic_lstm_cell/bias:0', + 'dynamic_seq2seq/decoder/multi_rnn_cell/cell_2/basic_lstm_cell/kernel:0', + 'dynamic_seq2seq/decoder/multi_rnn_cell/cell_2/basic_lstm_cell/bias:0', + 'dynamic_seq2seq/decoder/multi_rnn_cell/cell_3/basic_lstm_cell/kernel:0', + 'dynamic_seq2seq/decoder/multi_rnn_cell/cell_3/basic_lstm_cell/bias:0', + 'dynamic_seq2seq/decoder/output_projection/kernel:0' + ] + # pylint: enable=line-too-long + + with tf.Graph().as_default(): + with tf.Session(worker.target, config=self._get_session_config()) as sess: + train_m = self._createTestTrainModel(model.Model, hparams, sess) + + m_vars = tf.trainable_variables() + self._assertModelVariableNames(expected_var_names, + [v.name for v in m_vars], + 'NoAttentionResidualBiEncoder') + with tf.variable_scope('dynamic_seq2seq', reuse=True): + last_enc_weight = tf.get_variable( + 'encoder/bidirectional_rnn/bw/multi_rnn_cell/cell_1/basic_lstm_cell/kernel' + ) + last_dec_weight = tf.get_variable( + 'decoder/multi_rnn_cell/cell_3/basic_lstm_cell/kernel') + self._assertTrainStepsLoss(train_m, sess, + 'NoAttentionResidualBiEncoder') + self._assertModelVariable( + last_enc_weight, sess, + 'NoAttentionResidualBiEncoder/last_enc_weight') + self._assertModelVariable( + last_dec_weight, sess, + 'NoAttentionResidualBiEncoder/last_dec_weight') + + with tf.Graph().as_default(): + with tf.Session(worker.target, config=self._get_session_config()) as sess: + eval_m = self._createTestEvalModel(model.Model, hparams, sess) + self._assertEvalLossAndPredictCount(eval_m, sess, + 'NoAttentionResidualBiEncoder') + + with tf.Graph().as_default(): + with tf.Session(worker.target, config=self._get_session_config()) as sess: + infer_m = self._createTestInferModel(model.Model, hparams, sess) + self._assertInferLogits(infer_m, sess, 'NoAttentionResidualBiEncoder') + + # Test attention mechanisms: luong, scaled_luong, bahdanau, normed_bahdanau + def testAttentionMechanismLuong(self): + hparams = common_test_utils.create_test_hparams( + encoder_type='uni', + attention='luong', + attention_architecture='standard', + num_layers=2, + use_residual=False,) + + workers, _ = tf.test.create_local_cluster(1, 0) + worker = workers[0] + + # pylint: disable=line-too-long + expected_var_names = [ + 'dynamic_seq2seq/encoder/embedding_encoder:0', + 'dynamic_seq2seq/decoder/embedding_decoder:0', + 'dynamic_seq2seq/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel:0', + 'dynamic_seq2seq/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/bias:0', + 'dynamic_seq2seq/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel:0', + 'dynamic_seq2seq/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/bias:0', + 'dynamic_seq2seq/decoder/memory_layer/kernel:0', + 'dynamic_seq2seq/decoder/attention/multi_rnn_cell/cell_0/basic_lstm_cell/kernel:0', + 'dynamic_seq2seq/decoder/attention/multi_rnn_cell/cell_0/basic_lstm_cell/bias:0', + 'dynamic_seq2seq/decoder/attention/multi_rnn_cell/cell_1/basic_lstm_cell/kernel:0', + 'dynamic_seq2seq/decoder/attention/multi_rnn_cell/cell_1/basic_lstm_cell/bias:0', + 'dynamic_seq2seq/decoder/attention/attention_layer/kernel:0', + 'dynamic_seq2seq/decoder/output_projection/kernel:0' + ] + # pylint: enable=line-too-long + with tf.Graph().as_default(): + with tf.Session(worker.target, config=self._get_session_config()) as sess: + train_m = self._createTestTrainModel(attention_model.AttentionModel, + hparams, sess) + + m_vars = tf.trainable_variables() + self._assertModelVariableNames( + expected_var_names, [v.name + for v in m_vars], 'AttentionMechanismLuong') + + with tf.variable_scope('dynamic_seq2seq', reuse=True): + # pylint: disable=line-too-long + last_enc_weight = tf.get_variable( + 'encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel') + last_dec_weight = tf.get_variable( + 'decoder/attention/multi_rnn_cell/cell_1/basic_lstm_cell/kernel') + att_layer_weight = tf.get_variable( + 'decoder/attention/attention_layer/kernel') + # pylint: enable=line-too-long + self._assertTrainStepsLoss(train_m, sess, 'AttentionMechanismLuong') + self._assertModelVariable(last_enc_weight, sess, + 'AttentionMechanismLuong/last_enc_weight') + self._assertModelVariable(last_dec_weight, sess, + 'AttentionMechanismLuong/last_dec_weight') + self._assertModelVariable(att_layer_weight, sess, + 'AttentionMechanismLuong/att_layer_weight') + + with tf.Graph().as_default(): + with tf.Session(worker.target, config=self._get_session_config()) as sess: + eval_m = self._createTestEvalModel(attention_model.AttentionModel, + hparams, sess) + self._assertEvalLossAndPredictCount(eval_m, sess, + 'AttentionMechanismLuong') + + with tf.Graph().as_default(): + with tf.Session(worker.target, config=self._get_session_config()) as sess: + infer_m = self._createTestInferModel(attention_model.AttentionModel, + hparams, sess) + self._assertInferLogits(infer_m, sess, 'AttentionMechanismLuong') + + def testAttentionMechanismScaledLuong(self): + hparams = common_test_utils.create_test_hparams( + encoder_type='uni', + attention='scaled_luong', + attention_architecture='standard', + num_layers=2, + use_residual=False,) + + workers, _ = tf.test.create_local_cluster(1, 0) + worker = workers[0] + + # pylint: disable=line-too-long + expected_var_names = [ + 'dynamic_seq2seq/encoder/embedding_encoder:0', + 'dynamic_seq2seq/decoder/embedding_decoder:0', + 'dynamic_seq2seq/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel:0', + 'dynamic_seq2seq/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/bias:0', + 'dynamic_seq2seq/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel:0', + 'dynamic_seq2seq/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/bias:0', + 'dynamic_seq2seq/decoder/memory_layer/kernel:0', + 'dynamic_seq2seq/decoder/attention/multi_rnn_cell/cell_0/basic_lstm_cell/kernel:0', + 'dynamic_seq2seq/decoder/attention/multi_rnn_cell/cell_0/basic_lstm_cell/bias:0', + 'dynamic_seq2seq/decoder/attention/multi_rnn_cell/cell_1/basic_lstm_cell/kernel:0', + 'dynamic_seq2seq/decoder/attention/multi_rnn_cell/cell_1/basic_lstm_cell/bias:0', + 'dynamic_seq2seq/decoder/attention/luong_attention/attention_g:0', + 'dynamic_seq2seq/decoder/attention/attention_layer/kernel:0', + 'dynamic_seq2seq/decoder/output_projection/kernel:0' + ] + # pylint: enable=line-too-long + with tf.Graph().as_default(): + with tf.Session(worker.target, config=self._get_session_config()) as sess: + train_m = self._createTestTrainModel(attention_model.AttentionModel, + hparams, sess) + + m_vars = tf.trainable_variables() + self._assertModelVariableNames(expected_var_names, + [v.name for v in m_vars], + 'AttentionMechanismScaledLuong') + + with tf.variable_scope('dynamic_seq2seq', reuse=True): + # pylint: disable=line-too-long + last_enc_weight = tf.get_variable( + 'encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel') + last_dec_weight = tf.get_variable( + 'decoder/attention/multi_rnn_cell/cell_1/basic_lstm_cell/kernel') + att_layer_weight = tf.get_variable( + 'decoder/attention/attention_layer/kernel') + # pylint: enable=line-too-long + + self._assertTrainStepsLoss(train_m, sess, + 'AttentionMechanismScaledLuong') + self._assertModelVariable( + last_enc_weight, sess, + 'AttentionMechanismScaledLuong/last_enc_weight') + self._assertModelVariable( + last_dec_weight, sess, + 'AttentionMechanismScaledLuong/last_dec_weight') + self._assertModelVariable( + att_layer_weight, sess, + 'AttentionMechanismScaledLuong/att_layer_weight') + + with tf.Graph().as_default(): + with tf.Session(worker.target, config=self._get_session_config()) as sess: + eval_m = self._createTestEvalModel(attention_model.AttentionModel, + hparams, sess) + self._assertEvalLossAndPredictCount(eval_m, sess, + 'AttentionMechanismScaledLuong') + + with tf.Graph().as_default(): + with tf.Session(worker.target, config=self._get_session_config()) as sess: + infer_m = self._createTestInferModel(attention_model.AttentionModel, + hparams, sess) + self._assertInferLogits(infer_m, sess, 'AttentionMechanismScaledLuong') + + def testAttentionMechanismBahdanau(self): + hparams = common_test_utils.create_test_hparams( + encoder_type='uni', + attention='bahdanau', + attention_architecture='standard', + num_layers=2, + use_residual=False,) + + workers, _ = tf.test.create_local_cluster(1, 0) + worker = workers[0] + + # pylint: disable=line-too-long + expected_var_names = [ + 'dynamic_seq2seq/encoder/embedding_encoder:0', + 'dynamic_seq2seq/decoder/embedding_decoder:0', + 'dynamic_seq2seq/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel:0', + 'dynamic_seq2seq/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/bias:0', + 'dynamic_seq2seq/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel:0', + 'dynamic_seq2seq/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/bias:0', + 'dynamic_seq2seq/decoder/memory_layer/kernel:0', + 'dynamic_seq2seq/decoder/attention/multi_rnn_cell/cell_0/basic_lstm_cell/kernel:0', + 'dynamic_seq2seq/decoder/attention/multi_rnn_cell/cell_0/basic_lstm_cell/bias:0', + 'dynamic_seq2seq/decoder/attention/multi_rnn_cell/cell_1/basic_lstm_cell/kernel:0', + 'dynamic_seq2seq/decoder/attention/multi_rnn_cell/cell_1/basic_lstm_cell/bias:0', + 'dynamic_seq2seq/decoder/attention/bahdanau_attention/query_layer/kernel:0', + 'dynamic_seq2seq/decoder/attention/bahdanau_attention/attention_v:0', + 'dynamic_seq2seq/decoder/attention/attention_layer/kernel:0', + 'dynamic_seq2seq/decoder/output_projection/kernel:0' + ] + # pylint: enable=line-too-long + with tf.Graph().as_default(): + with tf.Session(worker.target, config=self._get_session_config()) as sess: + train_m = self._createTestTrainModel(attention_model.AttentionModel, + hparams, sess) + + m_vars = tf.trainable_variables() + self._assertModelVariableNames( + expected_var_names, [v.name + for v in m_vars], 'AttentionMechanismBahdanau') + + with tf.variable_scope('dynamic_seq2seq', reuse=True): + # pylint: disable=line-too-long + last_enc_weight = tf.get_variable( + 'encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel') + last_dec_weight = tf.get_variable( + 'decoder/attention/multi_rnn_cell/cell_1/basic_lstm_cell/kernel') + att_layer_weight = tf.get_variable( + 'decoder/attention/attention_layer/kernel') + # pylint: enable=line-too-long + self._assertTrainStepsLoss(train_m, sess, 'AttentionMechanismBahdanau') + self._assertModelVariable(last_enc_weight, sess, + 'AttentionMechanismBahdanau/last_enc_weight') + self._assertModelVariable(last_dec_weight, sess, + 'AttentionMechanismBahdanau/last_dec_weight') + self._assertModelVariable(att_layer_weight, sess, + 'AttentionMechanismBahdanau/att_layer_weight') + + with tf.Graph().as_default(): + with tf.Session(worker.target, config=self._get_session_config()) as sess: + eval_m = self._createTestEvalModel(attention_model.AttentionModel, + hparams, sess) + self._assertEvalLossAndPredictCount(eval_m, sess, + 'AttentionMechanismBahdanau') + + with tf.Graph().as_default(): + with tf.Session(worker.target, config=self._get_session_config()) as sess: + infer_m = self._createTestInferModel(attention_model.AttentionModel, + hparams, sess) + self._assertInferLogits(infer_m, sess, 'AttentionMechanismBahdanau') + + def testAttentionMechanismNormedBahdanau(self): + hparams = common_test_utils.create_test_hparams( + encoder_type='uni', + attention='normed_bahdanau', + attention_architecture='standard', + num_layers=2, + use_residual=False,) + + workers, _ = tf.test.create_local_cluster(1, 0) + worker = workers[0] + + # pylint: disable=line-too-long + expected_var_names = [ + 'dynamic_seq2seq/encoder/embedding_encoder:0', + 'dynamic_seq2seq/decoder/embedding_decoder:0', + 'dynamic_seq2seq/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel:0', + 'dynamic_seq2seq/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/bias:0', + 'dynamic_seq2seq/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel:0', + 'dynamic_seq2seq/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/bias:0', + 'dynamic_seq2seq/decoder/memory_layer/kernel:0', + 'dynamic_seq2seq/decoder/attention/multi_rnn_cell/cell_0/basic_lstm_cell/kernel:0', + 'dynamic_seq2seq/decoder/attention/multi_rnn_cell/cell_0/basic_lstm_cell/bias:0', + 'dynamic_seq2seq/decoder/attention/multi_rnn_cell/cell_1/basic_lstm_cell/kernel:0', + 'dynamic_seq2seq/decoder/attention/multi_rnn_cell/cell_1/basic_lstm_cell/bias:0', + 'dynamic_seq2seq/decoder/attention/bahdanau_attention/query_layer/kernel:0', + 'dynamic_seq2seq/decoder/attention/bahdanau_attention/attention_v:0', + 'dynamic_seq2seq/decoder/attention/bahdanau_attention/attention_g:0', + 'dynamic_seq2seq/decoder/attention/bahdanau_attention/attention_b:0', + 'dynamic_seq2seq/decoder/attention/attention_layer/kernel:0', + 'dynamic_seq2seq/decoder/output_projection/kernel:0' + ] + # pylint: enable=line-too-long + + with tf.Graph().as_default(): + with tf.Session(worker.target, config=self._get_session_config()) as sess: + train_m = self._createTestTrainModel(attention_model.AttentionModel, + hparams, sess) + + m_vars = tf.trainable_variables() + self._assertModelVariableNames(expected_var_names, + [v.name for v in m_vars], + 'AttentionMechanismNormedBahdanau') + + with tf.variable_scope('dynamic_seq2seq', reuse=True): + # pylint: disable=line-too-long + last_enc_weight = tf.get_variable( + 'encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel') + last_dec_weight = tf.get_variable( + 'decoder/attention/multi_rnn_cell/cell_1/basic_lstm_cell/kernel') + att_layer_weight = tf.get_variable( + 'decoder/attention/attention_layer/kernel') + # pylint: enable=line-too-long + self._assertTrainStepsLoss(train_m, sess, + 'AttentionMechanismNormedBahdanau') + self._assertModelVariable( + last_enc_weight, sess, + 'AttentionMechanismNormedBahdanau/last_enc_weight') + self._assertModelVariable( + last_dec_weight, sess, + 'AttentionMechanismNormedBahdanau/last_dec_weight') + self._assertModelVariable( + att_layer_weight, sess, + 'AttentionMechanismNormedBahdanau/att_layer_weight') + + with tf.Graph().as_default(): + with tf.Session(worker.target, config=self._get_session_config()) as sess: + eval_m = self._createTestEvalModel(attention_model.AttentionModel, + hparams, sess) + self._assertEvalLossAndPredictCount(eval_m, sess, + 'AttentionMechanismNormedBahdanau') + + with tf.Graph().as_default(): + with tf.Session(worker.target, config=self._get_session_config()) as sess: + infer_m = self._createTestInferModel(attention_model.AttentionModel, + hparams, sess) + self._assertInferLogits(infer_m, sess, + 'AttentionMechanismNormedBahdanau') + + # Test encoder vs. attention (all use residual): + # uni encoder, standard attention + def testUniEncoderStandardAttentionArchitecture(self): + hparams = common_test_utils.create_test_hparams( + encoder_type='uni', + num_layers=4, + attention='scaled_luong', + attention_architecture='standard',) + + workers, _ = tf.test.create_local_cluster(1, 0) + worker = workers[0] + + # pylint: disable=line-too-long + expected_var_names = [ + 'dynamic_seq2seq/encoder/embedding_encoder:0', + 'dynamic_seq2seq/decoder/embedding_decoder:0', + 'dynamic_seq2seq/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel:0', + 'dynamic_seq2seq/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/bias:0', + 'dynamic_seq2seq/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel:0', + 'dynamic_seq2seq/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/bias:0', + 'dynamic_seq2seq/encoder/rnn/multi_rnn_cell/cell_2/basic_lstm_cell/kernel:0', + 'dynamic_seq2seq/encoder/rnn/multi_rnn_cell/cell_2/basic_lstm_cell/bias:0', + 'dynamic_seq2seq/encoder/rnn/multi_rnn_cell/cell_3/basic_lstm_cell/kernel:0', + 'dynamic_seq2seq/encoder/rnn/multi_rnn_cell/cell_3/basic_lstm_cell/bias:0', + 'dynamic_seq2seq/decoder/memory_layer/kernel:0', + 'dynamic_seq2seq/decoder/attention/multi_rnn_cell/cell_0/basic_lstm_cell/kernel:0', + 'dynamic_seq2seq/decoder/attention/multi_rnn_cell/cell_0/basic_lstm_cell/bias:0', + 'dynamic_seq2seq/decoder/attention/multi_rnn_cell/cell_1/basic_lstm_cell/kernel:0', + 'dynamic_seq2seq/decoder/attention/multi_rnn_cell/cell_1/basic_lstm_cell/bias:0', + 'dynamic_seq2seq/decoder/attention/multi_rnn_cell/cell_2/basic_lstm_cell/kernel:0', + 'dynamic_seq2seq/decoder/attention/multi_rnn_cell/cell_2/basic_lstm_cell/bias:0', + 'dynamic_seq2seq/decoder/attention/multi_rnn_cell/cell_3/basic_lstm_cell/kernel:0', + 'dynamic_seq2seq/decoder/attention/multi_rnn_cell/cell_3/basic_lstm_cell/bias:0', + 'dynamic_seq2seq/decoder/attention/luong_attention/attention_g:0', + 'dynamic_seq2seq/decoder/attention/attention_layer/kernel:0', + 'dynamic_seq2seq/decoder/output_projection/kernel:0' + ] + # pylint: enable=line-too-long + + with tf.Graph().as_default(): + with tf.Session(worker.target, config=self._get_session_config()) as sess: + train_m = self._createTestTrainModel(attention_model.AttentionModel, + hparams, sess) + + m_vars = tf.trainable_variables() + self._assertModelVariableNames(expected_var_names, [ + v.name for v in m_vars + ], 'UniEncoderStandardAttentionArchitecture') + with tf.variable_scope('dynamic_seq2seq', reuse=True): + last_enc_weight = tf.get_variable( + 'encoder/rnn/multi_rnn_cell/cell_3/basic_lstm_cell/kernel') + last_dec_weight = tf.get_variable( + 'decoder/attention/multi_rnn_cell/cell_3/basic_lstm_cell/kernel') + mem_layer_weight = tf.get_variable('decoder/memory_layer/kernel') + self._assertTrainStepsLoss(train_m, sess, + 'UniEncoderStandardAttentionArchitecture') + self._assertModelVariable( + last_enc_weight, sess, + 'UniEncoderStandardAttentionArchitecture/last_enc_weight') + self._assertModelVariable( + last_dec_weight, sess, + 'UniEncoderStandardAttentionArchitecture/last_dec_weight') + self._assertModelVariable( + mem_layer_weight, sess, + 'UniEncoderStandardAttentionArchitecture/mem_layer_weight') + + with tf.Graph().as_default(): + with tf.Session(worker.target, config=self._get_session_config()) as sess: + eval_m = self._createTestEvalModel(attention_model.AttentionModel, + hparams, sess) + self._assertEvalLossAndPredictCount( + eval_m, sess, 'UniEncoderStandardAttentionArchitecture') + + with tf.Graph().as_default(): + with tf.Session(worker.target, config=self._get_session_config()) as sess: + infer_m = self._createTestInferModel(attention_model.AttentionModel, + hparams, sess) + self._assertInferLogits(infer_m, sess, + 'UniEncoderStandardAttentionArchitecture') + + # Test gnmt model. + def _testGNMTModel(self, architecture): + hparams = common_test_utils.create_test_hparams( + encoder_type='gnmt', + num_layers=4, + attention='scaled_luong', + attention_architecture=architecture) + + workers, _ = tf.test.create_local_cluster(1, 0) + worker = workers[0] + + # pylint: disable=line-too-long + expected_var_names = [ + 'dynamic_seq2seq/encoder/embedding_encoder:0', + 'dynamic_seq2seq/decoder/embedding_decoder:0', + 'dynamic_seq2seq/encoder/bidirectional_rnn/fw/basic_lstm_cell/kernel:0', + 'dynamic_seq2seq/encoder/bidirectional_rnn/fw/basic_lstm_cell/bias:0', + 'dynamic_seq2seq/encoder/bidirectional_rnn/bw/basic_lstm_cell/kernel:0', + 'dynamic_seq2seq/encoder/bidirectional_rnn/bw/basic_lstm_cell/bias:0', + 'dynamic_seq2seq/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel:0', + 'dynamic_seq2seq/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/bias:0', + 'dynamic_seq2seq/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel:0', + 'dynamic_seq2seq/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/bias:0', + 'dynamic_seq2seq/encoder/rnn/multi_rnn_cell/cell_2/basic_lstm_cell/kernel:0', + 'dynamic_seq2seq/encoder/rnn/multi_rnn_cell/cell_2/basic_lstm_cell/bias:0', + 'dynamic_seq2seq/decoder/memory_layer/kernel:0', + 'dynamic_seq2seq/decoder/multi_rnn_cell/cell_0_attention/attention/basic_lstm_cell/kernel:0', + 'dynamic_seq2seq/decoder/multi_rnn_cell/cell_0_attention/attention/basic_lstm_cell/bias:0', + 'dynamic_seq2seq/decoder/multi_rnn_cell/cell_0_attention/attention/luong_attention/attention_g:0', + 'dynamic_seq2seq/decoder/multi_rnn_cell/cell_1/basic_lstm_cell/kernel:0', + 'dynamic_seq2seq/decoder/multi_rnn_cell/cell_1/basic_lstm_cell/bias:0', + 'dynamic_seq2seq/decoder/multi_rnn_cell/cell_2/basic_lstm_cell/kernel:0', + 'dynamic_seq2seq/decoder/multi_rnn_cell/cell_2/basic_lstm_cell/bias:0', + 'dynamic_seq2seq/decoder/multi_rnn_cell/cell_3/basic_lstm_cell/kernel:0', + 'dynamic_seq2seq/decoder/multi_rnn_cell/cell_3/basic_lstm_cell/bias:0', + 'dynamic_seq2seq/decoder/output_projection/kernel:0' + ] + # pylint: enable=line-too-long + + test_prefix = 'GNMTModel_%s' % architecture + with tf.Graph().as_default(): + with tf.Session(worker.target, config=self._get_session_config()) as sess: + train_m = self._createTestTrainModel(gnmt_model.GNMTModel, hparams, + sess) + + m_vars = tf.trainable_variables() + self._assertModelVariableNames(expected_var_names, + [v.name for v in m_vars], test_prefix) + with tf.variable_scope('dynamic_seq2seq', reuse=True): + last_enc_weight = tf.get_variable( + 'encoder/rnn/multi_rnn_cell/cell_2/basic_lstm_cell/kernel') + last_dec_weight = tf.get_variable( + 'decoder/multi_rnn_cell/cell_3/basic_lstm_cell/kernel') + mem_layer_weight = tf.get_variable('decoder/memory_layer/kernel') + self._assertTrainStepsLoss(train_m, sess, test_prefix) + + self._assertModelVariable(last_enc_weight, sess, + '%s/last_enc_weight' % test_prefix) + self._assertModelVariable(last_dec_weight, sess, + '%s/last_dec_weight' % test_prefix) + self._assertModelVariable(mem_layer_weight, sess, + '%s/mem_layer_weight' % test_prefix) + + with tf.Graph().as_default(): + with tf.Session(worker.target, config=self._get_session_config()) as sess: + eval_m = self._createTestEvalModel(gnmt_model.GNMTModel, hparams, sess) + self._assertEvalLossAndPredictCount(eval_m, sess, test_prefix) + + with tf.Graph().as_default(): + with tf.Session(worker.target, config=self._get_session_config()) as sess: + infer_m = self._createTestInferModel(gnmt_model.GNMTModel, hparams, + sess) + self._assertInferLogits(infer_m, sess, test_prefix) + + def testGNMTModel(self): + self._testGNMTModel('gnmt') + + def testGNMTModelV2(self): + self._testGNMTModel('gnmt_v2') + + # Test beam search. + def testBeamSearchBasicModel(self): + hparams = common_test_utils.create_test_hparams( + encoder_type='uni', + num_layers=1, + attention='', + attention_architecture='', + use_residual=False,) + hparams.beam_width = 3 + hparams.infer_mode = "beam_search" + hparams.tgt_max_len_infer = 4 + assert_top_k_sentence = 3 + + with self.test_session() as sess: + infer_m = self._createTestInferModel( + model.Model, hparams, sess, True) + self._assertBeamSearchOutputs( + infer_m, sess, assert_top_k_sentence, 'BeamSearchBasicModel') + + def testBeamSearchAttentionModel(self): + hparams = common_test_utils.create_test_hparams( + encoder_type='uni', + attention='scaled_luong', + attention_architecture='standard', + num_layers=2, + use_residual=False,) + hparams.beam_width = 3 + hparams.infer_mode = "beam_search" + hparams.tgt_max_len_infer = 4 + assert_top_k_sentence = 2 + + with self.test_session() as sess: + infer_m = self._createTestInferModel( + attention_model.AttentionModel, hparams, sess, True) + self._assertBeamSearchOutputs( + infer_m, sess, assert_top_k_sentence, 'BeamSearchAttentionModel') + + def testBeamSearchGNMTModel(self): + hparams = common_test_utils.create_test_hparams( + encoder_type='gnmt', + num_layers=4, + attention='scaled_luong', + attention_architecture='gnmt') + hparams.beam_width = 3 + hparams.infer_mode = "beam_search" + hparams.tgt_max_len_infer = 4 + assert_top_k_sentence = 1 + + with self.test_session() as sess: + infer_m = self._createTestInferModel( + gnmt_model.GNMTModel, hparams, sess, True) + self._assertBeamSearchOutputs( + infer_m, sess, assert_top_k_sentence, 'BeamSearchGNMTModel') + + def testInitializerGlorotNormal(self): + hparams = common_test_utils.create_test_hparams( + encoder_type='uni', + num_layers=1, + attention='', + attention_architecture='', + use_residual=False, + init_op='glorot_normal') + + with self.test_session() as sess: + train_m = self._createTestTrainModel(model.Model, hparams, sess) + self._assertTrainStepsLoss(train_m, sess, + 'InitializerGlorotNormal') + + def testInitializerGlorotUniform(self): + hparams = common_test_utils.create_test_hparams( + encoder_type='uni', + num_layers=1, + attention='', + attention_architecture='', + use_residual=False, + init_op='glorot_uniform') + + with self.test_session() as sess: + train_m = self._createTestTrainModel(model.Model, hparams, sess) + self._assertTrainStepsLoss(train_m, sess, + 'InitializerGlorotUniform') + + def testSampledSoftmaxLoss(self): + hparams = common_test_utils.create_test_hparams( + encoder_type='gnmt', + num_layers=4, + attention='scaled_luong', + attention_architecture='gnmt') + hparams.num_sampled_softmax = 3 + + with self.test_session() as sess: + train_m = self._createTestTrainModel(gnmt_model.GNMTModel, hparams, sess) + self._assertTrainStepsLoss(train_m, sess, + 'SampledSoftmaxLoss') + if __name__ == '__main__': - tf.test.main() + tf.test.main() diff --git a/nmt/nmt.py b/nmt/nmt.py index 41b3a81..290267e 100644 --- a/nmt/nmt.py +++ b/nmt/nmt.py @@ -51,42 +51,42 @@ def add_arguments(parser): - """Build ArgumentParser.""" - parser.register("type", "bool", lambda v: v.lower() == "true") - - # network - parser.add_argument("--num_units", type=int, default=32, help="Network size.") - parser.add_argument("--num_layers", type=int, default=2, - help="Network depth.") - parser.add_argument("--num_encoder_layers", type=int, default=None, - help="Encoder depth, equal to num_layers if None.") - parser.add_argument("--num_decoder_layers", type=int, default=None, - help="Decoder depth, equal to num_layers if None.") - parser.add_argument("--encoder_type", type=str, default="uni", help="""\ + """Build ArgumentParser.""" + parser.register("type", "bool", lambda v: v.lower() == "true") + + # network + parser.add_argument("--num_units", type=int, default=32, help="Network size.") + parser.add_argument("--num_layers", type=int, default=2, + help="Network depth.") + parser.add_argument("--num_encoder_layers", type=int, default=None, + help="Encoder depth, equal to num_layers if None.") + parser.add_argument("--num_decoder_layers", type=int, default=None, + help="Decoder depth, equal to num_layers if None.") + parser.add_argument("--encoder_type", type=str, default="uni", help="""\ uni | bi | gnmt. For bi, we build num_encoder_layers/2 bi-directional layers. For gnmt, we build 1 bi-directional layer, and (num_encoder_layers - 1) uni-directional layers.\ """) - parser.add_argument("--residual", type="bool", nargs="?", const=True, - default=False, - help="Whether to add residual connections.") - parser.add_argument("--time_major", type="bool", nargs="?", const=True, - default=True, - help="Whether to use time-major mode for dynamic RNN.") - parser.add_argument("--num_embeddings_partitions", type=int, default=0, - help="Number of partitions for embedding vars.") - - # attention mechanisms - parser.add_argument("--attention", type=str, default="", help="""\ + parser.add_argument("--residual", type="bool", nargs="?", const=True, + default=False, + help="Whether to add residual connections.") + parser.add_argument("--time_major", type="bool", nargs="?", const=True, + default=True, + help="Whether to use time-major mode for dynamic RNN.") + parser.add_argument("--num_embeddings_partitions", type=int, default=0, + help="Number of partitions for embedding vars.") + + # attention mechanisms + parser.add_argument("--attention", type=str, default="", help="""\ luong | scaled_luong | bahdanau | normed_bahdanau or set to "" for no attention\ """) - parser.add_argument( - "--attention_architecture", - type=str, - default="standard", - help="""\ + parser.add_argument( + "--attention_architecture", + type=str, + default="standard", + help="""\ standard | gnmt | gnmt_v2. standard: use top layer to compute attention. gnmt: GNMT style of computing attention, use previous bottom layer to @@ -94,35 +94,35 @@ def add_arguments(parser): gnmt_v2: similar to gnmt, but use current bottom layer to compute attention.\ """) - parser.add_argument( - "--output_attention", type="bool", nargs="?", const=True, - default=True, - help="""\ + parser.add_argument( + "--output_attention", type="bool", nargs="?", const=True, + default=True, + help="""\ Only used in standard attention_architecture. Whether use attention as the cell output at each timestep. .\ """) - parser.add_argument( - "--pass_hidden_state", type="bool", nargs="?", const=True, - default=True, - help="""\ + parser.add_argument( + "--pass_hidden_state", type="bool", nargs="?", const=True, + default=True, + help="""\ Whether to pass encoder's hidden state to decoder when using an attention based model.\ """) - # optimizer - parser.add_argument("--optimizer", type=str, default="sgd", help="sgd | adam") - parser.add_argument("--learning_rate", type=float, default=1.0, - help="Learning rate. Adam: 0.001 | 0.0001") - parser.add_argument("--warmup_steps", type=int, default=0, - help="How many steps we inverse-decay learning.") - parser.add_argument("--warmup_scheme", type=str, default="t2t", help="""\ + # optimizer + parser.add_argument("--optimizer", type=str, default="sgd", help="sgd | adam") + parser.add_argument("--learning_rate", type=float, default=1.0, + help="Learning rate. Adam: 0.001 | 0.0001") + parser.add_argument("--warmup_steps", type=int, default=0, + help="How many steps we inverse-decay learning.") + parser.add_argument("--warmup_scheme", type=str, default="t2t", help="""\ How to warmup learning rates. Options include: t2t: Tensor2Tensor's way, start with lr 100 times smaller, then exponentiate until the specified lr.\ """) - parser.add_argument( - "--decay_scheme", type=str, default="", help="""\ + parser.add_argument( + "--decay_scheme", type=str, default="", help="""\ How we decay learning rate. Options include: luong234: after 2/3 num train steps, we start halving the learning rate for 4 times before finishing. @@ -132,584 +132,585 @@ def add_arguments(parser): for 10 times before finishing.\ """) - parser.add_argument( - "--num_train_steps", type=int, default=12000, help="Num steps to train.") - parser.add_argument("--colocate_gradients_with_ops", type="bool", nargs="?", - const=True, - default=True, - help=("Whether try colocating gradients with " - "corresponding op")) - - # initializer - parser.add_argument("--init_op", type=str, default="uniform", - help="uniform | glorot_normal | glorot_uniform") - parser.add_argument("--init_weight", type=float, default=0.1, - help=("for uniform init_op, initialize weights " - "between [-this, this].")) - - # data - parser.add_argument("--src", type=str, default=None, - help="Source suffix, e.g., en.") - parser.add_argument("--tgt", type=str, default=None, - help="Target suffix, e.g., de.") - parser.add_argument("--train_prefix", type=str, default=None, - help="Train prefix, expect files with src/tgt suffixes.") - parser.add_argument("--dev_prefix", type=str, default=None, - help="Dev prefix, expect files with src/tgt suffixes.") - parser.add_argument("--test_prefix", type=str, default=None, - help="Test prefix, expect files with src/tgt suffixes.") - parser.add_argument("--out_dir", type=str, default=None, - help="Store log/model files.") - - # Vocab - parser.add_argument("--vocab_prefix", type=str, default=None, help="""\ + parser.add_argument( + "--num_train_steps", type=int, default=12000, help="Num steps to train.") + parser.add_argument("--colocate_gradients_with_ops", type="bool", nargs="?", + const=True, + default=True, + help=("Whether try colocating gradients with " + "corresponding op")) + + # initializer + parser.add_argument("--init_op", type=str, default="uniform", + help="uniform | glorot_normal | glorot_uniform") + parser.add_argument("--init_weight", type=float, default=0.1, + help=("for uniform init_op, initialize weights " + "between [-this, this].")) + + # data + parser.add_argument("--src", type=str, default=None, + help="Source suffix, e.g., en.") + parser.add_argument("--tgt", type=str, default=None, + help="Target suffix, e.g., de.") + parser.add_argument("--train_prefix", type=str, default=None, + help="Train prefix, expect files with src/tgt suffixes.") + parser.add_argument("--dev_prefix", type=str, default=None, + help="Dev prefix, expect files with src/tgt suffixes.") + parser.add_argument("--test_prefix", type=str, default=None, + help="Test prefix, expect files with src/tgt suffixes.") + parser.add_argument("--out_dir", type=str, default=None, + help="Store log/model files.") + + # Vocab + parser.add_argument("--vocab_prefix", type=str, default=None, help="""\ Vocab prefix, expect files with src/tgt suffixes.\ """) - parser.add_argument("--embed_prefix", type=str, default=None, help="""\ + parser.add_argument("--embed_prefix", type=str, default=None, help="""\ Pretrained embedding prefix, expect files with src/tgt suffixes. The embedding files should be Glove formatted txt files.\ """) - parser.add_argument("--sos", type=str, default="", - help="Start-of-sentence symbol.") - parser.add_argument("--eos", type=str, default="", - help="End-of-sentence symbol.") - parser.add_argument("--share_vocab", type="bool", nargs="?", const=True, - default=False, - help="""\ + parser.add_argument("--sos", type=str, default="", + help="Start-of-sentence symbol.") + parser.add_argument("--eos", type=str, default="", + help="End-of-sentence symbol.") + parser.add_argument("--share_vocab", type="bool", nargs="?", const=True, + default=False, + help="""\ Whether to use the source vocab and embeddings for both source and target.\ """) - parser.add_argument("--check_special_token", type="bool", default=True, - help="""\ + parser.add_argument("--check_special_token", type="bool", default=True, + help="""\ Whether check special sos, eos, unk tokens exist in the vocab files.\ """) - # Sequence lengths - parser.add_argument("--src_max_len", type=int, default=50, - help="Max length of src sequences during training.") - parser.add_argument("--tgt_max_len", type=int, default=50, - help="Max length of tgt sequences during training.") - parser.add_argument("--src_max_len_infer", type=int, default=None, - help="Max length of src sequences during inference.") - parser.add_argument("--tgt_max_len_infer", type=int, default=None, - help="""\ + # Sequence lengths + parser.add_argument("--src_max_len", type=int, default=50, + help="Max length of src sequences during training.") + parser.add_argument("--tgt_max_len", type=int, default=50, + help="Max length of tgt sequences during training.") + parser.add_argument("--src_max_len_infer", type=int, default=None, + help="Max length of src sequences during inference.") + parser.add_argument("--tgt_max_len_infer", type=int, default=None, + help="""\ Max length of tgt sequences during inference. Also use to restrict the maximum decoding length.\ """) - # Default settings works well (rarely need to change) - parser.add_argument("--unit_type", type=str, default="lstm", - help="lstm | gru | layer_norm_lstm | nas") - parser.add_argument("--forget_bias", type=float, default=1.0, - help="Forget bias for BasicLSTMCell.") - parser.add_argument("--dropout", type=float, default=0.2, - help="Dropout rate (not keep_prob)") - parser.add_argument("--max_gradient_norm", type=float, default=5.0, - help="Clip gradients to this norm.") - parser.add_argument("--batch_size", type=int, default=128, help="Batch size.") - - parser.add_argument("--steps_per_stats", type=int, default=100, - help=("How many training steps to do per stats logging." - "Save checkpoint every 10x steps_per_stats")) - parser.add_argument("--max_train", type=int, default=0, - help="Limit on the size of training data (0: no limit).") - parser.add_argument("--num_buckets", type=int, default=5, - help="Put data into similar-length buckets.") - parser.add_argument("--num_sampled_softmax", type=int, default=0, - help=("Use sampled_softmax_loss if > 0." - "Otherwise, use full softmax loss.")) - - # SPM - parser.add_argument("--subword_option", type=str, default="", - choices=["", "bpe", "spm"], - help="""\ + # Default settings works well (rarely need to change) + parser.add_argument("--unit_type", type=str, default="lstm", + help="lstm | gru | layer_norm_lstm | nas") + parser.add_argument("--forget_bias", type=float, default=1.0, + help="Forget bias for BasicLSTMCell.") + parser.add_argument("--dropout", type=float, default=0.2, + help="Dropout rate (not keep_prob)") + parser.add_argument("--max_gradient_norm", type=float, default=5.0, + help="Clip gradients to this norm.") + parser.add_argument("--batch_size", type=int, default=128, help="Batch size.") + + parser.add_argument("--steps_per_stats", type=int, default=100, + help=("How many training steps to do per stats logging." + "Save checkpoint every 10x steps_per_stats")) + parser.add_argument("--max_train", type=int, default=0, + help="Limit on the size of training data (0: no limit).") + parser.add_argument("--num_buckets", type=int, default=5, + help="Put data into similar-length buckets.") + parser.add_argument("--num_sampled_softmax", type=int, default=0, + help=("Use sampled_softmax_loss if > 0." + "Otherwise, use full softmax loss.")) + + # SPM + parser.add_argument("--subword_option", type=str, default="", + choices=["", "bpe", "spm"], + help="""\ Set to bpe or spm to activate subword desegmentation.\ """) - # Experimental encoding feature. - parser.add_argument("--use_char_encode", type="bool", default=False, - help="""\ + # Experimental encoding feature. + parser.add_argument("--use_char_encode", type="bool", default=False, + help="""\ Whether to split each word or bpe into character, and then generate the word-level representation from the character reprentation. """) - # Misc - parser.add_argument("--num_gpus", type=int, default=1, - help="Number of gpus in each worker.") - parser.add_argument("--log_device_placement", type="bool", nargs="?", - const=True, default=False, help="Debug GPU allocation.") - parser.add_argument("--metrics", type=str, default="bleu", - help=("Comma-separated list of evaluations " - "metrics (bleu,rouge,accuracy)")) - parser.add_argument("--steps_per_external_eval", type=int, default=None, - help="""\ + # Misc + parser.add_argument("--num_gpus", type=int, default=1, + help="Number of gpus in each worker.") + parser.add_argument("--log_device_placement", type="bool", nargs="?", + const=True, default=False, help="Debug GPU allocation.") + parser.add_argument("--metrics", type=str, default="bleu", + help=("Comma-separated list of evaluations " + "metrics (bleu,rouge,accuracy)")) + parser.add_argument("--steps_per_external_eval", type=int, default=None, + help="""\ How many training steps to do per external evaluation. Automatically set based on data if None.\ """) - parser.add_argument("--scope", type=str, default=None, - help="scope to put variables under") - parser.add_argument("--hparams_path", type=str, default=None, - help=("Path to standard hparams json file that overrides" - "hparams values from FLAGS.")) - parser.add_argument("--random_seed", type=int, default=None, - help="Random seed (>0, set a specific seed).") - parser.add_argument("--override_loaded_hparams", type="bool", nargs="?", - const=True, default=False, - help="Override loaded hparams with values specified") - parser.add_argument("--num_keep_ckpts", type=int, default=5, - help="Max number of checkpoints to keep.") - parser.add_argument("--avg_ckpts", type="bool", nargs="?", - const=True, default=False, help=("""\ + parser.add_argument("--scope", type=str, default=None, + help="scope to put variables under") + parser.add_argument("--hparams_path", type=str, default=None, + help=("Path to standard hparams json file that overrides" + "hparams values from FLAGS.")) + parser.add_argument("--random_seed", type=int, default=None, + help="Random seed (>0, set a specific seed).") + parser.add_argument("--override_loaded_hparams", type="bool", nargs="?", + const=True, default=False, + help="Override loaded hparams with values specified") + parser.add_argument("--num_keep_ckpts", type=int, default=5, + help="Max number of checkpoints to keep.") + parser.add_argument("--avg_ckpts", type="bool", nargs="?", + const=True, default=False, help=("""\ Average the last N checkpoints for external evaluation. N can be controlled by setting --num_keep_ckpts.\ """)) - parser.add_argument("--language_model", type="bool", nargs="?", - const=True, default=False, - help="True to train a language model, ignoring encoder") - - # Inference - parser.add_argument("--ckpt", type=str, default="", - help="Checkpoint file to load a model for inference.") - parser.add_argument("--inference_input_file", type=str, default=None, - help="Set to the text to decode.") - parser.add_argument("--inference_list", type=str, default=None, - help=("A comma-separated list of sentence indices " - "(0-based) to decode.")) - parser.add_argument("--infer_batch_size", type=int, default=32, - help="Batch size for inference mode.") - parser.add_argument("--inference_output_file", type=str, default=None, - help="Output file to store decoding results.") - parser.add_argument("--inference_ref_file", type=str, default=None, - help=("""\ + parser.add_argument("--language_model", type="bool", nargs="?", + const=True, default=False, + help="True to train a language model, ignoring encoder") + + # Inference + parser.add_argument("--ckpt", type=str, default="", + help="Checkpoint file to load a model for inference.") + parser.add_argument("--inference_input_file", type=str, default=None, + help="Set to the text to decode.") + parser.add_argument("--inference_list", type=str, default=None, + help=("A comma-separated list of sentence indices " + "(0-based) to decode.")) + parser.add_argument("--infer_batch_size", type=int, default=32, + help="Batch size for inference mode.") + parser.add_argument("--inference_output_file", type=str, default=None, + help="Output file to store decoding results.") + parser.add_argument("--inference_ref_file", type=str, default=None, + help=("""\ Reference file to compute evaluation scores (if provided).\ """)) - # Advanced inference arguments - parser.add_argument("--infer_mode", type=str, default="greedy", - choices=["greedy", "sample", "beam_search"], - help="Which type of decoder to use during inference.") - parser.add_argument("--beam_width", type=int, default=0, - help=("""\ + # Advanced inference arguments + parser.add_argument("--infer_mode", type=str, default="greedy", + choices=["greedy", "sample", "beam_search"], + help="Which type of decoder to use during inference.") + parser.add_argument("--beam_width", type=int, default=0, + help=("""\ beam width when using beam search decoder. If 0 (default), use standard decoder with greedy helper.\ """)) - parser.add_argument("--length_penalty_weight", type=float, default=0.0, - help="Length penalty for beam search.") - parser.add_argument("--coverage_penalty_weight", type=float, default=0.0, - help="Coverage penalty for beam search.") - parser.add_argument("--sampling_temperature", type=float, - default=0.0, - help=("""\ + parser.add_argument("--length_penalty_weight", type=float, default=0.0, + help="Length penalty for beam search.") + parser.add_argument("--coverage_penalty_weight", type=float, default=0.0, + help="Coverage penalty for beam search.") + parser.add_argument("--sampling_temperature", type=float, + default=0.0, + help=("""\ Softmax sampling temperature for inference decoding, 0.0 means greedy decoding. This option is ignored when using beam search.\ """)) - parser.add_argument("--num_translations_per_input", type=int, default=1, - help=("""\ + parser.add_argument("--num_translations_per_input", type=int, default=1, + help=("""\ Number of translations generated for each sentence. This is only used for inference.\ """)) - # Job info - parser.add_argument("--jobid", type=int, default=0, - help="Task id of the worker.") - parser.add_argument("--num_workers", type=int, default=1, - help="Number of workers (inference only).") - parser.add_argument("--num_inter_threads", type=int, default=0, - help="number of inter_op_parallelism_threads") - parser.add_argument("--num_intra_threads", type=int, default=0, - help="number of intra_op_parallelism_threads") + # Job info + parser.add_argument("--jobid", type=int, default=0, + help="Task id of the worker.") + parser.add_argument("--num_workers", type=int, default=1, + help="Number of workers (inference only).") + parser.add_argument("--num_inter_threads", type=int, default=0, + help="number of inter_op_parallelism_threads") + parser.add_argument("--num_intra_threads", type=int, default=0, + help="number of intra_op_parallelism_threads") def create_hparams(flags): - """Create training hparams.""" - return tf.contrib.training.HParams( - # Data - src=flags.src, - tgt=flags.tgt, - train_prefix=flags.train_prefix, - dev_prefix=flags.dev_prefix, - test_prefix=flags.test_prefix, - vocab_prefix=flags.vocab_prefix, - embed_prefix=flags.embed_prefix, - out_dir=flags.out_dir, - - # Networks - num_units=flags.num_units, - num_encoder_layers=(flags.num_encoder_layers or flags.num_layers), - num_decoder_layers=(flags.num_decoder_layers or flags.num_layers), - dropout=flags.dropout, - unit_type=flags.unit_type, - encoder_type=flags.encoder_type, - residual=flags.residual, - time_major=flags.time_major, - num_embeddings_partitions=flags.num_embeddings_partitions, - - # Attention mechanisms - attention=flags.attention, - attention_architecture=flags.attention_architecture, - output_attention=flags.output_attention, - pass_hidden_state=flags.pass_hidden_state, - - # Train - optimizer=flags.optimizer, - num_train_steps=flags.num_train_steps, - batch_size=flags.batch_size, - init_op=flags.init_op, - init_weight=flags.init_weight, - max_gradient_norm=flags.max_gradient_norm, - learning_rate=flags.learning_rate, - warmup_steps=flags.warmup_steps, - warmup_scheme=flags.warmup_scheme, - decay_scheme=flags.decay_scheme, - colocate_gradients_with_ops=flags.colocate_gradients_with_ops, - num_sampled_softmax=flags.num_sampled_softmax, - - # Data constraints - num_buckets=flags.num_buckets, - max_train=flags.max_train, - src_max_len=flags.src_max_len, - tgt_max_len=flags.tgt_max_len, - - # Inference - src_max_len_infer=flags.src_max_len_infer, - tgt_max_len_infer=flags.tgt_max_len_infer, - infer_batch_size=flags.infer_batch_size, - - # Advanced inference arguments - infer_mode=flags.infer_mode, - beam_width=flags.beam_width, - length_penalty_weight=flags.length_penalty_weight, - coverage_penalty_weight=flags.coverage_penalty_weight, - sampling_temperature=flags.sampling_temperature, - num_translations_per_input=flags.num_translations_per_input, - - # Vocab - sos=flags.sos if flags.sos else vocab_utils.SOS, - eos=flags.eos if flags.eos else vocab_utils.EOS, - subword_option=flags.subword_option, - check_special_token=flags.check_special_token, - use_char_encode=flags.use_char_encode, - - # Misc - forget_bias=flags.forget_bias, - num_gpus=flags.num_gpus, - epoch_step=0, # record where we were within an epoch. - steps_per_stats=flags.steps_per_stats, - steps_per_external_eval=flags.steps_per_external_eval, - share_vocab=flags.share_vocab, - metrics=flags.metrics.split(","), - log_device_placement=flags.log_device_placement, - random_seed=flags.random_seed, - override_loaded_hparams=flags.override_loaded_hparams, - num_keep_ckpts=flags.num_keep_ckpts, - avg_ckpts=flags.avg_ckpts, - language_model=flags.language_model, - num_intra_threads=flags.num_intra_threads, - num_inter_threads=flags.num_inter_threads, - ) + """Create training hparams.""" + return tf.contrib.training.HParams( + # Data + src=flags.src, + tgt=flags.tgt, + train_prefix=flags.train_prefix, + dev_prefix=flags.dev_prefix, + test_prefix=flags.test_prefix, + vocab_prefix=flags.vocab_prefix, + embed_prefix=flags.embed_prefix, + out_dir=flags.out_dir, + + # Networks + num_units=flags.num_units, + num_encoder_layers=(flags.num_encoder_layers or flags.num_layers), + num_decoder_layers=(flags.num_decoder_layers or flags.num_layers), + dropout=flags.dropout, + unit_type=flags.unit_type, + encoder_type=flags.encoder_type, + residual=flags.residual, + time_major=flags.time_major, + num_embeddings_partitions=flags.num_embeddings_partitions, + + # Attention mechanisms + attention=flags.attention, + attention_architecture=flags.attention_architecture, + output_attention=flags.output_attention, + pass_hidden_state=flags.pass_hidden_state, + + # Train + optimizer=flags.optimizer, + num_train_steps=flags.num_train_steps, + batch_size=flags.batch_size, + init_op=flags.init_op, + init_weight=flags.init_weight, + max_gradient_norm=flags.max_gradient_norm, + learning_rate=flags.learning_rate, + warmup_steps=flags.warmup_steps, + warmup_scheme=flags.warmup_scheme, + decay_scheme=flags.decay_scheme, + colocate_gradients_with_ops=flags.colocate_gradients_with_ops, + num_sampled_softmax=flags.num_sampled_softmax, + + # Data constraints + num_buckets=flags.num_buckets, + max_train=flags.max_train, + src_max_len=flags.src_max_len, + tgt_max_len=flags.tgt_max_len, + + # Inference + src_max_len_infer=flags.src_max_len_infer, + tgt_max_len_infer=flags.tgt_max_len_infer, + infer_batch_size=flags.infer_batch_size, + + # Advanced inference arguments + infer_mode=flags.infer_mode, + beam_width=flags.beam_width, + length_penalty_weight=flags.length_penalty_weight, + coverage_penalty_weight=flags.coverage_penalty_weight, + sampling_temperature=flags.sampling_temperature, + num_translations_per_input=flags.num_translations_per_input, + + # Vocab + sos=flags.sos if flags.sos else vocab_utils.SOS, + eos=flags.eos if flags.eos else vocab_utils.EOS, + subword_option=flags.subword_option, + check_special_token=flags.check_special_token, + use_char_encode=flags.use_char_encode, + + # Misc + forget_bias=flags.forget_bias, + num_gpus=flags.num_gpus, + epoch_step=0, # record where we were within an epoch. + steps_per_stats=flags.steps_per_stats, + steps_per_external_eval=flags.steps_per_external_eval, + share_vocab=flags.share_vocab, + metrics=flags.metrics.split(","), + log_device_placement=flags.log_device_placement, + random_seed=flags.random_seed, + override_loaded_hparams=flags.override_loaded_hparams, + num_keep_ckpts=flags.num_keep_ckpts, + avg_ckpts=flags.avg_ckpts, + language_model=flags.language_model, + num_intra_threads=flags.num_intra_threads, + num_inter_threads=flags.num_inter_threads, + ) def _add_argument(hparams, key, value, update=True): - """Add an argument to hparams; if exists, change the value if update==True.""" - if hasattr(hparams, key): - if update: - setattr(hparams, key, value) - else: - hparams.add_hparam(key, value) + """Add an argument to hparams; if exists, change the value if update==True.""" + if hasattr(hparams, key): + if update: + setattr(hparams, key, value) + else: + hparams.add_hparam(key, value) def extend_hparams(hparams): - """Add new arguments to hparams.""" - # Sanity checks - if hparams.encoder_type == "bi" and hparams.num_encoder_layers % 2 != 0: - raise ValueError("For bi, num_encoder_layers %d should be even" % - hparams.num_encoder_layers) - if (hparams.attention_architecture in ["gnmt"] and - hparams.num_encoder_layers < 2): - raise ValueError("For gnmt attention architecture, " - "num_encoder_layers %d should be >= 2" % - hparams.num_encoder_layers) - if hparams.subword_option and hparams.subword_option not in ["spm", "bpe"]: - raise ValueError("subword option must be either spm, or bpe") - if hparams.infer_mode == "beam_search" and hparams.beam_width <= 0: - raise ValueError("beam_width must greater than 0 when using beam_search" - "decoder.") - if hparams.infer_mode == "sample" and hparams.sampling_temperature <= 0.0: - raise ValueError("sampling_temperature must greater than 0.0 when using" - "sample decoder.") - - # Different number of encoder / decoder layers - assert hparams.num_encoder_layers and hparams.num_decoder_layers - if hparams.num_encoder_layers != hparams.num_decoder_layers: - hparams.pass_hidden_state = False - utils.print_out("Num encoder layer %d is different from num decoder layer" - " %d, so set pass_hidden_state to False" % ( - hparams.num_encoder_layers, - hparams.num_decoder_layers)) - - # Set residual layers - num_encoder_residual_layers = 0 - num_decoder_residual_layers = 0 - if hparams.residual: - if hparams.num_encoder_layers > 1: - num_encoder_residual_layers = hparams.num_encoder_layers - 1 - if hparams.num_decoder_layers > 1: - num_decoder_residual_layers = hparams.num_decoder_layers - 1 - - if hparams.encoder_type == "gnmt": - # The first unidirectional layer (after the bi-directional layer) in - # the GNMT encoder can't have residual connection due to the input is - # the concatenation of fw_cell and bw_cell's outputs. - num_encoder_residual_layers = hparams.num_encoder_layers - 2 - - # Compatible for GNMT models - if hparams.num_encoder_layers == hparams.num_decoder_layers: - num_decoder_residual_layers = num_encoder_residual_layers - _add_argument(hparams, "num_encoder_residual_layers", - num_encoder_residual_layers) - _add_argument(hparams, "num_decoder_residual_layers", - num_decoder_residual_layers) - - # Language modeling - if getattr(hparams, "language_model", None): - hparams.attention = "" - hparams.attention_architecture = "" - hparams.pass_hidden_state = False - hparams.share_vocab = True - hparams.src = hparams.tgt - utils.print_out("For language modeling, we turn off attention and " - "pass_hidden_state; turn on share_vocab; set src to tgt.") - - ## Vocab - # Get vocab file names first - if hparams.vocab_prefix: - src_vocab_file = hparams.vocab_prefix + "." + hparams.src - tgt_vocab_file = hparams.vocab_prefix + "." + hparams.tgt - else: - raise ValueError("hparams.vocab_prefix must be provided.") - - # Source vocab - check_special_token = getattr(hparams, "check_special_token", True) - src_vocab_size, src_vocab_file = vocab_utils.check_vocab( - src_vocab_file, - hparams.out_dir, - check_special_token=check_special_token, - sos=hparams.sos, - eos=hparams.eos, - unk=vocab_utils.UNK) - - # Target vocab - if hparams.share_vocab: - utils.print_out(" using source vocab for target") - tgt_vocab_file = src_vocab_file - tgt_vocab_size = src_vocab_size - else: - tgt_vocab_size, tgt_vocab_file = vocab_utils.check_vocab( - tgt_vocab_file, + """Add new arguments to hparams.""" + # Sanity checks + if hparams.encoder_type == "bi" and hparams.num_encoder_layers % 2 != 0: + raise ValueError("For bi, num_encoder_layers %d should be even" % + hparams.num_encoder_layers) + if (hparams.attention_architecture in ["gnmt"] and + hparams.num_encoder_layers < 2): + raise ValueError("For gnmt attention architecture, " + "num_encoder_layers %d should be >= 2" % + hparams.num_encoder_layers) + if hparams.subword_option and hparams.subword_option not in ["spm", "bpe"]: + raise ValueError("subword option must be either spm, or bpe") + if hparams.infer_mode == "beam_search" and hparams.beam_width <= 0: + raise ValueError("beam_width must greater than 0 when using beam_search" + "decoder.") + if hparams.infer_mode == "sample" and hparams.sampling_temperature <= 0.0: + raise ValueError("sampling_temperature must greater than 0.0 when using" + "sample decoder.") + + # Different number of encoder / decoder layers + assert hparams.num_encoder_layers and hparams.num_decoder_layers + if hparams.num_encoder_layers != hparams.num_decoder_layers: + hparams.pass_hidden_state = False + utils.print_out("Num encoder layer %d is different from num decoder layer" + " %d, so set pass_hidden_state to False" % ( + hparams.num_encoder_layers, + hparams.num_decoder_layers)) + + # Set residual layers + num_encoder_residual_layers = 0 + num_decoder_residual_layers = 0 + if hparams.residual: + if hparams.num_encoder_layers > 1: + num_encoder_residual_layers = hparams.num_encoder_layers - 1 + if hparams.num_decoder_layers > 1: + num_decoder_residual_layers = hparams.num_decoder_layers - 1 + + if hparams.encoder_type == "gnmt": + # The first unidirectional layer (after the bi-directional layer) in + # the GNMT encoder can't have residual connection due to the input is + # the concatenation of fw_cell and bw_cell's outputs. + num_encoder_residual_layers = hparams.num_encoder_layers - 2 + + # Compatible for GNMT models + if hparams.num_encoder_layers == hparams.num_decoder_layers: + num_decoder_residual_layers = num_encoder_residual_layers + _add_argument(hparams, "num_encoder_residual_layers", + num_encoder_residual_layers) + _add_argument(hparams, "num_decoder_residual_layers", + num_decoder_residual_layers) + + # Language modeling + if getattr(hparams, "language_model", None): + hparams.attention = "" + hparams.attention_architecture = "" + hparams.pass_hidden_state = False + hparams.share_vocab = True + hparams.src = hparams.tgt + utils.print_out("For language modeling, we turn off attention and " + "pass_hidden_state; turn on share_vocab; set src to tgt.") + + # Vocab + # Get vocab file names first + if hparams.vocab_prefix: + src_vocab_file = hparams.vocab_prefix + "." + hparams.src + tgt_vocab_file = hparams.vocab_prefix + "." + hparams.tgt + else: + raise ValueError("hparams.vocab_prefix must be provided.") + + # Source vocab + check_special_token = getattr(hparams, "check_special_token", True) + src_vocab_size, src_vocab_file = vocab_utils.check_vocab( + src_vocab_file, hparams.out_dir, check_special_token=check_special_token, sos=hparams.sos, eos=hparams.eos, unk=vocab_utils.UNK) - _add_argument(hparams, "src_vocab_size", src_vocab_size) - _add_argument(hparams, "tgt_vocab_size", tgt_vocab_size) - _add_argument(hparams, "src_vocab_file", src_vocab_file) - _add_argument(hparams, "tgt_vocab_file", tgt_vocab_file) - - # Num embedding partitions - num_embeddings_partitions = getattr(hparams, "num_embeddings_partitions", 0) - _add_argument(hparams, "num_enc_emb_partitions", num_embeddings_partitions) - _add_argument(hparams, "num_dec_emb_partitions", num_embeddings_partitions) - - # Pretrained Embeddings - _add_argument(hparams, "src_embed_file", "") - _add_argument(hparams, "tgt_embed_file", "") - if getattr(hparams, "embed_prefix", None): - src_embed_file = hparams.embed_prefix + "." + hparams.src - tgt_embed_file = hparams.embed_prefix + "." + hparams.tgt - - if tf.gfile.Exists(src_embed_file): - utils.print_out(" src_embed_file %s exist" % src_embed_file) - hparams.src_embed_file = src_embed_file - - utils.print_out( - "For pretrained embeddings, set num_enc_emb_partitions to 1") - hparams.num_enc_emb_partitions = 1 - else: - utils.print_out(" src_embed_file %s doesn't exist" % src_embed_file) - - if tf.gfile.Exists(tgt_embed_file): - utils.print_out(" tgt_embed_file %s exist" % tgt_embed_file) - hparams.tgt_embed_file = tgt_embed_file - utils.print_out( - "For pretrained embeddings, set num_dec_emb_partitions to 1") - hparams.num_dec_emb_partitions = 1 + # Target vocab + if hparams.share_vocab: + utils.print_out(" using source vocab for target") + tgt_vocab_file = src_vocab_file + tgt_vocab_size = src_vocab_size else: - utils.print_out(" tgt_embed_file %s doesn't exist" % tgt_embed_file) + tgt_vocab_size, tgt_vocab_file = vocab_utils.check_vocab( + tgt_vocab_file, + hparams.out_dir, + check_special_token=check_special_token, + sos=hparams.sos, + eos=hparams.eos, + unk=vocab_utils.UNK) + _add_argument(hparams, "src_vocab_size", src_vocab_size) + _add_argument(hparams, "tgt_vocab_size", tgt_vocab_size) + _add_argument(hparams, "src_vocab_file", src_vocab_file) + _add_argument(hparams, "tgt_vocab_file", tgt_vocab_file) + + # Num embedding partitions + num_embeddings_partitions = getattr(hparams, "num_embeddings_partitions", 0) + _add_argument(hparams, "num_enc_emb_partitions", num_embeddings_partitions) + _add_argument(hparams, "num_dec_emb_partitions", num_embeddings_partitions) + + # Pretrained Embeddings + _add_argument(hparams, "src_embed_file", "") + _add_argument(hparams, "tgt_embed_file", "") + if getattr(hparams, "embed_prefix", None): + src_embed_file = hparams.embed_prefix + "." + hparams.src + tgt_embed_file = hparams.embed_prefix + "." + hparams.tgt + + if tf.gfile.Exists(src_embed_file): + utils.print_out(" src_embed_file %s exist" % src_embed_file) + hparams.src_embed_file = src_embed_file + + utils.print_out( + "For pretrained embeddings, set num_enc_emb_partitions to 1") + hparams.num_enc_emb_partitions = 1 + else: + utils.print_out(" src_embed_file %s doesn't exist" % src_embed_file) + + if tf.gfile.Exists(tgt_embed_file): + utils.print_out(" tgt_embed_file %s exist" % tgt_embed_file) + hparams.tgt_embed_file = tgt_embed_file + + utils.print_out( + "For pretrained embeddings, set num_dec_emb_partitions to 1") + hparams.num_dec_emb_partitions = 1 + else: + utils.print_out(" tgt_embed_file %s doesn't exist" % tgt_embed_file) - # Evaluation - for metric in hparams.metrics: - best_metric_dir = os.path.join(hparams.out_dir, "best_" + metric) - tf.gfile.MakeDirs(best_metric_dir) - _add_argument(hparams, "best_" + metric, 0, update=False) - _add_argument(hparams, "best_" + metric + "_dir", best_metric_dir) + # Evaluation + for metric in hparams.metrics: + best_metric_dir = os.path.join(hparams.out_dir, "best_" + metric) + tf.gfile.MakeDirs(best_metric_dir) + _add_argument(hparams, "best_" + metric, 0, update=False) + _add_argument(hparams, "best_" + metric + "_dir", best_metric_dir) - if getattr(hparams, "avg_ckpts", None): - best_metric_dir = os.path.join(hparams.out_dir, "avg_best_" + metric) - tf.gfile.MakeDirs(best_metric_dir) - _add_argument(hparams, "avg_best_" + metric, 0, update=False) - _add_argument(hparams, "avg_best_" + metric + "_dir", best_metric_dir) + if getattr(hparams, "avg_ckpts", None): + best_metric_dir = os.path.join(hparams.out_dir, "avg_best_" + metric) + tf.gfile.MakeDirs(best_metric_dir) + _add_argument(hparams, "avg_best_" + metric, 0, update=False) + _add_argument(hparams, "avg_best_" + metric + "_dir", best_metric_dir) - return hparams + return hparams def ensure_compatible_hparams(hparams, default_hparams, hparams_path=""): - """Make sure the loaded hparams is compatible with new changes.""" - default_hparams = utils.maybe_parse_standard_hparams( - default_hparams, hparams_path) - - # Set num encoder/decoder layers (for old checkpoints) - if hasattr(hparams, "num_layers"): - if not hasattr(hparams, "num_encoder_layers"): - hparams.add_hparam("num_encoder_layers", hparams.num_layers) - if not hasattr(hparams, "num_decoder_layers"): - hparams.add_hparam("num_decoder_layers", hparams.num_layers) - - # For compatible reason, if there are new fields in default_hparams, - # we add them to the current hparams - default_config = default_hparams.values() - config = hparams.values() - for key in default_config: - if key not in config: - hparams.add_hparam(key, default_config[key]) - - # Update all hparams' keys if override_loaded_hparams=True - if getattr(default_hparams, "override_loaded_hparams", None): - overwritten_keys = default_config.keys() - else: - # For inference - overwritten_keys = INFERENCE_KEYS - - for key in overwritten_keys: - if getattr(hparams, key) != default_config[key]: - utils.print_out("# Updating hparams.%s: %s -> %s" % - (key, str(getattr(hparams, key)), - str(default_config[key]))) - setattr(hparams, key, default_config[key]) - return hparams + """Make sure the loaded hparams is compatible with new changes.""" + default_hparams = utils.maybe_parse_standard_hparams( + default_hparams, hparams_path) + + # Set num encoder/decoder layers (for old checkpoints) + if hasattr(hparams, "num_layers"): + if not hasattr(hparams, "num_encoder_layers"): + hparams.add_hparam("num_encoder_layers", hparams.num_layers) + if not hasattr(hparams, "num_decoder_layers"): + hparams.add_hparam("num_decoder_layers", hparams.num_layers) + + # For compatible reason, if there are new fields in default_hparams, + # we add them to the current hparams + default_config = default_hparams.values() + config = hparams.values() + for key in default_config: + if key not in config: + hparams.add_hparam(key, default_config[key]) + + # Update all hparams' keys if override_loaded_hparams=True + if getattr(default_hparams, "override_loaded_hparams", None): + overwritten_keys = default_config.keys() + else: + # For inference + overwritten_keys = INFERENCE_KEYS + + for key in overwritten_keys: + if getattr(hparams, key) != default_config[key]: + utils.print_out("# Updating hparams.%s: %s -> %s" % + (key, str(getattr(hparams, key)), + str(default_config[key]))) + setattr(hparams, key, default_config[key]) + return hparams def create_or_load_hparams( - out_dir, default_hparams, hparams_path, save_hparams=True): - """Create hparams or load hparams from out_dir.""" - hparams = utils.load_hparams(out_dir) - if not hparams: - hparams = default_hparams - hparams = utils.maybe_parse_standard_hparams( - hparams, hparams_path) - else: - hparams = ensure_compatible_hparams(hparams, default_hparams, hparams_path) - hparams = extend_hparams(hparams) - - # Save HParams - if save_hparams: - utils.save_hparams(out_dir, hparams) - for metric in hparams.metrics: - utils.save_hparams(getattr(hparams, "best_" + metric + "_dir"), hparams) + out_dir, default_hparams, hparams_path, save_hparams=True): + """Create hparams or load hparams from out_dir.""" + hparams = utils.load_hparams(out_dir) + if not hparams: + hparams = default_hparams + hparams = utils.maybe_parse_standard_hparams( + hparams, hparams_path) + else: + hparams = ensure_compatible_hparams(hparams, default_hparams, hparams_path) + hparams = extend_hparams(hparams) - # Print HParams - utils.print_hparams(hparams) - return hparams + # Save HParams + if save_hparams: + utils.save_hparams(out_dir, hparams) + for metric in hparams.metrics: + utils.save_hparams(getattr(hparams, "best_" + metric + "_dir"), hparams) + # Print HParams + utils.print_hparams(hparams) + return hparams -def run_main(flags, default_hparams, train_fn, inference_fn, target_session=""): - """Run main.""" - # Job - jobid = flags.jobid - num_workers = flags.num_workers - utils.print_out("# Job id %d" % jobid) - - # GPU device - utils.print_out( - "# Devices visible to TensorFlow: %s" % repr(tf.Session().list_devices())) - - # Random - random_seed = flags.random_seed - if random_seed is not None and random_seed > 0: - utils.print_out("# Set random seed to %d" % random_seed) - random.seed(random_seed + jobid) - np.random.seed(random_seed + jobid) - - # Model output directory - out_dir = flags.out_dir - if out_dir and not tf.gfile.Exists(out_dir): - utils.print_out("# Creating output directory %s ..." % out_dir) - tf.gfile.MakeDirs(out_dir) - - # Load hparams. - loaded_hparams = False - if flags.ckpt: # Try to load hparams from the same directory as ckpt - ckpt_dir = os.path.dirname(flags.ckpt) - ckpt_hparams_file = os.path.join(ckpt_dir, "hparams") - if tf.gfile.Exists(ckpt_hparams_file) or flags.hparams_path: - hparams = create_or_load_hparams( - ckpt_dir, default_hparams, flags.hparams_path, - save_hparams=False) - loaded_hparams = True - if not loaded_hparams: # Try to load from out_dir - assert out_dir - hparams = create_or_load_hparams( - out_dir, default_hparams, flags.hparams_path, - save_hparams=(jobid == 0)) - - ## Train / Decode - if flags.inference_input_file: - # Inference output directory - trans_file = flags.inference_output_file - assert trans_file - trans_dir = os.path.dirname(trans_file) - if not tf.gfile.Exists(trans_dir): tf.gfile.MakeDirs(trans_dir) - - # Inference indices - hparams.inference_indices = None - if flags.inference_list: - (hparams.inference_indices) = ( - [int(token) for token in flags.inference_list.split(",")]) - # Inference - ckpt = flags.ckpt - if not ckpt: - ckpt = tf.train.latest_checkpoint(out_dir) - inference_fn(ckpt, flags.inference_input_file, - trans_file, hparams, num_workers, jobid) - - # Evaluation - ref_file = flags.inference_ref_file - if ref_file and tf.gfile.Exists(trans_file): - for metric in hparams.metrics: - score = evaluation_utils.evaluate( - ref_file, - trans_file, - metric, - hparams.subword_option) - utils.print_out(" %s: %.1f" % (metric, score)) - else: - # Train - train_fn(hparams, target_session=target_session) +def run_main(flags, default_hparams, train_fn, inference_fn, target_session=""): + """Run main.""" + # Job + jobid = flags.jobid + num_workers = flags.num_workers + utils.print_out("# Job id %d" % jobid) + + # GPU device + utils.print_out( + "# Devices visible to TensorFlow: %s" % repr(tf.Session().list_devices())) + + # Random + random_seed = flags.random_seed + if random_seed is not None and random_seed > 0: + utils.print_out("# Set random seed to %d" % random_seed) + random.seed(random_seed + jobid) + np.random.seed(random_seed + jobid) + + # Model output directory + out_dir = flags.out_dir + if out_dir and not tf.gfile.Exists(out_dir): + utils.print_out("# Creating output directory %s ..." % out_dir) + tf.gfile.MakeDirs(out_dir) + + # Load hparams. + loaded_hparams = False + if flags.ckpt: # Try to load hparams from the same directory as ckpt + ckpt_dir = os.path.dirname(flags.ckpt) + ckpt_hparams_file = os.path.join(ckpt_dir, "hparams") + if tf.gfile.Exists(ckpt_hparams_file) or flags.hparams_path: + hparams = create_or_load_hparams( + ckpt_dir, default_hparams, flags.hparams_path, + save_hparams=False) + loaded_hparams = True + if not loaded_hparams: # Try to load from out_dir + assert out_dir + hparams = create_or_load_hparams( + out_dir, default_hparams, flags.hparams_path, + save_hparams=(jobid == 0)) + + ## Train / Decode + if flags.inference_input_file: + # Inference output directory + trans_file = flags.inference_output_file + assert trans_file + trans_dir = os.path.dirname(trans_file) + if not tf.gfile.Exists(trans_dir): + tf.gfile.MakeDirs(trans_dir) + + # Inference indices + hparams.inference_indices = None + if flags.inference_list: + (hparams.inference_indices) = ( + [int(token) for token in flags.inference_list.split(",")]) + + # Inference + ckpt = flags.ckpt + if not ckpt: + ckpt = tf.train.latest_checkpoint(out_dir) + inference_fn(ckpt, flags.inference_input_file, + trans_file, hparams, num_workers, jobid) + + # Evaluation + ref_file = flags.inference_ref_file + if ref_file and tf.gfile.Exists(trans_file): + for metric in hparams.metrics: + score = evaluation_utils.evaluate( + ref_file, + trans_file, + metric, + hparams.subword_option) + utils.print_out(" %s: %.1f" % (metric, score)) + else: + # Train + train_fn(hparams, target_session=target_session) def main(unused_argv): - default_hparams = create_hparams(FLAGS) - train_fn = train.train - inference_fn = inference.inference - run_main(FLAGS, default_hparams, train_fn, inference_fn) + default_hparams = create_hparams(FLAGS) + train_fn = train.train + inference_fn = inference.inference + run_main(FLAGS, default_hparams, train_fn, inference_fn) if __name__ == "__main__": - nmt_parser = argparse.ArgumentParser() - add_arguments(nmt_parser) - FLAGS, unparsed = nmt_parser.parse_known_args() - tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) + nmt_parser = argparse.ArgumentParser() + add_arguments(nmt_parser) + FLAGS, unparsed = nmt_parser.parse_known_args() + tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/nmt/nmt_test.py b/nmt/nmt_test.py index c12179d..962b3b7 100644 --- a/nmt/nmt_test.py +++ b/nmt/nmt_test.py @@ -29,79 +29,77 @@ def _update_flags(flags, test_name): - """Update flags for basic training.""" - flags.num_train_steps = 100 - flags.steps_per_stats = 5 - flags.src = "en" - flags.tgt = "vi" - flags.train_prefix = ("nmt/testdata/" + """Update flags for basic training.""" + flags.num_train_steps = 100 + flags.steps_per_stats = 5 + flags.src = "en" + flags.tgt = "vi" + flags.train_prefix = ("nmt/testdata/" + "iwslt15.tst2013.100") + flags.vocab_prefix = ("nmt/testdata/" + "iwslt15.vocab.100") + flags.dev_prefix = ("nmt/testdata/" "iwslt15.tst2013.100") - flags.vocab_prefix = ("nmt/testdata/" - "iwslt15.vocab.100") - flags.dev_prefix = ("nmt/testdata/" - "iwslt15.tst2013.100") - flags.test_prefix = ("nmt/testdata/" - "iwslt15.tst2013.100") - flags.out_dir = os.path.join(tf.test.get_temp_dir(), test_name) + flags.test_prefix = ("nmt/testdata/" + "iwslt15.tst2013.100") + flags.out_dir = os.path.join(tf.test.get_temp_dir(), test_name) class NMTTest(tf.test.TestCase): - def testTrain(self): - """Test the training loop is functional with basic hparams.""" - nmt_parser = argparse.ArgumentParser() - nmt.add_arguments(nmt_parser) - FLAGS, unparsed = nmt_parser.parse_known_args() + def testTrain(self): + """Test the training loop is functional with basic hparams.""" + nmt_parser = argparse.ArgumentParser() + nmt.add_arguments(nmt_parser) + FLAGS, unparsed = nmt_parser.parse_known_args() - _update_flags(FLAGS, "nmt_train_test") + _update_flags(FLAGS, "nmt_train_test") - default_hparams = nmt.create_hparams(FLAGS) + default_hparams = nmt.create_hparams(FLAGS) - train_fn = train.train - nmt.run_main(FLAGS, default_hparams, train_fn, None) + train_fn = train.train + nmt.run_main(FLAGS, default_hparams, train_fn, None) + def testTrainWithAvgCkpts(self): + """Test the training loop is functional with basic hparams.""" + nmt_parser = argparse.ArgumentParser() + nmt.add_arguments(nmt_parser) + FLAGS, unparsed = nmt_parser.parse_known_args() - def testTrainWithAvgCkpts(self): - """Test the training loop is functional with basic hparams.""" - nmt_parser = argparse.ArgumentParser() - nmt.add_arguments(nmt_parser) - FLAGS, unparsed = nmt_parser.parse_known_args() + _update_flags(FLAGS, "nmt_train_test_avg_ckpts") + FLAGS.avg_ckpts = True - _update_flags(FLAGS, "nmt_train_test_avg_ckpts") - FLAGS.avg_ckpts = True + default_hparams = nmt.create_hparams(FLAGS) - default_hparams = nmt.create_hparams(FLAGS) + train_fn = train.train + nmt.run_main(FLAGS, default_hparams, train_fn, None) - train_fn = train.train - nmt.run_main(FLAGS, default_hparams, train_fn, None) + def testInference(self): + """Test inference is function with basic hparams.""" + nmt_parser = argparse.ArgumentParser() + nmt.add_arguments(nmt_parser) + FLAGS, unparsed = nmt_parser.parse_known_args() + _update_flags(FLAGS, "nmt_train_infer") - def testInference(self): - """Test inference is function with basic hparams.""" - nmt_parser = argparse.ArgumentParser() - nmt.add_arguments(nmt_parser) - FLAGS, unparsed = nmt_parser.parse_known_args() + # Train one step so we have a checkpoint. + FLAGS.num_train_steps = 1 + default_hparams = nmt.create_hparams(FLAGS) + train_fn = train.train + nmt.run_main(FLAGS, default_hparams, train_fn, None) - _update_flags(FLAGS, "nmt_train_infer") + # Update FLAGS for inference. + FLAGS.inference_input_file = ("nmt/testdata/" + "iwslt15.tst2013.100.en") + FLAGS.inference_output_file = os.path.join(FLAGS.out_dir, "output") + FLAGS.inference_ref_file = ("nmt/testdata/" + "iwslt15.tst2013.100.vi") - # Train one step so we have a checkpoint. - FLAGS.num_train_steps = 1 - default_hparams = nmt.create_hparams(FLAGS) - train_fn = train.train - nmt.run_main(FLAGS, default_hparams, train_fn, None) + default_hparams = nmt.create_hparams(FLAGS) - # Update FLAGS for inference. - FLAGS.inference_input_file = ("nmt/testdata/" - "iwslt15.tst2013.100.en") - FLAGS.inference_output_file = os.path.join(FLAGS.out_dir, "output") - FLAGS.inference_ref_file = ("nmt/testdata/" - "iwslt15.tst2013.100.vi") - - default_hparams = nmt.create_hparams(FLAGS) - - inference_fn = inference.inference - nmt.run_main(FLAGS, default_hparams, None, inference_fn) + inference_fn = inference.inference + nmt.run_main(FLAGS, default_hparams, None, inference_fn) if __name__ == "__main__": - tf.test.main() + tf.test.main() diff --git a/nmt/train.py b/nmt/train.py index d86d1c1..044523a 100644 --- a/nmt/train.py +++ b/nmt/train.py @@ -42,15 +42,15 @@ def run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer, src_data, tgt_data): - """Sample decode a random sentence from src_data.""" - with infer_model.graph.as_default(): - loaded_infer_model, global_step = model_helper.create_or_load_model( - infer_model.model, model_dir, infer_sess, "infer") + """Sample decode a random sentence from src_data.""" + with infer_model.graph.as_default(): + loaded_infer_model, global_step = model_helper.create_or_load_model( + infer_model.model, model_dir, infer_sess, "infer") - _sample_decode(loaded_infer_model, global_step, infer_sess, hparams, - infer_model.iterator, src_data, tgt_data, - infer_model.src_placeholder, - infer_model.batch_size_placeholder, summary_writer) + _sample_decode(loaded_infer_model, global_step, infer_sess, hparams, + infer_model.iterator, src_data, tgt_data, + infer_model.src_placeholder, + infer_model.batch_size_placeholder, summary_writer) def run_internal_eval(eval_model, @@ -61,57 +61,57 @@ def run_internal_eval(eval_model, use_test_set=True, dev_eval_iterator_feed_dict=None, test_eval_iterator_feed_dict=None): - """Compute internal evaluation (perplexity) for both dev / test. - - Computes development and testing perplexities for given model. - - Args: - eval_model: Evaluation model for which to compute perplexities. - eval_sess: Evaluation TensorFlow session. - model_dir: Directory from which to load evaluation model from. - hparams: Model hyper-parameters. - summary_writer: Summary writer for logging metrics to TensorBoard. - use_test_set: Computes testing perplexity if true; does not otherwise. - Note that the development perplexity is always computed regardless of - value of this parameter. - dev_eval_iterator_feed_dict: Feed dictionary for a TensorFlow session. - Can be used to pass in additional inputs necessary for running the - development evaluation. - test_eval_iterator_feed_dict: Feed dictionary for a TensorFlow session. - Can be used to pass in additional inputs necessary for running the - testing evaluation. - Returns: - Pair containing development perplexity and testing perplexity, in this - order. - """ - if dev_eval_iterator_feed_dict is None: - dev_eval_iterator_feed_dict = {} - if test_eval_iterator_feed_dict is None: - test_eval_iterator_feed_dict = {} - with eval_model.graph.as_default(): - loaded_eval_model, global_step = model_helper.create_or_load_model( - eval_model.model, model_dir, eval_sess, "eval") - - dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src) - dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt) - dev_eval_iterator_feed_dict[eval_model.src_file_placeholder] = dev_src_file - dev_eval_iterator_feed_dict[eval_model.tgt_file_placeholder] = dev_tgt_file - - dev_ppl = _internal_eval(loaded_eval_model, global_step, eval_sess, - eval_model.iterator, dev_eval_iterator_feed_dict, - summary_writer, "dev") - test_ppl = None - if use_test_set and hparams.test_prefix: - test_src_file = "%s.%s" % (hparams.test_prefix, hparams.src) - test_tgt_file = "%s.%s" % (hparams.test_prefix, hparams.tgt) - test_eval_iterator_feed_dict[ - eval_model.src_file_placeholder] = test_src_file - test_eval_iterator_feed_dict[ - eval_model.tgt_file_placeholder] = test_tgt_file - test_ppl = _internal_eval(loaded_eval_model, global_step, eval_sess, - eval_model.iterator, test_eval_iterator_feed_dict, - summary_writer, "test") - return dev_ppl, test_ppl + """Compute internal evaluation (perplexity) for both dev / test. + + Computes development and testing perplexities for given model. + + Args: + eval_model: Evaluation model for which to compute perplexities. + eval_sess: Evaluation TensorFlow session. + model_dir: Directory from which to load evaluation model from. + hparams: Model hyper-parameters. + summary_writer: Summary writer for logging metrics to TensorBoard. + use_test_set: Computes testing perplexity if true; does not otherwise. + Note that the development perplexity is always computed regardless of + value of this parameter. + dev_eval_iterator_feed_dict: Feed dictionary for a TensorFlow session. + Can be used to pass in additional inputs necessary for running the + development evaluation. + test_eval_iterator_feed_dict: Feed dictionary for a TensorFlow session. + Can be used to pass in additional inputs necessary for running the + testing evaluation. + Returns: + Pair containing development perplexity and testing perplexity, in this + order. + """ + if dev_eval_iterator_feed_dict is None: + dev_eval_iterator_feed_dict = {} + if test_eval_iterator_feed_dict is None: + test_eval_iterator_feed_dict = {} + with eval_model.graph.as_default(): + loaded_eval_model, global_step = model_helper.create_or_load_model( + eval_model.model, model_dir, eval_sess, "eval") + + dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src) + dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt) + dev_eval_iterator_feed_dict[eval_model.src_file_placeholder] = dev_src_file + dev_eval_iterator_feed_dict[eval_model.tgt_file_placeholder] = dev_tgt_file + + dev_ppl = _internal_eval(loaded_eval_model, global_step, eval_sess, + eval_model.iterator, dev_eval_iterator_feed_dict, + summary_writer, "dev") + test_ppl = None + if use_test_set and hparams.test_prefix: + test_src_file = "%s.%s" % (hparams.test_prefix, hparams.src) + test_tgt_file = "%s.%s" % (hparams.test_prefix, hparams.tgt) + test_eval_iterator_feed_dict[ + eval_model.src_file_placeholder] = test_src_file + test_eval_iterator_feed_dict[ + eval_model.tgt_file_placeholder] = test_tgt_file + test_ppl = _internal_eval(loaded_eval_model, global_step, eval_sess, + eval_model.iterator, test_eval_iterator_feed_dict, + summary_writer, "test") + return dev_ppl, test_ppl def run_external_eval(infer_model, @@ -124,100 +124,100 @@ def run_external_eval(infer_model, avg_ckpts=False, dev_infer_iterator_feed_dict=None, test_infer_iterator_feed_dict=None): - """Compute external evaluation for both dev / test. - - Computes development and testing external evaluation (e.g. bleu, rouge) for - given model. - - Args: - infer_model: Inference model for which to compute perplexities. - infer_sess: Inference TensorFlow session. - model_dir: Directory from which to load inference model from. - hparams: Model hyper-parameters. - summary_writer: Summary writer for logging metrics to TensorBoard. - use_test_set: Computes testing external evaluation if true; does not - otherwise. Note that the development external evaluation is always - computed regardless of value of this parameter. - dev_infer_iterator_feed_dict: Feed dictionary for a TensorFlow session. - Can be used to pass in additional inputs necessary for running the - development external evaluation. - test_infer_iterator_feed_dict: Feed dictionary for a TensorFlow session. - Can be used to pass in additional inputs necessary for running the - testing external evaluation. - Returns: - Triple containing development scores, testing scores and the TensorFlow - Variable for the global step number, in this order. - """ - if dev_infer_iterator_feed_dict is None: - dev_infer_iterator_feed_dict = {} - if test_infer_iterator_feed_dict is None: - test_infer_iterator_feed_dict = {} - with infer_model.graph.as_default(): - loaded_infer_model, global_step = model_helper.create_or_load_model( - infer_model.model, model_dir, infer_sess, "infer") - - dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src) - dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt) - dev_infer_iterator_feed_dict[ - infer_model.src_placeholder] = inference.load_data(dev_src_file) - dev_infer_iterator_feed_dict[ - infer_model.batch_size_placeholder] = hparams.infer_batch_size - dev_scores = _external_eval( - loaded_infer_model, - global_step, - infer_sess, - hparams, - infer_model.iterator, - dev_infer_iterator_feed_dict, - dev_tgt_file, - "dev", - summary_writer, - save_on_best=save_best_dev, - avg_ckpts=avg_ckpts) - - test_scores = None - if use_test_set and hparams.test_prefix: - test_src_file = "%s.%s" % (hparams.test_prefix, hparams.src) - test_tgt_file = "%s.%s" % (hparams.test_prefix, hparams.tgt) - test_infer_iterator_feed_dict[ - infer_model.src_placeholder] = inference.load_data(test_src_file) - test_infer_iterator_feed_dict[ + """Compute external evaluation for both dev / test. + + Computes development and testing external evaluation (e.g. bleu, rouge) for + given model. + + Args: + infer_model: Inference model for which to compute perplexities. + infer_sess: Inference TensorFlow session. + model_dir: Directory from which to load inference model from. + hparams: Model hyper-parameters. + summary_writer: Summary writer for logging metrics to TensorBoard. + use_test_set: Computes testing external evaluation if true; does not + otherwise. Note that the development external evaluation is always + computed regardless of value of this parameter. + dev_infer_iterator_feed_dict: Feed dictionary for a TensorFlow session. + Can be used to pass in additional inputs necessary for running the + development external evaluation. + test_infer_iterator_feed_dict: Feed dictionary for a TensorFlow session. + Can be used to pass in additional inputs necessary for running the + testing external evaluation. + Returns: + Triple containing development scores, testing scores and the TensorFlow + Variable for the global step number, in this order. + """ + if dev_infer_iterator_feed_dict is None: + dev_infer_iterator_feed_dict = {} + if test_infer_iterator_feed_dict is None: + test_infer_iterator_feed_dict = {} + with infer_model.graph.as_default(): + loaded_infer_model, global_step = model_helper.create_or_load_model( + infer_model.model, model_dir, infer_sess, "infer") + + dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src) + dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt) + dev_infer_iterator_feed_dict[ + infer_model.src_placeholder] = inference.load_data(dev_src_file) + dev_infer_iterator_feed_dict[ infer_model.batch_size_placeholder] = hparams.infer_batch_size - test_scores = _external_eval( + dev_scores = _external_eval( loaded_infer_model, global_step, infer_sess, hparams, infer_model.iterator, - test_infer_iterator_feed_dict, - test_tgt_file, - "test", + dev_infer_iterator_feed_dict, + dev_tgt_file, + "dev", summary_writer, - save_on_best=False, + save_on_best=save_best_dev, avg_ckpts=avg_ckpts) - return dev_scores, test_scores, global_step + + test_scores = None + if use_test_set and hparams.test_prefix: + test_src_file = "%s.%s" % (hparams.test_prefix, hparams.src) + test_tgt_file = "%s.%s" % (hparams.test_prefix, hparams.tgt) + test_infer_iterator_feed_dict[ + infer_model.src_placeholder] = inference.load_data(test_src_file) + test_infer_iterator_feed_dict[ + infer_model.batch_size_placeholder] = hparams.infer_batch_size + test_scores = _external_eval( + loaded_infer_model, + global_step, + infer_sess, + hparams, + infer_model.iterator, + test_infer_iterator_feed_dict, + test_tgt_file, + "test", + summary_writer, + save_on_best=False, + avg_ckpts=avg_ckpts) + return dev_scores, test_scores, global_step def run_avg_external_eval(infer_model, infer_sess, model_dir, hparams, summary_writer, global_step): - """Creates an averaged checkpoint and run external eval with it.""" - avg_dev_scores, avg_test_scores = None, None - if hparams.avg_ckpts: - # Convert VariableName:0 to VariableName. - global_step_name = infer_model.model.global_step.name.split(":")[0] - avg_model_dir = model_helper.avg_checkpoints( - model_dir, hparams.num_keep_ckpts, global_step, global_step_name) - - if avg_model_dir: - avg_dev_scores, avg_test_scores, _ = run_external_eval( - infer_model, - infer_sess, - avg_model_dir, - hparams, - summary_writer, - avg_ckpts=True) - - return avg_dev_scores, avg_test_scores + """Creates an averaged checkpoint and run external eval with it.""" + avg_dev_scores, avg_test_scores = None, None + if hparams.avg_ckpts: + # Convert VariableName:0 to VariableName. + global_step_name = infer_model.model.global_step.name.split(":")[0] + avg_model_dir = model_helper.avg_checkpoints( + model_dir, hparams.num_keep_ckpts, global_step, global_step_name) + + if avg_model_dir: + avg_dev_scores, avg_test_scores, _ = run_external_eval( + infer_model, + infer_sess, + avg_model_dir, + hparams, + summary_writer, + avg_ckpts=True) + + return avg_dev_scores, avg_test_scores def run_internal_and_external_eval(model_dir, @@ -232,79 +232,79 @@ def run_internal_and_external_eval(model_dir, test_eval_iterator_feed_dict=None, dev_infer_iterator_feed_dict=None, test_infer_iterator_feed_dict=None): - """Compute internal evaluation (perplexity) for both dev / test. - - Computes development and testing perplexities for given model. - - Args: - model_dir: Directory from which to load models from. - infer_model: Inference model for which to compute perplexities. - infer_sess: Inference TensorFlow session. - eval_model: Evaluation model for which to compute perplexities. - eval_sess: Evaluation TensorFlow session. - hparams: Model hyper-parameters. - summary_writer: Summary writer for logging metrics to TensorBoard. - avg_ckpts: Whether to compute average external evaluation scores. - dev_eval_iterator_feed_dict: Feed dictionary for a TensorFlow session. - Can be used to pass in additional inputs necessary for running the - internal development evaluation. - test_eval_iterator_feed_dict: Feed dictionary for a TensorFlow session. - Can be used to pass in additional inputs necessary for running the - internal testing evaluation. - dev_infer_iterator_feed_dict: Feed dictionary for a TensorFlow session. - Can be used to pass in additional inputs necessary for running the - external development evaluation. - test_infer_iterator_feed_dict: Feed dictionary for a TensorFlow session. - Can be used to pass in additional inputs necessary for running the - external testing evaluation. - Returns: - Triple containing results summary, global step Tensorflow Variable and - metrics in this order. - """ - dev_ppl, test_ppl = run_internal_eval( - eval_model, - eval_sess, - model_dir, - hparams, - summary_writer, - dev_eval_iterator_feed_dict=dev_eval_iterator_feed_dict, - test_eval_iterator_feed_dict=test_eval_iterator_feed_dict) - dev_scores, test_scores, global_step = run_external_eval( - infer_model, - infer_sess, - model_dir, - hparams, - summary_writer, - dev_infer_iterator_feed_dict=dev_infer_iterator_feed_dict, - test_infer_iterator_feed_dict=test_infer_iterator_feed_dict) - - metrics = { - "dev_ppl": dev_ppl, - "test_ppl": test_ppl, - "dev_scores": dev_scores, - "test_scores": test_scores, - } - - avg_dev_scores, avg_test_scores = None, None - if avg_ckpts: - avg_dev_scores, avg_test_scores = run_avg_external_eval( - infer_model, infer_sess, model_dir, hparams, summary_writer, - global_step) - metrics["avg_dev_scores"] = avg_dev_scores - metrics["avg_test_scores"] = avg_test_scores - - result_summary = _format_results("dev", dev_ppl, dev_scores, hparams.metrics) - if avg_dev_scores: - result_summary += ", " + _format_results("avg_dev", None, avg_dev_scores, - hparams.metrics) - if hparams.test_prefix: - result_summary += ", " + _format_results("test", test_ppl, test_scores, - hparams.metrics) - if avg_test_scores: - result_summary += ", " + _format_results("avg_test", None, - avg_test_scores, hparams.metrics) - - return result_summary, global_step, metrics + """Compute internal evaluation (perplexity) for both dev / test. + + Computes development and testing perplexities for given model. + + Args: + model_dir: Directory from which to load models from. + infer_model: Inference model for which to compute perplexities. + infer_sess: Inference TensorFlow session. + eval_model: Evaluation model for which to compute perplexities. + eval_sess: Evaluation TensorFlow session. + hparams: Model hyper-parameters. + summary_writer: Summary writer for logging metrics to TensorBoard. + avg_ckpts: Whether to compute average external evaluation scores. + dev_eval_iterator_feed_dict: Feed dictionary for a TensorFlow session. + Can be used to pass in additional inputs necessary for running the + internal development evaluation. + test_eval_iterator_feed_dict: Feed dictionary for a TensorFlow session. + Can be used to pass in additional inputs necessary for running the + internal testing evaluation. + dev_infer_iterator_feed_dict: Feed dictionary for a TensorFlow session. + Can be used to pass in additional inputs necessary for running the + external development evaluation. + test_infer_iterator_feed_dict: Feed dictionary for a TensorFlow session. + Can be used to pass in additional inputs necessary for running the + external testing evaluation. + Returns: + Triple containing results summary, global step Tensorflow Variable and + metrics in this order. + """ + dev_ppl, test_ppl = run_internal_eval( + eval_model, + eval_sess, + model_dir, + hparams, + summary_writer, + dev_eval_iterator_feed_dict=dev_eval_iterator_feed_dict, + test_eval_iterator_feed_dict=test_eval_iterator_feed_dict) + dev_scores, test_scores, global_step = run_external_eval( + infer_model, + infer_sess, + model_dir, + hparams, + summary_writer, + dev_infer_iterator_feed_dict=dev_infer_iterator_feed_dict, + test_infer_iterator_feed_dict=test_infer_iterator_feed_dict) + + metrics = { + "dev_ppl": dev_ppl, + "test_ppl": test_ppl, + "dev_scores": dev_scores, + "test_scores": test_scores, + } + + avg_dev_scores, avg_test_scores = None, None + if avg_ckpts: + avg_dev_scores, avg_test_scores = run_avg_external_eval( + infer_model, infer_sess, model_dir, hparams, summary_writer, + global_step) + metrics["avg_dev_scores"] = avg_dev_scores + metrics["avg_test_scores"] = avg_test_scores + + result_summary = _format_results("dev", dev_ppl, dev_scores, hparams.metrics) + if avg_dev_scores: + result_summary += ", " + _format_results("avg_dev", None, avg_dev_scores, + hparams.metrics) + if hparams.test_prefix: + result_summary += ", " + _format_results("test", test_ppl, test_scores, + hparams.metrics) + if avg_test_scores: + result_summary += ", " + _format_results("avg_test", None, + avg_test_scores, hparams.metrics) + + return result_summary, global_step, metrics def run_full_eval(model_dir, @@ -317,434 +317,434 @@ def run_full_eval(model_dir, sample_src_data, sample_tgt_data, avg_ckpts=False): - """Wrapper for running sample_decode, internal_eval and external_eval. - - Args: - model_dir: Directory from which to load models from. - infer_model: Inference model for which to compute perplexities. - infer_sess: Inference TensorFlow session. - eval_model: Evaluation model for which to compute perplexities. - eval_sess: Evaluation TensorFlow session. - hparams: Model hyper-parameters. - summary_writer: Summary writer for logging metrics to TensorBoard. - sample_src_data: sample of source data for sample decoding. - sample_tgt_data: sample of target data for sample decoding. - avg_ckpts: Whether to compute average external evaluation scores. - Returns: - Triple containing results summary, global step Tensorflow Variable and - metrics in this order. - """ - run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer, - sample_src_data, sample_tgt_data) - return run_internal_and_external_eval(model_dir, infer_model, infer_sess, - eval_model, eval_sess, hparams, - summary_writer, avg_ckpts) + """Wrapper for running sample_decode, internal_eval and external_eval. + + Args: + model_dir: Directory from which to load models from. + infer_model: Inference model for which to compute perplexities. + infer_sess: Inference TensorFlow session. + eval_model: Evaluation model for which to compute perplexities. + eval_sess: Evaluation TensorFlow session. + hparams: Model hyper-parameters. + summary_writer: Summary writer for logging metrics to TensorBoard. + sample_src_data: sample of source data for sample decoding. + sample_tgt_data: sample of target data for sample decoding. + avg_ckpts: Whether to compute average external evaluation scores. + Returns: + Triple containing results summary, global step Tensorflow Variable and + metrics in this order. + """ + run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer, + sample_src_data, sample_tgt_data) + return run_internal_and_external_eval(model_dir, infer_model, infer_sess, + eval_model, eval_sess, hparams, + summary_writer, avg_ckpts) def init_stats(): - """Initialize statistics that we want to accumulate.""" - return {"step_time": 0.0, "train_loss": 0.0, - "predict_count": 0.0, # word count on the target side - "word_count": 0.0, # word counts for both source and target - "sequence_count": 0.0, # number of training examples processed - "grad_norm": 0.0} + """Initialize statistics that we want to accumulate.""" + return {"step_time": 0.0, "train_loss": 0.0, + "predict_count": 0.0, # word count on the target side + "word_count": 0.0, # word counts for both source and target + "sequence_count": 0.0, # number of training examples processed + "grad_norm": 0.0} def update_stats(stats, start_time, step_result): - """Update stats: write summary and accumulate statistics.""" - _, output_tuple = step_result + """Update stats: write summary and accumulate statistics.""" + _, output_tuple = step_result - # Update statistics - batch_size = output_tuple.batch_size - stats["step_time"] += time.time() - start_time - stats["train_loss"] += output_tuple.train_loss * batch_size - stats["grad_norm"] += output_tuple.grad_norm - stats["predict_count"] += output_tuple.predict_count - stats["word_count"] += output_tuple.word_count - stats["sequence_count"] += batch_size + # Update statistics + batch_size = output_tuple.batch_size + stats["step_time"] += time.time() - start_time + stats["train_loss"] += output_tuple.train_loss * batch_size + stats["grad_norm"] += output_tuple.grad_norm + stats["predict_count"] += output_tuple.predict_count + stats["word_count"] += output_tuple.word_count + stats["sequence_count"] += batch_size - return (output_tuple.global_step, output_tuple.learning_rate, - output_tuple.train_summary) + return (output_tuple.global_step, output_tuple.learning_rate, + output_tuple.train_summary) def print_step_info(prefix, global_step, info, result_summary, log_f): - """Print all info at the current global step.""" - utils.print_out( - "%sstep %d lr %g step-time %.2fs wps %.2fK ppl %.2f gN %.2f %s, %s" % - (prefix, global_step, info["learning_rate"], info["avg_step_time"], - info["speed"], info["train_ppl"], info["avg_grad_norm"], result_summary, - time.ctime()), - log_f) + """Print all info at the current global step.""" + utils.print_out( + "%sstep %d lr %g step-time %.2fs wps %.2fK ppl %.2f gN %.2f %s, %s" % + (prefix, global_step, info["learning_rate"], info["avg_step_time"], + info["speed"], info["train_ppl"], info["avg_grad_norm"], result_summary, + time.ctime()), + log_f) def add_info_summaries(summary_writer, global_step, info): - """Add stuffs in info to summaries.""" - excluded_list = ["learning_rate"] - for key in info: - if key not in excluded_list: - utils.add_summary(summary_writer, global_step, key, info[key]) + """Add stuffs in info to summaries.""" + excluded_list = ["learning_rate"] + for key in info: + if key not in excluded_list: + utils.add_summary(summary_writer, global_step, key, info[key]) def process_stats(stats, info, global_step, steps_per_stats, log_f): - """Update info and check for overflow.""" - # Per-step info - info["avg_step_time"] = stats["step_time"] / steps_per_stats - info["avg_grad_norm"] = stats["grad_norm"] / steps_per_stats - info["avg_sequence_count"] = stats["sequence_count"] / steps_per_stats - info["speed"] = stats["word_count"] / (1000 * stats["step_time"]) + """Update info and check for overflow.""" + # Per-step info + info["avg_step_time"] = stats["step_time"] / steps_per_stats + info["avg_grad_norm"] = stats["grad_norm"] / steps_per_stats + info["avg_sequence_count"] = stats["sequence_count"] / steps_per_stats + info["speed"] = stats["word_count"] / (1000 * stats["step_time"]) - # Per-predict info - info["train_ppl"] = ( - utils.safe_exp(stats["train_loss"] / stats["predict_count"])) + # Per-predict info + info["train_ppl"] = ( + utils.safe_exp(stats["train_loss"] / stats["predict_count"])) - # Check for overflow - is_overflow = False - train_ppl = info["train_ppl"] - if math.isnan(train_ppl) or math.isinf(train_ppl) or train_ppl > 1e20: - utils.print_out(" step %d overflow, stop early" % global_step, - log_f) - is_overflow = True + # Check for overflow + is_overflow = False + train_ppl = info["train_ppl"] + if math.isnan(train_ppl) or math.isinf(train_ppl) or train_ppl > 1e20: + utils.print_out(" step %d overflow, stop early" % global_step, + log_f) + is_overflow = True - return is_overflow + return is_overflow def before_train(loaded_train_model, train_model, train_sess, global_step, hparams, log_f): - """Misc tasks to do before training.""" - stats = init_stats() - info = {"train_ppl": 0.0, "speed": 0.0, - "avg_step_time": 0.0, - "avg_grad_norm": 0.0, - "avg_sequence_count": 0.0, - "learning_rate": loaded_train_model.learning_rate.eval( - session=train_sess)} - start_train_time = time.time() - utils.print_out("# Start step %d, lr %g, %s" % - (global_step, info["learning_rate"], time.ctime()), log_f) - - # Initialize all of the iterators - skip_count = hparams.batch_size * hparams.epoch_step - utils.print_out("# Init train iterator, skipping %d elements" % skip_count) - train_sess.run( - train_model.iterator.initializer, - feed_dict={train_model.skip_count_placeholder: skip_count}) - - return stats, info, start_train_time + """Misc tasks to do before training.""" + stats = init_stats() + info = {"train_ppl": 0.0, "speed": 0.0, + "avg_step_time": 0.0, + "avg_grad_norm": 0.0, + "avg_sequence_count": 0.0, + "learning_rate": loaded_train_model.learning_rate.eval( + session=train_sess)} + start_train_time = time.time() + utils.print_out("# Start step %d, lr %g, %s" % + (global_step, info["learning_rate"], time.ctime()), log_f) + + # Initialize all of the iterators + skip_count = hparams.batch_size * hparams.epoch_step + utils.print_out("# Init train iterator, skipping %d elements" % skip_count) + train_sess.run( + train_model.iterator.initializer, + feed_dict={train_model.skip_count_placeholder: skip_count}) + + return stats, info, start_train_time def get_model_creator(hparams): - """Get the right model class depending on configuration.""" - if (hparams.encoder_type == "gnmt" or - hparams.attention_architecture in ["gnmt", "gnmt_v2"]): - model_creator = gnmt_model.GNMTModel - elif hparams.attention and hparams.attention_architecture == "standard": - model_creator = attention_model.AttentionModel - elif not hparams.attention: - model_creator = nmt_model.Model - else: - raise ValueError("Unknown attention architecture %s" % - hparams.attention_architecture) - return model_creator + """Get the right model class depending on configuration.""" + if (hparams.encoder_type == "gnmt" or + hparams.attention_architecture in ["gnmt", "gnmt_v2"]): + model_creator = gnmt_model.GNMTModel + elif hparams.attention and hparams.attention_architecture == "standard": + model_creator = attention_model.AttentionModel + elif not hparams.attention: + model_creator = nmt_model.Model + else: + raise ValueError("Unknown attention architecture %s" % + hparams.attention_architecture) + return model_creator def train(hparams, scope=None, target_session=""): - """Train a translation model.""" - log_device_placement = hparams.log_device_placement - out_dir = hparams.out_dir - num_train_steps = hparams.num_train_steps - steps_per_stats = hparams.steps_per_stats - steps_per_external_eval = hparams.steps_per_external_eval - steps_per_eval = 10 * steps_per_stats - avg_ckpts = hparams.avg_ckpts - - if not steps_per_external_eval: - steps_per_external_eval = 5 * steps_per_eval - - # Create model - model_creator = get_model_creator(hparams) - train_model = model_helper.create_train_model(model_creator, hparams, scope) - eval_model = model_helper.create_eval_model(model_creator, hparams, scope) - infer_model = model_helper.create_infer_model(model_creator, hparams, scope) - - # Preload data for sample decoding. - dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src) - dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt) - sample_src_data = inference.load_data(dev_src_file) - sample_tgt_data = inference.load_data(dev_tgt_file) - - summary_name = "train_log" - model_dir = hparams.out_dir - - # Log and output files - log_file = os.path.join(out_dir, "log_%d" % time.time()) - log_f = tf.gfile.GFile(log_file, mode="a") - utils.print_out("# log_file=%s" % log_file, log_f) - - # TensorFlow model - config_proto = utils.get_config_proto( - log_device_placement=log_device_placement, - num_intra_threads=hparams.num_intra_threads, - num_inter_threads=hparams.num_inter_threads) - train_sess = tf.Session( - target=target_session, config=config_proto, graph=train_model.graph) - eval_sess = tf.Session( - target=target_session, config=config_proto, graph=eval_model.graph) - infer_sess = tf.Session( - target=target_session, config=config_proto, graph=infer_model.graph) - - with train_model.graph.as_default(): - loaded_train_model, global_step = model_helper.create_or_load_model( - train_model.model, model_dir, train_sess, "train") - - # Summary writer - summary_writer = tf.summary.FileWriter( - os.path.join(out_dir, summary_name), train_model.graph) - - # First evaluation - run_full_eval( - model_dir, infer_model, infer_sess, - eval_model, eval_sess, hparams, - summary_writer, sample_src_data, - sample_tgt_data, avg_ckpts) - - last_stats_step = global_step - last_eval_step = global_step - last_external_eval_step = global_step - - # This is the training loop. - stats, info, start_train_time = before_train( - loaded_train_model, train_model, train_sess, global_step, hparams, log_f) - while global_step < num_train_steps: - ### Run a step ### - start_time = time.time() - try: - step_result = loaded_train_model.train(train_sess) - hparams.epoch_step += 1 - except tf.errors.OutOfRangeError: - # Finished going through the training dataset. Go to next epoch. - hparams.epoch_step = 0 - utils.print_out( - "# Finished an epoch, step %d. Perform external evaluation" % - global_step) - run_sample_decode(infer_model, infer_sess, model_dir, hparams, - summary_writer, sample_src_data, sample_tgt_data) - run_external_eval(infer_model, infer_sess, model_dir, hparams, - summary_writer) - - if avg_ckpts: - run_avg_external_eval(infer_model, infer_sess, model_dir, hparams, - summary_writer, global_step) - - train_sess.run( - train_model.iterator.initializer, - feed_dict={train_model.skip_count_placeholder: 0}) - continue - - # Process step_result, accumulate stats, and write summary - global_step, info["learning_rate"], step_summary = update_stats( - stats, start_time, step_result) - summary_writer.add_summary(step_summary, global_step) - - # Once in a while, we print statistics. - if global_step - last_stats_step >= steps_per_stats: - last_stats_step = global_step - is_overflow = process_stats( - stats, info, global_step, steps_per_stats, log_f) - print_step_info(" ", global_step, info, get_best_results(hparams), - log_f) - if is_overflow: - break - - # Reset statistics - stats = init_stats() - - if global_step - last_eval_step >= steps_per_eval: - last_eval_step = global_step - utils.print_out("# Save eval, global step %d" % global_step) - add_info_summaries(summary_writer, global_step, info) - - # Save checkpoint - loaded_train_model.saver.save( - train_sess, - os.path.join(out_dir, "translate.ckpt"), - global_step=global_step) - - # Evaluate on dev/test - run_sample_decode(infer_model, infer_sess, - model_dir, hparams, summary_writer, sample_src_data, - sample_tgt_data) - run_internal_eval( - eval_model, eval_sess, model_dir, hparams, summary_writer) - - if global_step - last_external_eval_step >= steps_per_external_eval: - last_external_eval_step = global_step - - # Save checkpoint - loaded_train_model.saver.save( - train_sess, - os.path.join(out_dir, "translate.ckpt"), - global_step=global_step) - run_sample_decode(infer_model, infer_sess, - model_dir, hparams, summary_writer, sample_src_data, - sample_tgt_data) - run_external_eval( - infer_model, infer_sess, model_dir, - hparams, summary_writer) - - if avg_ckpts: - run_avg_external_eval(infer_model, infer_sess, model_dir, hparams, - summary_writer, global_step) - - # Done training - loaded_train_model.saver.save( - train_sess, - os.path.join(out_dir, "translate.ckpt"), - global_step=global_step) - - (result_summary, _, final_eval_metrics) = ( - run_full_eval( - model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams, - summary_writer, sample_src_data, sample_tgt_data, avg_ckpts)) - print_step_info("# Final, ", global_step, info, result_summary, log_f) - utils.print_time("# Done training!", start_train_time) - - summary_writer.close() - - utils.print_out("# Start evaluating saved best models.") - for metric in hparams.metrics: - best_model_dir = getattr(hparams, "best_" + metric + "_dir") + """Train a translation model.""" + log_device_placement = hparams.log_device_placement + out_dir = hparams.out_dir + num_train_steps = hparams.num_train_steps + steps_per_stats = hparams.steps_per_stats + steps_per_external_eval = hparams.steps_per_external_eval + steps_per_eval = 10 * steps_per_stats + avg_ckpts = hparams.avg_ckpts + + if not steps_per_external_eval: + steps_per_external_eval = 5 * steps_per_eval + + # Create model + model_creator = get_model_creator(hparams) + train_model = model_helper.create_train_model(model_creator, hparams, scope) + eval_model = model_helper.create_eval_model(model_creator, hparams, scope) + infer_model = model_helper.create_infer_model(model_creator, hparams, scope) + + # Preload data for sample decoding. + dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src) + dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt) + sample_src_data = inference.load_data(dev_src_file) + sample_tgt_data = inference.load_data(dev_tgt_file) + + summary_name = "train_log" + model_dir = hparams.out_dir + + # Log and output files + log_file = os.path.join(out_dir, "log_%d" % time.time()) + log_f = tf.gfile.GFile(log_file, mode="a") + utils.print_out("# log_file=%s" % log_file, log_f) + + # TensorFlow model + config_proto = utils.get_config_proto( + log_device_placement=log_device_placement, + num_intra_threads=hparams.num_intra_threads, + num_inter_threads=hparams.num_inter_threads) + train_sess = tf.Session( + target=target_session, config=config_proto, graph=train_model.graph) + eval_sess = tf.Session( + target=target_session, config=config_proto, graph=eval_model.graph) + infer_sess = tf.Session( + target=target_session, config=config_proto, graph=infer_model.graph) + + with train_model.graph.as_default(): + loaded_train_model, global_step = model_helper.create_or_load_model( + train_model.model, model_dir, train_sess, "train") + + # Summary writer summary_writer = tf.summary.FileWriter( - os.path.join(best_model_dir, summary_name), infer_model.graph) - result_summary, best_global_step, _ = run_full_eval( - best_model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams, - summary_writer, sample_src_data, sample_tgt_data) - print_step_info("# Best %s, " % metric, best_global_step, info, - result_summary, log_f) - summary_writer.close() + os.path.join(out_dir, summary_name), train_model.graph) + + # First evaluation + run_full_eval( + model_dir, infer_model, infer_sess, + eval_model, eval_sess, hparams, + summary_writer, sample_src_data, + sample_tgt_data, avg_ckpts) + + last_stats_step = global_step + last_eval_step = global_step + last_external_eval_step = global_step + + # This is the training loop. + stats, info, start_train_time = before_train( + loaded_train_model, train_model, train_sess, global_step, hparams, log_f) + while global_step < num_train_steps: + ### Run a step ### + start_time = time.time() + try: + step_result = loaded_train_model.train(train_sess) + hparams.epoch_step += 1 + except tf.errors.OutOfRangeError: + # Finished going through the training dataset. Go to next epoch. + hparams.epoch_step = 0 + utils.print_out( + "# Finished an epoch, step %d. Perform external evaluation" % + global_step) + run_sample_decode(infer_model, infer_sess, model_dir, hparams, + summary_writer, sample_src_data, sample_tgt_data) + run_external_eval(infer_model, infer_sess, model_dir, hparams, + summary_writer) + + if avg_ckpts: + run_avg_external_eval(infer_model, infer_sess, model_dir, hparams, + summary_writer, global_step) + + train_sess.run( + train_model.iterator.initializer, + feed_dict={train_model.skip_count_placeholder: 0}) + continue + + # Process step_result, accumulate stats, and write summary + global_step, info["learning_rate"], step_summary = update_stats( + stats, start_time, step_result) + summary_writer.add_summary(step_summary, global_step) + + # Once in a while, we print statistics. + if global_step - last_stats_step >= steps_per_stats: + last_stats_step = global_step + is_overflow = process_stats( + stats, info, global_step, steps_per_stats, log_f) + print_step_info(" ", global_step, info, get_best_results(hparams), + log_f) + if is_overflow: + break + + # Reset statistics + stats = init_stats() + + if global_step - last_eval_step >= steps_per_eval: + last_eval_step = global_step + utils.print_out("# Save eval, global step %d" % global_step) + add_info_summaries(summary_writer, global_step, info) + + # Save checkpoint + loaded_train_model.saver.save( + train_sess, + os.path.join(out_dir, "translate.ckpt"), + global_step=global_step) + + # Evaluate on dev/test + run_sample_decode(infer_model, infer_sess, + model_dir, hparams, summary_writer, sample_src_data, + sample_tgt_data) + run_internal_eval( + eval_model, eval_sess, model_dir, hparams, summary_writer) + + if global_step - last_external_eval_step >= steps_per_external_eval: + last_external_eval_step = global_step + + # Save checkpoint + loaded_train_model.saver.save( + train_sess, + os.path.join(out_dir, "translate.ckpt"), + global_step=global_step) + run_sample_decode(infer_model, infer_sess, + model_dir, hparams, summary_writer, sample_src_data, + sample_tgt_data) + run_external_eval( + infer_model, infer_sess, model_dir, + hparams, summary_writer) + + if avg_ckpts: + run_avg_external_eval(infer_model, infer_sess, model_dir, hparams, + summary_writer, global_step) + + # Done training + loaded_train_model.saver.save( + train_sess, + os.path.join(out_dir, "translate.ckpt"), + global_step=global_step) + + (result_summary, _, final_eval_metrics) = ( + run_full_eval( + model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams, + summary_writer, sample_src_data, sample_tgt_data, avg_ckpts)) + print_step_info("# Final, ", global_step, info, result_summary, log_f) + utils.print_time("# Done training!", start_train_time) - if avg_ckpts: - best_model_dir = getattr(hparams, "avg_best_" + metric + "_dir") - summary_writer = tf.summary.FileWriter( - os.path.join(best_model_dir, summary_name), infer_model.graph) - result_summary, best_global_step, _ = run_full_eval( - best_model_dir, infer_model, infer_sess, eval_model, eval_sess, - hparams, summary_writer, sample_src_data, sample_tgt_data) - print_step_info("# Averaged Best %s, " % metric, best_global_step, info, - result_summary, log_f) - summary_writer.close() + summary_writer.close() - return final_eval_metrics, global_step + utils.print_out("# Start evaluating saved best models.") + for metric in hparams.metrics: + best_model_dir = getattr(hparams, "best_" + metric + "_dir") + summary_writer = tf.summary.FileWriter( + os.path.join(best_model_dir, summary_name), infer_model.graph) + result_summary, best_global_step, _ = run_full_eval( + best_model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams, + summary_writer, sample_src_data, sample_tgt_data) + print_step_info("# Best %s, " % metric, best_global_step, info, + result_summary, log_f) + summary_writer.close() + + if avg_ckpts: + best_model_dir = getattr(hparams, "avg_best_" + metric + "_dir") + summary_writer = tf.summary.FileWriter( + os.path.join(best_model_dir, summary_name), infer_model.graph) + result_summary, best_global_step, _ = run_full_eval( + best_model_dir, infer_model, infer_sess, eval_model, eval_sess, + hparams, summary_writer, sample_src_data, sample_tgt_data) + print_step_info("# Averaged Best %s, " % metric, best_global_step, info, + result_summary, log_f) + summary_writer.close() + + return final_eval_metrics, global_step def _format_results(name, ppl, scores, metrics): - """Format results.""" - result_str = "" - if ppl: - result_str = "%s ppl %.2f" % (name, ppl) - if scores: - for metric in metrics: - if result_str: - result_str += ", %s %s %.1f" % (name, metric, scores[metric]) - else: - result_str = "%s %s %.1f" % (name, metric, scores[metric]) - return result_str + """Format results.""" + result_str = "" + if ppl: + result_str = "%s ppl %.2f" % (name, ppl) + if scores: + for metric in metrics: + if result_str: + result_str += ", %s %s %.1f" % (name, metric, scores[metric]) + else: + result_str = "%s %s %.1f" % (name, metric, scores[metric]) + return result_str def get_best_results(hparams): - """Summary of the current best results.""" - tokens = [] - for metric in hparams.metrics: - tokens.append("%s %.2f" % (metric, getattr(hparams, "best_" + metric))) - return ", ".join(tokens) + """Summary of the current best results.""" + tokens = [] + for metric in hparams.metrics: + tokens.append("%s %.2f" % (metric, getattr(hparams, "best_" + metric))) + return ", ".join(tokens) def _internal_eval(model, global_step, sess, iterator, iterator_feed_dict, summary_writer, label): - """Computing perplexity.""" - sess.run(iterator.initializer, feed_dict=iterator_feed_dict) - ppl = model_helper.compute_perplexity(model, sess, label) - utils.add_summary(summary_writer, global_step, "%s_ppl" % label, ppl) - return ppl + """Computing perplexity.""" + sess.run(iterator.initializer, feed_dict=iterator_feed_dict) + ppl = model_helper.compute_perplexity(model, sess, label) + utils.add_summary(summary_writer, global_step, "%s_ppl" % label, ppl) + return ppl def _sample_decode(model, global_step, sess, hparams, iterator, src_data, tgt_data, iterator_src_placeholder, iterator_batch_size_placeholder, summary_writer): - """Pick a sentence and decode.""" - decode_id = random.randint(0, len(src_data) - 1) - utils.print_out(" # %d" % decode_id) + """Pick a sentence and decode.""" + decode_id = random.randint(0, len(src_data) - 1) + utils.print_out(" # %d" % decode_id) - iterator_feed_dict = { - iterator_src_placeholder: [src_data[decode_id]], - iterator_batch_size_placeholder: 1, - } - sess.run(iterator.initializer, feed_dict=iterator_feed_dict) + iterator_feed_dict = { + iterator_src_placeholder: [src_data[decode_id]], + iterator_batch_size_placeholder: 1, + } + sess.run(iterator.initializer, feed_dict=iterator_feed_dict) - nmt_outputs, attention_summary = model.decode(sess) + nmt_outputs, attention_summary = model.decode(sess) - if hparams.infer_mode == "beam_search": - # get the top translation. - nmt_outputs = nmt_outputs[0] + if hparams.infer_mode == "beam_search": + # get the top translation. + nmt_outputs = nmt_outputs[0] - translation = nmt_utils.get_translation( - nmt_outputs, - sent_id=0, - tgt_eos=hparams.eos, - subword_option=hparams.subword_option) - utils.print_out(" src: %s" % src_data[decode_id]) - utils.print_out(" ref: %s" % tgt_data[decode_id]) - utils.print_out(b" nmt: " + translation) + translation = nmt_utils.get_translation( + nmt_outputs, + sent_id=0, + tgt_eos=hparams.eos, + subword_option=hparams.subword_option) + utils.print_out(" src: %s" % src_data[decode_id]) + utils.print_out(" ref: %s" % tgt_data[decode_id]) + utils.print_out(b" nmt: " + translation) - # Summary - if attention_summary is not None: - summary_writer.add_summary(attention_summary, global_step) + # Summary + if attention_summary is not None: + summary_writer.add_summary(attention_summary, global_step) def _external_eval(model, global_step, sess, hparams, iterator, iterator_feed_dict, tgt_file, label, summary_writer, save_on_best, avg_ckpts=False): - """External evaluation such as BLEU and ROUGE scores.""" - out_dir = hparams.out_dir - decode = global_step > 0 - - if avg_ckpts: - label = "avg_" + label - - if decode: - utils.print_out("# External evaluation, global step %d" % global_step) - - sess.run(iterator.initializer, feed_dict=iterator_feed_dict) - - output = os.path.join(out_dir, "output_%s" % label) - scores = nmt_utils.decode_and_evaluate( - label, - model, - sess, - output, - ref_file=tgt_file, - metrics=hparams.metrics, - subword_option=hparams.subword_option, - beam_width=hparams.beam_width, - tgt_eos=hparams.eos, - decode=decode, - infer_mode=hparams.infer_mode) - # Save on best metrics - if decode: - for metric in hparams.metrics: - if avg_ckpts: - best_metric_label = "avg_best_" + metric - else: - best_metric_label = "best_" + metric - - utils.add_summary(summary_writer, global_step, "%s_%s" % (label, metric), - scores[metric]) - # metric: larger is better - if save_on_best and scores[metric] > getattr(hparams, best_metric_label): - setattr(hparams, best_metric_label, scores[metric]) - model.saver.save( - sess, - os.path.join( - getattr(hparams, best_metric_label + "_dir"), "translate.ckpt"), - global_step=model.global_step) - utils.save_hparams(out_dir, hparams) - return scores + """External evaluation such as BLEU and ROUGE scores.""" + out_dir = hparams.out_dir + decode = global_step > 0 + + if avg_ckpts: + label = "avg_" + label + + if decode: + utils.print_out("# External evaluation, global step %d" % global_step) + + sess.run(iterator.initializer, feed_dict=iterator_feed_dict) + + output = os.path.join(out_dir, "output_%s" % label) + scores = nmt_utils.decode_and_evaluate( + label, + model, + sess, + output, + ref_file=tgt_file, + metrics=hparams.metrics, + subword_option=hparams.subword_option, + beam_width=hparams.beam_width, + tgt_eos=hparams.eos, + decode=decode, + infer_mode=hparams.infer_mode) + # Save on best metrics + if decode: + for metric in hparams.metrics: + if avg_ckpts: + best_metric_label = "avg_best_" + metric + else: + best_metric_label = "best_" + metric + + utils.add_summary(summary_writer, global_step, "%s_%s" % (label, metric), + scores[metric]) + # metric: larger is better + if save_on_best and scores[metric] > getattr(hparams, best_metric_label): + setattr(hparams, best_metric_label, scores[metric]) + model.saver.save( + sess, + os.path.join( + getattr(hparams, best_metric_label + "_dir"), "translate.ckpt"), + global_step=model.global_step) + utils.save_hparams(out_dir, hparams) + return scores diff --git a/text_classification/data.py b/text_classification/data.py index 0530c40..d3c643f 100644 --- a/text_classification/data.py +++ b/text_classification/data.py @@ -4,109 +4,110 @@ def get_arrays(data_dir): - data = np.load(os.path.join(data_dir, "data-simplified.npz")) - try: - X_train = data["X_train"].astype(np.int32) - except: # For spitted large X_train, such as in Amazon_review_polarity - X_train_part1 = data["X_train_part1"].astype(np.int32) - X_train_part2 = data["X_train_part2"].astype(np.int32) - X_train = np.concatenate((X_train_part1, X_train_part2)) - X_test = data["X_test"].astype(np.int32) - y_train = data["y_train"].astype(np.int32) - y_test = data["y_test"].astype(np.int32) - vocab = data["vocab"] - return X_train, y_train, X_test, y_test, vocab + data = np.load(os.path.join(data_dir, "data-simplified.npz")) + try: + X_train = data["X_train"].astype(np.int32) + except: # For spitted large X_train, such as in Amazon_review_polarity + X_train_part1 = data["X_train_part1"].astype(np.int32) + X_train_part2 = data["X_train_part2"].astype(np.int32) + X_train = np.concatenate((X_train_part1, X_train_part2)) + X_test = data["X_test"].astype(np.int32) + y_train = data["y_train"].astype(np.int32) + y_test = data["y_test"].astype(np.int32) + vocab = data["vocab"] + return X_train, y_train, X_test, y_test, vocab def batchify_small(X, y, batch_size, num_epochs, reinitializer, is_train): - """Cannot deal with large X.""" - y = tf.one_hot(y, len(set(y))) # Transform to one-hot. - dataset = tf.data.Dataset.from_tensor_slices((X, y)) - if is_train: - dataset = dataset.shuffle(buffer_size=X.shape[0]) - dataset = dataset.batch(batch_size).repeat(num_epochs) + """Cannot deal with large X.""" + y = tf.one_hot(y, len(set(y))) # Transform to one-hot. + dataset = tf.data.Dataset.from_tensor_slices((X, y)) + if is_train: + dataset = dataset.shuffle(buffer_size=X.shape[0]) + dataset = dataset.batch(batch_size).repeat(num_epochs) - if reinitializer: - iterator = tf.data.Iterator.from_structure( - dataset.output_types, dataset.output_shapes) - initializer = iterator.make_initializer(dataset) - X, y = iterator.get_next() - else: - X, y = dataset.make_one_shot_iterator().get_next() - initializer = None - return X, y, initializer + if reinitializer: + iterator = tf.data.Iterator.from_structure( + dataset.output_types, dataset.output_shapes) + initializer = iterator.make_initializer(dataset) + X, y = iterator.get_next() + else: + X, y = dataset.make_one_shot_iterator().get_next() + initializer = None + return X, y, initializer def get_data_small(data_name, data_dir, batch_size): - """Cannot deal with large X.""" - X_train, y_train, X_test, y_test, vocab = get_arrays(data_dir) - X_train, y_train, _ = batchify_small( - X_train, y_train, batch_size, None, False, True) - X_test, y_test, reinitializer = batchify_small( - X_test, y_test, batch_size, 1, True, False) - return X_train, y_train, X_test, y_test, reinitializer, vocab + """Cannot deal with large X.""" + X_train, y_train, X_test, y_test, vocab = get_arrays(data_dir) + X_train, y_train, _ = batchify_small( + X_train, y_train, batch_size, None, False, True) + X_test, y_test, reinitializer = batchify_small( + X_test, y_test, batch_size, 1, True, False) + return X_train, y_train, X_test, y_test, reinitializer, vocab def batchify(X, y, batch_size, num_epochs, reinitializer, is_train): - y = tf.one_hot(y, len(set(y))) # Transform to one-hot. - N = X.shape[0] - dataset = tf.data.Dataset.range(N) - if is_train: - dataset = dataset.shuffle(buffer_size=N) - dataset = dataset.batch(batch_size).repeat(num_epochs) + y = tf.one_hot(y, len(set(y))) # Transform to one-hot. + N = X.shape[0] + dataset = tf.data.Dataset.range(N) + if is_train: + dataset = dataset.shuffle(buffer_size=N) + dataset = dataset.batch(batch_size).repeat(num_epochs) - if reinitializer: - iterator = tf.data.Iterator.from_structure( - dataset.output_types, dataset.output_shapes) - initializer = iterator.make_initializer(dataset) - idxs = iterator.get_next() - else: - idxs = dataset.make_one_shot_iterator().get_next() - initializer = None + if reinitializer: + iterator = tf.data.Iterator.from_structure( + dataset.output_types, dataset.output_shapes) + initializer = iterator.make_initializer(dataset) + idxs = iterator.get_next() + else: + idxs = dataset.make_one_shot_iterator().get_next() + initializer = None - X, y = tf.nn.embedding_lookup(X, idxs), tf.nn.embedding_lookup(y, idxs) - return X, y, initializer + X, y = tf.nn.embedding_lookup(X, idxs), tf.nn.embedding_lookup(y, idxs) + return X, y, initializer def get_data(data_name, data_dir, batch_size): - X_train_d, y_train, X_test_d, y_test, vocab = get_arrays(data_dir) - num_classes = len(set(y_train)) - with tf.device("/cpu:0"): # DEBUG - X_train_holder = tf.placeholder(tf.int32, shape=X_train_d.shape) - X_train = tf.Variable(X_train_holder, trainable=False) - X_test_holder = tf.placeholder(tf.int32, shape=X_test_d.shape) - X_test = tf.Variable(X_test_holder, trainable=False) - X_train, y_train, _ = batchify( - X_train, y_train, batch_size, None, False, True) - X_test, y_test, reinitializer = batchify( - X_test, y_test, batch_size, 1, True, False) - return (X_train_d, X_test_d, X_train_holder, X_test_holder, - X_train, y_train, X_test, y_test, reinitializer, vocab) + X_train_d, y_train, X_test_d, y_test, vocab = get_arrays(data_dir) + num_classes = len(set(y_train)) + with tf.device("/cpu:0"): # DEBUG + X_train_holder = tf.placeholder(tf.int32, shape=X_train_d.shape) + X_train = tf.Variable(X_train_holder, trainable=False) + X_test_holder = tf.placeholder(tf.int32, shape=X_test_d.shape) + X_test = tf.Variable(X_test_holder, trainable=False) + X_train, y_train, _ = batchify( + X_train, y_train, batch_size, None, False, True) + X_test, y_test, reinitializer = batchify( + X_test, y_test, batch_size, 1, True, False) + return (X_train_d, X_test_d, X_train_holder, X_test_holder, + X_train, y_train, X_test, y_test, reinitializer, vocab) if __name__ == "__main__": - import time - home = os.path.expanduser("~") - # data_name = "ag_news" - data_name = "yahoo_answers" - # data_name = "yelp_review_full" - batch_size = 128 - data_dir = os.path.join( - home, "corpus/text_classification_char_cnn/%s_csv" % data_name) - (X_train_d, X_test_d, X_train_holder, X_test_holder, X_train, y_train, - X_test, y_test, test_reinitializer, vocab) = get_data( - data_name, data_dir, batch_size) - with tf.Session() as sess: - sess.run(tf.global_variables_initializer(), - feed_dict={X_train_holder: X_train_d, X_test_holder: X_test_d}) - del X_train_d, X_test_d - sess.run(test_reinitializer) - start = time.time() - for _ in range(1000): - _ = sess.run([X_train, y_train]) - print("duration: ", time.time() - start) - train_results = sess.run([X_train, y_train]) - test_results = sess.run([X_test, y_test]) - get_info = lambda sent: " ".join([vocab[each] for each in sent]) - import pdb; pdb.set_trace() - print("ok") + import time + home = os.path.expanduser("~") + # data_name = "ag_news" + data_name = "yahoo_answers" + # data_name = "yelp_review_full" + batch_size = 128 + data_dir = os.path.join( + home, "corpus/text_classification_char_cnn/%s_csv" % data_name) + (X_train_d, X_test_d, X_train_holder, X_test_holder, X_train, y_train, + X_test, y_test, test_reinitializer, vocab) = get_data( + data_name, data_dir, batch_size) + with tf.Session() as sess: + sess.run(tf.global_variables_initializer(), + feed_dict={X_train_holder: X_train_d, X_test_holder: X_test_d}) + del X_train_d, X_test_d + sess.run(test_reinitializer) + start = time.time() + for _ in range(1000): + _ = sess.run([X_train, y_train]) + print("duration: ", time.time() - start) + train_results = sess.run([X_train, y_train]) + test_results = sess.run([X_test, y_test]) + def get_info(sent): return " ".join([vocab[each] for each in sent]) + import pdb + pdb.set_trace() + print("ok") diff --git a/text_classification/main.py b/text_classification/main.py index 29e08de..6c67b67 100644 --- a/text_classification/main.py +++ b/text_classification/main.py @@ -1,3 +1,6 @@ +import util +from model import Model +import data import os import sys import time @@ -6,10 +9,7 @@ import cPickle as pickle os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # TODO: avoid the OutOfRangeError msg. -import data -from model import Model sys.path.insert(0, "../lm") -import util flags = tf.flags FLAGS = flags.FLAGS @@ -35,69 +35,70 @@ flags.DEFINE_bool("kdq_share_subspace", False, "whether to share subspace") flags.DEFINE_bool("additive_quantization", False, "only work with smx") -def main(_): - (X_train_d, X_test_d, X_train_holder, X_test_holder, X_train, y_train, - X_test, y_test, test_reinitializer, vocab) = data.get_data( - FLAGS.dataset, FLAGS.data_dir, FLAGS.batch_size) - flags.DEFINE_integer("vocab_size", len(vocab), 'Auto add vocab size') - with tf.name_scope("Train"): - with tf.variable_scope("model", reuse=False): - m = Model() - loss_train, preds_train, train_op = m.forward( - X_train, y_train, is_training=True) - with tf.name_scope("Test"): - with tf.variable_scope("model", reuse=True): - loss_test, preds_test, _ = m.forward( - X_test, y_test, is_training=False) +def main(_): + (X_train_d, X_test_d, X_train_holder, X_test_holder, X_train, y_train, + X_test, y_test, test_reinitializer, vocab) = data.get_data( + FLAGS.dataset, FLAGS.data_dir, FLAGS.batch_size) + flags.DEFINE_integer("vocab_size", len(vocab), 'Auto add vocab size') - # Verbose. - print("FLAGS:") - for key, value in tf.flags.FLAGS.__flags.items(): - print(key, value._value) - print("Number of trainable params: {}".format(util.get_parameter_count())) - print(tf.trainable_variables()) + with tf.name_scope("Train"): + with tf.variable_scope("model", reuse=False): + m = Model() + loss_train, preds_train, train_op = m.forward( + X_train, y_train, is_training=True) + with tf.name_scope("Test"): + with tf.variable_scope("model", reuse=True): + loss_test, preds_test, _ = m.forward( + X_test, y_test, is_training=False) - # Training session. - init_feed_dict = {X_train_holder: X_train_d, X_test_holder: X_test_d} - sv = tf.train.Supervisor(saver=None, init_feed_dict=init_feed_dict) - config_proto = tf.ConfigProto(allow_soft_placement=True) - config_proto.gpu_options.allow_growth = True - with sv.managed_session(config=config_proto) as sess: - del X_train_d, X_test_d - print("Start training") - # Training loop. - losses_train_ = [] - test_accs = [] - start = time.time() - for it in range(1 + FLAGS.max_iter): - loss_train_, _ = sess.run([loss_train, train_op]) - losses_train_.append(loss_train_) + # Verbose. + print("FLAGS:") + for key, value in tf.flags.FLAGS.__flags.items(): + print(key, value._value) + print("Number of trainable params: {}".format(util.get_parameter_count())) + print(tf.trainable_variables()) - # Evaluation. - if it % FLAGS.eval_every_num_iter == 0: - sess.run(test_reinitializer) - train_hits = [] - test_hits = [] - for _ in range(100): - results = sess.run([y_train, preds_train]) - train_hits.append(np.argmax(results[0], 1) == results[1]) - train_accuracy = np.concatenate(train_hits).mean() - while True: - try: - results = sess.run([y_test, preds_test]) - test_hits.append(np.argmax(results[0], 1) == results[1]) - except tf.errors.OutOfRangeError: - break - test_accuracy = np.concatenate(test_hits).mean() - test_accs.append(test_accuracy) - end = time.time() - print("Iter {:6}, {:.3} (secs), loss {:.4}, train acc {:.4} test acc {:.4}".format( - it, end - start, np.mean(losses_train_), train_accuracy, test_accuracy)) + # Training session. + init_feed_dict = {X_train_holder: X_train_d, X_test_holder: X_test_d} + sv = tf.train.Supervisor(saver=None, init_feed_dict=init_feed_dict) + config_proto = tf.ConfigProto(allow_soft_placement=True) + config_proto.gpu_options.allow_growth = True + with sv.managed_session(config=config_proto) as sess: + del X_train_d, X_test_d + print("Start training") + # Training loop. losses_train_ = [] + test_accs = [] start = time.time() - print("Best test accuracy {}".format(max(test_accs))) + for it in range(1 + FLAGS.max_iter): + loss_train_, _ = sess.run([loss_train, train_op]) + losses_train_.append(loss_train_) + + # Evaluation. + if it % FLAGS.eval_every_num_iter == 0: + sess.run(test_reinitializer) + train_hits = [] + test_hits = [] + for _ in range(100): + results = sess.run([y_train, preds_train]) + train_hits.append(np.argmax(results[0], 1) == results[1]) + train_accuracy = np.concatenate(train_hits).mean() + while True: + try: + results = sess.run([y_test, preds_test]) + test_hits.append(np.argmax(results[0], 1) == results[1]) + except tf.errors.OutOfRangeError: + break + test_accuracy = np.concatenate(test_hits).mean() + test_accs.append(test_accuracy) + end = time.time() + print("Iter {:6}, {:.3} (secs), loss {:.4}, train acc {:.4} test acc {:.4}".format( + it, end - start, np.mean(losses_train_), train_accuracy, test_accuracy)) + losses_train_ = [] + start = time.time() + print("Best test accuracy {}".format(max(test_accs))) if __name__ == "__main__": - tf.app.run() + tf.app.run() diff --git a/text_classification/model.py b/text_classification/model.py index 1b1a286..c5fbb3c 100644 --- a/text_classification/model.py +++ b/text_classification/model.py @@ -1,3 +1,5 @@ +from kdq_embedding import full_embed, kdq_embed, KDQhparam +from kd_quantizer import KDQuantizer import os import sys import numpy as np @@ -6,77 +8,75 @@ import util parent_path = "/".join(os.getcwd().split('/')[:-1]) sys.path.append(os.path.join(parent_path, "core")) -from kd_quantizer import KDQuantizer -from kdq_embedding import full_embed, kdq_embed, KDQhparam FLAGS = tf.flags.FLAGS class Model(object): - def __init__(self): - pass + def __init__(self): + pass - def forward(self, features, labels, is_training=False): - """Returns loss, preds, train_op. + def forward(self, features, labels, is_training=False): + """Returns loss, preds, train_op. - Args: - features: (batch_size, max_seq_length) - labels: (batch_size, num_classes) + Args: + features: (batch_size, max_seq_length) + labels: (batch_size, num_classes) - Returns: - loss: (batch_size, ) - preds: (batch_size, ) - train_op: op. - """ - num_classes = labels.shape.as_list()[-1] - batch_size = tf.shape(features)[0] - mask = tf.cast(tf.greater(features, 0), tf.float32) # (bs, max_seq_length) - lengths = tf.reduce_sum(mask, axis=1, keepdims=True) # (batch_size, 1) + Returns: + loss: (batch_size, ) + preds: (batch_size, ) + train_op: op. + """ + num_classes = labels.shape.as_list()[-1] + batch_size = tf.shape(features)[0] + mask = tf.cast(tf.greater(features, 0), tf.float32) # (bs, max_seq_length) + lengths = tf.reduce_sum(mask, axis=1, keepdims=True) # (batch_size, 1) - # Embedding - if FLAGS.kdq_type == "none": - inputs = full_embed(features, FLAGS.vocab_size, FLAGS.dims) - else: - kdq_hparam = KDQhparam( - K=FLAGS.K, D=FLAGS.D, kdq_type=FLAGS.kdq_type, - kdq_d_in=FLAGS.kdq_d_in, kdq_share_subspace=FLAGS.kdq_share_subspace, - additive_quantization=FLAGS.additive_quantization) - inputs = kdq_embed( - features, FLAGS.vocab_size, FLAGS.dims, kdq_hparam, is_training) - word_embs = inputs # (bs, length, emb_dim) - word_embs *= tf.expand_dims(mask, -1) + # Embedding + if FLAGS.kdq_type == "none": + inputs = full_embed(features, FLAGS.vocab_size, FLAGS.dims) + else: + kdq_hparam = KDQhparam( + K=FLAGS.K, D=FLAGS.D, kdq_type=FLAGS.kdq_type, + kdq_d_in=FLAGS.kdq_d_in, kdq_share_subspace=FLAGS.kdq_share_subspace, + additive_quantization=FLAGS.additive_quantization) + inputs = kdq_embed( + features, FLAGS.vocab_size, FLAGS.dims, kdq_hparam, is_training) + word_embs = inputs # (bs, length, emb_dim) + word_embs *= tf.expand_dims(mask, -1) - embs_maxpool = tf.reduce_max(word_embs, 1) # Max pooling. - embs_meanpool = tf.reduce_sum(word_embs, 1) / lengths # Mean pooling. - if FLAGS.concat_maxpooling: - embs = tf.concat([embs_meanpool, embs_maxpool], -1) - else: - embs = embs_meanpool - if FLAGS.hidden_layers > 0: - embs = tf.nn.relu( - tf.layers.batch_normalization(embs, training=is_training)) - embs = tf.layers.dense(embs, FLAGS.dims) - embs = tf.nn.relu( - tf.layers.batch_normalization(embs, training=is_training)) - logits = tf.layers.dense(embs, num_classes) - preds = tf.argmax(logits, -1)[:batch_size] - loss = tf.nn.softmax_cross_entropy_with_logits_v2( - labels=labels, logits=logits) + embs_maxpool = tf.reduce_max(word_embs, 1) # Max pooling. + embs_meanpool = tf.reduce_sum(word_embs, 1) / lengths # Mean pooling. + if FLAGS.concat_maxpooling: + embs = tf.concat([embs_meanpool, embs_maxpool], -1) + else: + embs = embs_meanpool + if FLAGS.hidden_layers > 0: + embs = tf.nn.relu( + tf.layers.batch_normalization(embs, training=is_training)) + embs = tf.layers.dense(embs, FLAGS.dims) + embs = tf.nn.relu( + tf.layers.batch_normalization(embs, training=is_training)) + logits = tf.layers.dense(embs, num_classes) + preds = tf.argmax(logits, -1)[:batch_size] + loss = tf.nn.softmax_cross_entropy_with_logits_v2( + labels=labels, logits=logits) - if is_training: - # Regular loss updater. - loss_scalar = tf.reduce_mean(loss) - loss_scalar += FLAGS.reg_weight * tf.reduce_mean(word_embs**2) - loss_scalar += tf.reduce_sum( - tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) - train_op = tf.contrib.layers.optimize_loss( - loss=loss_scalar, - global_step=tf.train.get_or_create_global_step(), - learning_rate=FLAGS.learning_rate, - optimizer=util.get_optimizer(FLAGS.optimizer), - variables=tf.trainable_variables()) - else: - train_op = False - loss_scalar = None + if is_training: + # Regular loss updater. + loss_scalar = tf.reduce_mean(loss) + loss_scalar += FLAGS.reg_weight * tf.reduce_mean(word_embs**2) + loss_scalar += tf.reduce_sum( + tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) + train_op = tf.contrib.layers.optimize_loss( + loss=loss_scalar, + global_step=tf.train.get_or_create_global_step(), + learning_rate=FLAGS.learning_rate, + optimizer=util.get_optimizer(FLAGS.optimizer), + variables=tf.trainable_variables()) + else: + train_op = False + loss_scalar = None - return loss_scalar, preds, train_op + return loss_scalar, preds, train_op diff --git a/text_classification/util.py b/text_classification/util.py index d64c23d..6aa9092 100644 --- a/text_classification/util.py +++ b/text_classification/util.py @@ -29,72 +29,69 @@ def get_activation(name): - """Returns activation function given name.""" - name = name.lower() - if name == "relu": - return tf.nn.relu - elif name == "sigmoid": - return tf.nn.sigmoid - elif name == "tanh": - return tf.nn.sigmoid - elif name == "elu": - return tf.nn.elu - elif name == "linear": - return lambda x: x - else: - raise ValueError("Unknown activation name {}".format(name)) + """Returns activation function given name.""" + name = name.lower() + if name == "relu": + return tf.nn.relu + elif name == "sigmoid": + return tf.nn.sigmoid + elif name == "tanh": + return tf.nn.sigmoid + elif name == "elu": + return tf.nn.elu + elif name == "linear": + return lambda x: x + else: + raise ValueError("Unknown activation name {}".format(name)) - return name + return name def get_optimizer(name): - name = name.lower() - if name == "sgd": - optimizer = tf.train.GradientDescentOptimizer - elif name == "momentum": - optimizer = partial(tf.train.MomentumOptimizer, - momentum=0.05, use_nesterov=True) - elif name == "adam": - optimizer = tf.train.AdamOptimizer - # optimizer = partial(tf.train.AdamOptimizer, beta1=0.5, beta2=0.9) - elif name == "lazy_adam": - optimizer = tf.contrib.opt.LazyAdamOptimizer - # optimizer = partial(tf.contrib.opt.LazyAdamOptimizer, beta1=0.5, beta2=0.9) - elif name == "adagrad": - optimizer = tf.train.AdagradOptimizer - elif name == "rmsprop": - optimizer = tf.train.RMSPropOptimizer - else: - raise ValueError("Unknown optimizer name {}.".format(name)) - - return optimizer - - -def get_parameter_count(excludings=None, display_count=True): - trainables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) - count = 0 - for var in trainables: - ignored = False - if excludings is not None: - for excluding in excludings: - if var.name.find(excluding) >= 0: - ignored = True - break - if ignored: - continue - if var.shape == tf.TensorShape(None): - tf.logging.warn("var {} has unknown shape and it is not counted.".format( - var.name)) - continue - if var.shape.as_list() == []: - count_ = 1 + name = name.lower() + if name == "sgd": + optimizer = tf.train.GradientDescentOptimizer + elif name == "momentum": + optimizer = partial(tf.train.MomentumOptimizer, + momentum=0.05, use_nesterov=True) + elif name == "adam": + optimizer = tf.train.AdamOptimizer + # optimizer = partial(tf.train.AdamOptimizer, beta1=0.5, beta2=0.9) + elif name == "lazy_adam": + optimizer = tf.contrib.opt.LazyAdamOptimizer + # optimizer = partial(tf.contrib.opt.LazyAdamOptimizer, beta1=0.5, beta2=0.9) + elif name == "adagrad": + optimizer = tf.train.AdagradOptimizer + elif name == "rmsprop": + optimizer = tf.train.RMSPropOptimizer else: - count_ = reduce(lambda x, y: x * y, var.shape.as_list()) - count += count_ - if display_count: - print("{0:80} {1}".format( - var.name, count_)) - return count + raise ValueError("Unknown optimizer name {}.".format(name)) + return optimizer +def get_parameter_count(excludings=None, display_count=True): + trainables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) + count = 0 + for var in trainables: + ignored = False + if excludings is not None: + for excluding in excludings: + if var.name.find(excluding) >= 0: + ignored = True + break + if ignored: + continue + if var.shape == tf.TensorShape(None): + tf.logging.warn("var {} has unknown shape and it is not counted.".format( + var.name)) + continue + if var.shape.as_list() == []: + count_ = 1 + else: + count_ = reduce(lambda x, y: x * y, var.shape.as_list()) + count += count_ + if display_count: + print("{0:80} {1}".format( + var.name, count_)) + return count From aba01a17f3bc7d6211094976f0707cc23271a346 Mon Sep 17 00:00:00 2001 From: Sciroccogti Date: Thu, 29 Dec 2022 00:17:57 +0800 Subject: [PATCH 3/3] merge tf2 --- core/kd_quantizer.py | 16 +++--- core/kdq_embedding.py | 8 +-- lm/ptb_word_lm.py | 105 ++++++++++++++++++----------------- lm/reader.py | 12 ++-- lm/util.py | 32 +++++------ text_classification/data.py | 22 ++++---- text_classification/main.py | 22 ++++---- text_classification/model.py | 42 +++++++------- text_classification/util.py | 21 +++---- 9 files changed, 142 insertions(+), 138 deletions(-) diff --git a/core/kd_quantizer.py b/core/kd_quantizer.py index f114e23..aba0149 100644 --- a/core/kd_quantizer.py +++ b/core/kd_quantizer.py @@ -6,12 +6,12 @@ def safer_log(x, eps=1e-10): Note that if x.dtype=tf.float16, \forall eps, eps < 3e-8, is equal to zero. """ - return tf.log(x + eps) + return tf.compat.v1.log(x + eps) def sample_gumbel(shape): """Sample from Gumbel(0, 1)""" - U = tf.random_uniform(shape, minval=0, maxval=1) + U = tf.compat.v1.random_uniform(shape, minval=0, maxval=1) return -safer_log(-safer_log(U)) @@ -51,12 +51,12 @@ def __init__(self, K, D, d_in, d_out, tie_in_n_out, # Create centroids for keys and values. D_to_create = 1 if shared_centroids else D - centroids_k = tf.get_variable( + centroids_k = tf.compat.v1.get_variable( "centroids_k", [D_to_create, K, d_in]) if tie_in_n_out: centroids_v = centroids_k else: - centroids_v = tf.get_variable( + centroids_v = tf.compat.v1.get_variable( "centroids_v", [D_to_create, K, d_out]) if shared_centroids: centroids_k = tf.tile(centroids_k, [D, 1, 1]) @@ -110,7 +110,7 @@ def forward(self, # response = tf.contrib.layers.instance_norm( # response, scale=False, center=False, # trainable=False, data_format="NCHW") - response = tf.layers.batch_normalization( + response = tf.compat.v1.layers.batch_normalization( response, scale=False, center=False, training=is_training) # Layer norm as alternative to BN. # response = tf.contrib.layers.layer_norm( @@ -169,17 +169,17 @@ def forward(self, # entropy regularization # reg = - beta * tf.reduce_mean( # tf.reduce_sum(response_prob * safer_log(response_prob), [2])) - tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, reg) + tf.compat.v1.add_to_collection(tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES, reg) return codes, outputs_final if __name__ == "__main__": # VQ - with tf.variable_scope("VQ"): + with tf.compat.v1.variable_scope("VQ"): kdq_demo = KDQuantizer(100, 10, 5, 5, True, "euclidean") codes_vq, outputs_vq = kdq_demo.forward(tf.random_normal([64, 10, 5])) # tempering softmax - with tf.variable_scope("tempering_softmax"): + with tf.compat.v1.variable_scope("tempering_softmax"): kdq_demo = KDQuantizer(100, 10, 5, 10, False, "dot") codes_ts, outputs_ts = kdq_demo.forward(tf.random_normal([64, 10, 5])) diff --git a/core/kdq_embedding.py b/core/kdq_embedding.py index 42e7d39..8f80dfa 100644 --- a/core/kdq_embedding.py +++ b/core/kdq_embedding.py @@ -14,8 +14,8 @@ def full_embed(input, vocab_size, emb_size, hparams=None, Returns: input_emb: float tensor, embedding for entity idxs. """ - with tf.variable_scope(name): - embedding = tf.get_variable("embedding", [vocab_size, emb_size]) + with tf.compat.v1.variable_scope(name): + embedding = tf.compat.v1.get_variable("embedding", [vocab_size, emb_size]) input_emb = tf.nn.embedding_lookup(embedding, input) return input_emb @@ -44,8 +44,8 @@ def kdq_embed(input, vocab_size, emb_size, hparams=None, d_out = d if hparams.additive_quantization else d//D out_size = [D, emb_size] if hparams.additive_quantization else [emb_size] - with tf.variable_scope(name, reuse=tf.AUTO_REUSE): - query_wemb = tf.get_variable( + with tf.compat.v1.variable_scope(name, reuse=tf.compat.v1.AUTO_REUSE): + query_wemb = tf.compat.v1.get_variable( "query_wemb", [vocab_size, D * d_in], dtype=tf.float32) idxs = tf.reshape(input, [-1]) input_emb = tf.nn.embedding_lookup(query_wemb, idxs) # (bs*len, d) diff --git a/lm/ptb_word_lm.py b/lm/ptb_word_lm.py index 95c91ce..a98e23b 100644 --- a/lm/ptb_word_lm.py +++ b/lm/ptb_word_lm.py @@ -19,8 +19,6 @@ http://arxiv.org/abs/1409.2329 """ from __future__ import absolute_import -from kdq_embedding import full_embed, kdq_embed, KDQhparam -from kd_quantizer import KDQuantizer from __future__ import division from __future__ import print_function @@ -32,18 +30,21 @@ from collections import Counter, defaultdict import numpy as np import tensorflow as tf +import tensorflow_addons as tfa import reader import util parent_path = "/".join(os.getcwd().split('/')[:-1]) sys.path.append(os.path.join(parent_path, "core")) +from kd_quantizer import KDQuantizer +from kdq_embedding import full_embed, kdq_embed, KDQhparam # os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # disable when using tf.Print -flags = tf.flags -logging = tf.logging +flags = tf.compat.v1.flags +logging = tf.compat.v1.logging # Define basics. flags.DEFINE_string( @@ -142,14 +143,14 @@ def __init__(self, is_training, config, input_, vocab_size): # final softmax layer targets = input_.targets - softmax_w = tf.get_variable( + softmax_w = tf.compat.v1.get_variable( "softmax_w", [size, vocab_size], dtype=data_type()) - softmax_b = tf.get_variable("softmax_b", [vocab_size], dtype=data_type()) - logits = tf.nn.xw_plus_b(outputs, softmax_w, softmax_b) + softmax_b = tf.compat.v1.get_variable("softmax_b", [vocab_size], dtype=data_type()) + logits = tf.compat.v1.nn.xw_plus_b(outputs, softmax_w, softmax_b) # Reshape logits to be a 3-D tensor for sequence loss logits = tf.reshape(logits, [self.batch_size, self.num_steps, vocab_size]) - loss = tf.contrib.seq2seq.sequence_loss( + loss = tfa.seq2seq.sequence_loss( logits, targets, tf.ones([self.batch_size, self.num_steps], dtype=data_type()), @@ -157,7 +158,7 @@ def __init__(self, is_training, config, input_, vocab_size): average_across_batch=False) # (batch_size, num_steps) # Update the cost - self._nll = tf.reduce_sum(tf.reduce_mean(loss, 0)) + self._nll = tf.reduce_sum(input_tensor=tf.reduce_mean(input_tensor=loss, axis=0)) self._cost = self._nll self._final_state = state @@ -167,32 +168,32 @@ def __init__(self, is_training, config, input_, vocab_size): tf.expand_dims(targets, -1), [1] * targets.shape.ndims + [FLAGS.eval_topk]) hits = tf.reduce_sum( - tf.cast(tf.equal(preds_topk, targets_topk), tf.float32), -1) - self._recall_at_k = tf.reduce_sum(tf.reduce_mean(hits, 0)) + input_tensor=tf.cast(tf.equal(preds_topk, targets_topk), tf.float32), axis=-1) + self._recall_at_k = tf.reduce_sum(input_tensor=tf.reduce_mean(input_tensor=hits, axis=0)) if not is_training: return # Add regularization. print("[INFO] regularization loss", - tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) - self._cost += sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) + tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES)) + self._cost += sum(tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES)) # Optimizer self._lr = tf.Variable(0.0, trainable=False) - update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) + update_ops = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): - grads = tf.gradients(self._cost, tf.trainable_variables()) - tf.summary.scalar("global_grad_norm", tf.global_norm(grads)) + grads = tf.gradients(ys=self._cost, xs=tf.compat.v1.trainable_variables()) + tf.compat.v1.summary.scalar("global_grad_norm", tf.linalg.global_norm(grads)) grads, _ = tf.clip_by_global_norm(grads, config.max_grad_norm) - optimizer = tf.train.GradientDescentOptimizer(self._lr) + optimizer = tf.compat.v1.train.GradientDescentOptimizer(self._lr) self._train_op = optimizer.apply_gradients( - zip(grads, tf.trainable_variables()), - global_step=tf.train.get_or_create_global_step()) - self._new_lr = tf.placeholder( + zip(grads, tf.compat.v1.trainable_variables()),) + # global_step=tf.compat.v1.train.get_or_create_global_step()) + self._new_lr = tf.compat.v1.placeholder( tf.float32, shape=[], name="new_learning_rate") - self._lr_update = tf.assign(self._lr, self._new_lr) + self._lr_update = tf.compat.v1.assign(self._lr, self._new_lr) def _build_rnn_graph(self, inputs, config, is_training): if config.rnn_mode == CUDNN: @@ -203,55 +204,55 @@ def _build_rnn_graph(self, inputs, config, is_training): def _build_rnn_graph_cudnn(self, inputs, config, is_training): """Build the inference graph using CUDNN cell.""" if is_training and config.keep_prob < 1: - inputs = tf.nn.dropout(inputs, config.keep_prob) + inputs = tf.nn.dropout(inputs, rate=1 - (config.keep_prob)) - inputs = tf.transpose(inputs, [1, 0, 2]) + inputs = tf.transpose(a=inputs, perm=[1, 0, 2]) self._cell = tf.contrib.cudnn_rnn.CudnnLSTM( num_layers=config.num_layers, num_units=config.hidden_size, input_size=config.hidden_size, dropout=1 - config.keep_prob if is_training else 0) params_size_t = self._cell.params_size() - self._rnn_params = tf.get_variable( + self._rnn_params = tf.compat.v1.get_variable( "lstm_params", - initializer=tf.random_uniform( + initializer=tf.random.uniform( [params_size_t], -config.init_scale, config.init_scale), validate_shape=False) c = tf.zeros([config.num_layers, self.batch_size, config.hidden_size], tf.float32) h = tf.zeros([config.num_layers, self.batch_size, config.hidden_size], tf.float32) - self._initial_state = (tf.contrib.rnn.LSTMStateTuple(h=h, c=c),) + self._initial_state = (tf.nn.rnn_cell.LSTMStateTuple(h=h, c=c),) outputs, h, c = self._cell(inputs, h, c, self._rnn_params, is_training) - outputs = tf.transpose(outputs, [1, 0, 2]) + outputs = tf.transpose(a=outputs, perm=[1, 0, 2]) outputs = tf.reshape(outputs, [-1, config.hidden_size]) - return outputs, (tf.contrib.rnn.LSTMStateTuple(h=h, c=c),) + return outputs, (tf.nn.rnn_cell.LSTMStateTuple(h=h, c=c),) def _get_lstm_cell(self, config, is_training): if config.rnn_mode == BASIC: - return tf.contrib.rnn.BasicLSTMCell( + return tf.compat.v1.nn.rnn_cell.BasicLSTMCell( config.hidden_size, forget_bias=0.0, state_is_tuple=True, reuse=not is_training) if config.rnn_mode == BLOCK: - return tf.contrib.rnn.LSTMBlockCell( - config.hidden_size, forget_bias=0.0) + # https://www.tensorflow.org/api_docs/python/tf/keras/layers/LSTMCell + return tfa.rnn.LayerNormLSTMCell(config.hidden_size) raise ValueError("rnn_mode %s not supported" % config.rnn_mode) def _build_rnn_graph_lstm(self, inputs, config, is_training): """Build the inference graph using canonical LSTM cells without Wrapper.""" init_sates = [] final_states = [] - with tf.variable_scope("RNN", reuse=not is_training): + with tf.compat.v1.variable_scope("RNN", reuse=not is_training): for l in range(config.num_layers): - with tf.variable_scope("layer_%d" % l): + with tf.compat.v1.variable_scope("layer_%d" % l): cell = self._get_lstm_cell(config, is_training) - initial_state = cell.zero_state(self.batch_size, data_type()) + initial_state = cell.get_initial_state(batch_size=self.batch_size, dtype=data_type()) init_sates.append(initial_state) state = init_sates[-1] if is_training and config.keep_prob < 1: - inputs = tf.nn.dropout(inputs, config.keep_prob) + inputs = tf.nn.dropout(inputs, rate=1 - (config.keep_prob)) inputs = tf.unstack(inputs, num=self.num_steps, axis=1) - outputs, state = tf.contrib.rnn.static_rnn(cell, inputs, + outputs, state = tf.compat.v1.nn.static_rnn(cell, inputs, initial_state=init_sates[-1]) final_states.append(state) outputs = [tf.expand_dims(output, 1) for output in outputs] @@ -261,7 +262,7 @@ def _build_rnn_graph_lstm(self, inputs, config, is_training): self._initial_state = tuple(init_sates) state = tuple(final_states) if is_training and config.keep_prob < 1: - outputs = tf.nn.dropout(outputs, config.keep_prob) + outputs = tf.nn.dropout(outputs, rate=1 - (config.keep_prob)) return outputs, state def assign_lr(self, session, lr_value): @@ -380,8 +381,8 @@ def run_epoch(session, model, eval_op=None, verbose=False, mode=TRAIN_MODE): for step in range(model.input.epoch_size): feed_dict = {} for i, (c, h) in enumerate(model.initial_state): - feed_dict[c] = state[i].c - feed_dict[h] = state[i].h + feed_dict[c] = state[i][0] + feed_dict[h] = state[i][1] vals = session.run(fetches, feed_dict) cost = vals["cost"] @@ -438,32 +439,32 @@ def main(_): eval_config.num_steps = 1 with tf.Graph().as_default(): - initializer = tf.random_uniform_initializer(-config.init_scale, + initializer = tf.compat.v1.random_uniform_initializer(-config.init_scale, config.init_scale) - with tf.name_scope("Train"): + with tf.compat.v1.name_scope("Train"): train_input = PTBInput(config=config, data=train_data, name="TrainInput") - with tf.variable_scope("Model", reuse=None, initializer=initializer): + with tf.compat.v1.variable_scope("Model", reuse=None, initializer=initializer): m = PTBModel(is_training=True, config=config, input_=train_input, vocab_size=vocab_size) - tf.summary.scalar("Training Loss", m.cost) - tf.summary.scalar("Learning Rate", m.lr) + tf.compat.v1.summary.scalar("Training Loss", m.cost) + tf.compat.v1.summary.scalar("Learning Rate", m.lr) - with tf.name_scope("Valid"): + with tf.compat.v1.name_scope("Valid"): valid_input = PTBInput(config=config, data=valid_data, name="ValidInput") - with tf.variable_scope("Model", reuse=True, initializer=initializer): + with tf.compat.v1.variable_scope("Model", reuse=True, initializer=initializer): mvalid = PTBModel(is_training=False, config=config, input_=valid_input, vocab_size=vocab_size) - tf.summary.scalar("Validation Loss", mvalid.cost) + tf.compat.v1.summary.scalar("Validation Loss", mvalid.cost) - with tf.name_scope("Test"): + with tf.compat.v1.name_scope("Test"): test_input = PTBInput( config=eval_config, data=test_data, name="TestInput") - with tf.variable_scope("Model", reuse=True, initializer=initializer): + with tf.compat.v1.variable_scope("Model", reuse=True, initializer=initializer): mtest = PTBModel(is_training=False, config=eval_config, input_=test_input, @@ -472,10 +473,10 @@ def main(_): models = {"Train": m, "Valid": mvalid, "Test": mtest} print_at_beginning(config) - sv = tf.train.Supervisor(logdir=FLAGS.save_path, + sv = tf.compat.v1.train.Supervisor(logdir=FLAGS.save_path, save_model_secs=FLAGS.save_model_secs, save_summaries_secs=10) - config_proto = tf.ConfigProto(allow_soft_placement=True) + config_proto = tf.compat.v1.ConfigProto(allow_soft_placement=True) config_proto.gpu_options.allow_growth = True with sv.managed_session(config=config_proto) as session: for i in range(config.max_max_epoch): @@ -505,4 +506,4 @@ def main(_): if __name__ == "__main__": - tf.app.run() + tf.compat.v1.app.run() diff --git a/lm/reader.py b/lm/reader.py index 6558bf2..d60a09a 100644 --- a/lm/reader.py +++ b/lm/reader.py @@ -29,7 +29,7 @@ def _read_words(filename): - with tf.gfile.GFile(filename, "r") as f: + with tf.io.gfile.GFile(filename, "r") as f: if Py3: return f.read().replace("\n", "").split() else: @@ -122,22 +122,22 @@ def ptb_producer(raw_data, batch_size, num_steps, name=None): Raises: tf.errors.InvalidArgumentError: if batch_size or num_steps are too high. """ - with tf.name_scope(name, "PTBProducer", [raw_data, batch_size, num_steps]): - raw_data = tf.convert_to_tensor(raw_data, name="raw_data", dtype=tf.int32) + with tf.compat.v1.name_scope(name): #, "PTBProducer", [raw_data, batch_size, num_steps]): + raw_data = tf.convert_to_tensor(value=raw_data, name="raw_data", dtype=tf.int32) - data_len = tf.size(raw_data) + data_len = tf.size(input=raw_data) batch_len = data_len // batch_size data = tf.reshape(raw_data[0: batch_size * batch_len], [batch_size, batch_len]) epoch_size = (batch_len - 1) // num_steps - assertion = tf.assert_positive( + assertion = tf.compat.v1.assert_positive( epoch_size, message="epoch_size == 0, decrease batch_size or num_steps") with tf.control_dependencies([assertion]): epoch_size = tf.identity(epoch_size, name="epoch_size") - i = tf.train.range_input_producer(epoch_size, shuffle=False).dequeue() + i = tf.compat.v1.train.range_input_producer(epoch_size, shuffle=False).dequeue() x = tf.strided_slice(data, [0, i * num_steps], [batch_size, (i + 1) * num_steps]) x.set_shape([batch_size, num_steps]) diff --git a/lm/util.py b/lm/util.py index a36a03f..ea25c0b 100644 --- a/lm/util.py +++ b/lm/util.py @@ -23,9 +23,9 @@ from tensorflow.core.framework import variable_pb2 from tensorflow.core.protobuf import rewriter_config_pb2 -from tensorflow.contrib.tensorboard.plugins import projector +from tensorboard.plugins import projector -FLAGS = tf.flags.FLAGS +FLAGS = tf.compat.v1.flags.FLAGS eps_micro = 1e-15 # tf.float32 sensible. eps_tiny = 1e-10 # tf.float32 sensible. @@ -34,16 +34,16 @@ def export_state_tuples(state_tuples, name): for state_tuple in state_tuples: - tf.add_to_collection(name, state_tuple.c) - tf.add_to_collection(name, state_tuple.h) + tf.compat.v1.add_to_collection(name, state_tuple.c) + tf.compat.v1.add_to_collection(name, state_tuple.h) def import_state_tuples(state_tuples, name, num_replicas): restored = [] for i in range(len(state_tuples) * num_replicas): - c = tf.get_collection_ref(name)[2 * i + 0] - h = tf.get_collection_ref(name)[2 * i + 1] - restored.append(tf.contrib.rnn.LSTMStateTuple(c, h)) + c = tf.compat.v1.get_collection_ref(name)[2 * i + 0] + h = tf.compat.v1.get_collection_ref(name)[2 * i + 1] + restored.append(tf.nn.rnn_cell.LSTMStateTuple(c, h)) return tuple(restored) @@ -108,7 +108,7 @@ def safer_log(x, eps=eps_micro): Note that if x.dtype=tf.float16, \forall eps, eps < 3e-8, is equal to zero. """ - return tf.log(x + eps) + return tf.math.log(x + eps) def get_activation(name): @@ -153,20 +153,20 @@ def filter_actv(actv): return ( def get_optimizer(name): name = name.lower() if name == "sgd": - optimizer = tf.train.GradientDescentOptimizer + optimizer = tf.compat.v1.train.GradientDescentOptimizer elif name == "momentum": - optimizer = partial(tf.train.MomentumOptimizer, + optimizer = partial(tf.compat.v1.train.MomentumOptimizer, momentum=0.05, use_nesterov=True) elif name == "adam": - optimizer = tf.train.AdamOptimizer + optimizer = tf.compat.v1.train.AdamOptimizer # optimizer = partial(tf.train.AdamOptimizer, beta1=0.5, beta2=0.9) elif name == "lazy_adam": optimizer = tf.contrib.opt.LazyAdamOptimizer # optimizer = partial(tf.contrib.opt.LazyAdamOptimizer, beta1=0.5, beta2=0.9) elif name == "adagrad": - optimizer = tf.train.AdagradOptimizer + optimizer = tf.compat.v1.train.AdagradOptimizer elif name == "rmsprop": - optimizer = tf.train.RMSPropOptimizer + optimizer = tf.compat.v1.train.RMSPropOptimizer else: raise ValueError("Unknown optimizer name {}.".format(name)) @@ -178,7 +178,7 @@ def replace_list_element(data_list, x, y): def get_parameter_count(excludings=None, display_count=True): - trainables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) + trainables = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES) count = 0 for var in trainables: ignored = False @@ -190,7 +190,7 @@ def get_parameter_count(excludings=None, display_count=True): if ignored: continue if var.shape == tf.TensorShape(None): - tf.logging.warn("var {} has unknown shape and it is not counted.".format( + tf.compat.v1.logging.warn("var {} has unknown shape and it is not counted.".format( var.name)) continue if var.shape.as_list() == []: @@ -237,7 +237,7 @@ def save_emb_visualize_meta(save_path, embedding = config.embeddings.add() embedding.tensor_name = emb_var_name embedding.metadata_path = metadata_path - summary_writer = tf.summary.FileWriter(save_path) + summary_writer = tf.compat.v1.summary.FileWriter(save_path) projector.visualize_embeddings(summary_writer, config) diff --git a/text_classification/data.py b/text_classification/data.py index d3c643f..3ae8d0f 100644 --- a/text_classification/data.py +++ b/text_classification/data.py @@ -2,6 +2,8 @@ import numpy as np import tensorflow as tf +tf.compat.v1.disable_eager_execution() + def get_arrays(data_dir): data = np.load(os.path.join(data_dir, "data-simplified.npz")) @@ -27,12 +29,12 @@ def batchify_small(X, y, batch_size, num_epochs, reinitializer, is_train): dataset = dataset.batch(batch_size).repeat(num_epochs) if reinitializer: - iterator = tf.data.Iterator.from_structure( + iterator = tf.compat.v1.data.Iterator.from_structure( dataset.output_types, dataset.output_shapes) initializer = iterator.make_initializer(dataset) X, y = iterator.get_next() else: - X, y = dataset.make_one_shot_iterator().get_next() + X, y = tf.compat.v1.data.make_one_shot_iterator(dataset).get_next() initializer = None return X, y, initializer @@ -56,15 +58,15 @@ def batchify(X, y, batch_size, num_epochs, reinitializer, is_train): dataset = dataset.batch(batch_size).repeat(num_epochs) if reinitializer: - iterator = tf.data.Iterator.from_structure( - dataset.output_types, dataset.output_shapes) + iterator = tf.compat.v1.data.Iterator.from_structure( + tf.compat.v1.data.get_output_types(dataset), tf.compat.v1.data.get_output_shapes(dataset)) initializer = iterator.make_initializer(dataset) idxs = iterator.get_next() else: - idxs = dataset.make_one_shot_iterator().get_next() + idxs = tf.compat.v1.data.make_one_shot_iterator(dataset).get_next() initializer = None - X, y = tf.nn.embedding_lookup(X, idxs), tf.nn.embedding_lookup(y, idxs) + X, y = tf.nn.embedding_lookup(params=X, ids=idxs), tf.nn.embedding_lookup(params=y, ids=idxs) return X, y, initializer @@ -72,9 +74,9 @@ def get_data(data_name, data_dir, batch_size): X_train_d, y_train, X_test_d, y_test, vocab = get_arrays(data_dir) num_classes = len(set(y_train)) with tf.device("/cpu:0"): # DEBUG - X_train_holder = tf.placeholder(tf.int32, shape=X_train_d.shape) + X_train_holder = tf.compat.v1.placeholder(tf.int32, shape=X_train_d.shape) X_train = tf.Variable(X_train_holder, trainable=False) - X_test_holder = tf.placeholder(tf.int32, shape=X_test_d.shape) + X_test_holder = tf.compat.v1.placeholder(tf.int32, shape=X_test_d.shape) X_test = tf.Variable(X_test_holder, trainable=False) X_train, y_train, _ = batchify( X_train, y_train, batch_size, None, False, True) @@ -96,8 +98,8 @@ def get_data(data_name, data_dir, batch_size): (X_train_d, X_test_d, X_train_holder, X_test_holder, X_train, y_train, X_test, y_test, test_reinitializer, vocab) = get_data( data_name, data_dir, batch_size) - with tf.Session() as sess: - sess.run(tf.global_variables_initializer(), + with tf.compat.v1.Session() as sess: + sess.run(tf.compat.v1.global_variables_initializer(), feed_dict={X_train_holder: X_train_d, X_test_holder: X_test_d}) del X_train_d, X_test_d sess.run(test_reinitializer) diff --git a/text_classification/main.py b/text_classification/main.py index 6c67b67..7bfcac7 100644 --- a/text_classification/main.py +++ b/text_classification/main.py @@ -6,12 +6,12 @@ import time import numpy as np import tensorflow as tf -import cPickle as pickle + os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # TODO: avoid the OutOfRangeError msg. sys.path.insert(0, "../lm") -flags = tf.flags +flags = tf.compat.v1.flags FLAGS = flags.FLAGS flags.DEFINE_string("dataset", None, "") flags.DEFINE_string("data_dir", None, "") @@ -42,27 +42,27 @@ def main(_): FLAGS.dataset, FLAGS.data_dir, FLAGS.batch_size) flags.DEFINE_integer("vocab_size", len(vocab), 'Auto add vocab size') - with tf.name_scope("Train"): - with tf.variable_scope("model", reuse=False): + with tf.compat.v1.name_scope("Train"): + with tf.compat.v1.variable_scope("model", reuse=False): m = Model() loss_train, preds_train, train_op = m.forward( X_train, y_train, is_training=True) - with tf.name_scope("Test"): - with tf.variable_scope("model", reuse=True): + with tf.compat.v1.name_scope("Test"): + with tf.compat.v1.variable_scope("model", reuse=True): loss_test, preds_test, _ = m.forward( X_test, y_test, is_training=False) # Verbose. print("FLAGS:") - for key, value in tf.flags.FLAGS.__flags.items(): + for key, value in tf.compat.v1.flags.FLAGS.__flags.items(): print(key, value._value) print("Number of trainable params: {}".format(util.get_parameter_count())) - print(tf.trainable_variables()) + print(tf.compat.v1.trainable_variables()) # Training session. init_feed_dict = {X_train_holder: X_train_d, X_test_holder: X_test_d} - sv = tf.train.Supervisor(saver=None, init_feed_dict=init_feed_dict) - config_proto = tf.ConfigProto(allow_soft_placement=True) + sv = tf.compat.v1.train.Supervisor(saver=None, init_feed_dict=init_feed_dict) + config_proto = tf.compat.v1.ConfigProto(allow_soft_placement=True) config_proto.gpu_options.allow_growth = True with sv.managed_session(config=config_proto) as sess: del X_train_d, X_test_d @@ -101,4 +101,4 @@ def main(_): if __name__ == "__main__": - tf.app.run() + tf.compat.v1.app.run() diff --git a/text_classification/model.py b/text_classification/model.py index c5fbb3c..69f4739 100644 --- a/text_classification/model.py +++ b/text_classification/model.py @@ -1,15 +1,16 @@ -from kdq_embedding import full_embed, kdq_embed, KDQhparam -from kd_quantizer import KDQuantizer import os import sys import numpy as np import tensorflow as tf +import tensorflow_addons as tfa import util parent_path = "/".join(os.getcwd().split('/')[:-1]) sys.path.append(os.path.join(parent_path, "core")) +from kd_quantizer import KDQuantizer +from kdq_embedding import full_embed, kdq_embed, KDQhparam -FLAGS = tf.flags.FLAGS +FLAGS = tf.compat.v1.flags.FLAGS class Model(object): @@ -29,9 +30,9 @@ def forward(self, features, labels, is_training=False): train_op: op. """ num_classes = labels.shape.as_list()[-1] - batch_size = tf.shape(features)[0] + batch_size = tf.shape(input=features)[0] mask = tf.cast(tf.greater(features, 0), tf.float32) # (bs, max_seq_length) - lengths = tf.reduce_sum(mask, axis=1, keepdims=True) # (batch_size, 1) + lengths = tf.reduce_sum(input_tensor=mask, axis=1, keepdims=True) # (batch_size, 1) # Embedding if FLAGS.kdq_type == "none": @@ -46,35 +47,34 @@ def forward(self, features, labels, is_training=False): word_embs = inputs # (bs, length, emb_dim) word_embs *= tf.expand_dims(mask, -1) - embs_maxpool = tf.reduce_max(word_embs, 1) # Max pooling. - embs_meanpool = tf.reduce_sum(word_embs, 1) / lengths # Mean pooling. + embs_maxpool = tf.reduce_max(input_tensor=word_embs, axis=1) # Max pooling. + embs_meanpool = tf.reduce_sum(input_tensor=word_embs, axis=1) / lengths # Mean pooling. if FLAGS.concat_maxpooling: embs = tf.concat([embs_meanpool, embs_maxpool], -1) else: embs = embs_meanpool if FLAGS.hidden_layers > 0: embs = tf.nn.relu( - tf.layers.batch_normalization(embs, training=is_training)) - embs = tf.layers.dense(embs, FLAGS.dims) + tf.compat.v1.layers.batch_normalization(embs, training=is_training)) + embs = tf.compat.v1.layers.dense(embs, FLAGS.dims) embs = tf.nn.relu( - tf.layers.batch_normalization(embs, training=is_training)) - logits = tf.layers.dense(embs, num_classes) - preds = tf.argmax(logits, -1)[:batch_size] - loss = tf.nn.softmax_cross_entropy_with_logits_v2( + tf.compat.v1.layers.batch_normalization(embs, training=is_training)) + logits = tf.compat.v1.layers.dense(embs, num_classes) + preds = tf.argmax(input=logits, axis=-1)[:batch_size] + loss = tf.nn.softmax_cross_entropy_with_logits( labels=labels, logits=logits) if is_training: # Regular loss updater. - loss_scalar = tf.reduce_mean(loss) - loss_scalar += FLAGS.reg_weight * tf.reduce_mean(word_embs**2) + loss_scalar = tf.reduce_mean(input_tensor=loss) + loss_scalar += FLAGS.reg_weight * tf.reduce_mean(input_tensor=word_embs**2) loss_scalar += tf.reduce_sum( - tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) - train_op = tf.contrib.layers.optimize_loss( + input_tensor=tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES)) + op = util.get_optimizer(FLAGS.optimizer)(learning_rate=FLAGS.learning_rate) + train_op = op.minimize( loss=loss_scalar, - global_step=tf.train.get_or_create_global_step(), - learning_rate=FLAGS.learning_rate, - optimizer=util.get_optimizer(FLAGS.optimizer), - variables=tf.trainable_variables()) + global_step=tf.compat.v1.train.get_or_create_global_step(), + var_list=tf.compat.v1.trainable_variables()) else: train_op = False loss_scalar = None diff --git a/text_classification/util.py b/text_classification/util.py index 6aa9092..49bbd54 100644 --- a/text_classification/util.py +++ b/text_classification/util.py @@ -18,10 +18,11 @@ from __future__ import print_function import os -from functools import partial +from functools import partial, reduce import tensorflow as tf +import tensorflow_addons as tfa -FLAGS = tf.flags.FLAGS +FLAGS = tf.compat.v1.flags.FLAGS eps_micro = 1e-15 # tf.float32 sensible. eps_tiny = 1e-10 # tf.float32 sensible. @@ -50,20 +51,20 @@ def get_activation(name): def get_optimizer(name): name = name.lower() if name == "sgd": - optimizer = tf.train.GradientDescentOptimizer + optimizer = tf.compat.v1.train.GradientDescentOptimizer elif name == "momentum": - optimizer = partial(tf.train.MomentumOptimizer, + optimizer = partial(tf.compat.v1.train.MomentumOptimizer, momentum=0.05, use_nesterov=True) elif name == "adam": - optimizer = tf.train.AdamOptimizer + optimizer = tf.compat.v1.train.AdamOptimizer # optimizer = partial(tf.train.AdamOptimizer, beta1=0.5, beta2=0.9) elif name == "lazy_adam": - optimizer = tf.contrib.opt.LazyAdamOptimizer + optimizer = tfa.optimizers.LazyAdam # TODO: migrate to tf2, global_step not supported # optimizer = partial(tf.contrib.opt.LazyAdamOptimizer, beta1=0.5, beta2=0.9) elif name == "adagrad": - optimizer = tf.train.AdagradOptimizer + optimizer = tf.compat.v1.train.AdagradOptimizer elif name == "rmsprop": - optimizer = tf.train.RMSPropOptimizer + optimizer = tf.compat.v1.train.RMSPropOptimizer else: raise ValueError("Unknown optimizer name {}.".format(name)) @@ -71,7 +72,7 @@ def get_optimizer(name): def get_parameter_count(excludings=None, display_count=True): - trainables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) + trainables = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES) count = 0 for var in trainables: ignored = False @@ -83,7 +84,7 @@ def get_parameter_count(excludings=None, display_count=True): if ignored: continue if var.shape == tf.TensorShape(None): - tf.logging.warn("var {} has unknown shape and it is not counted.".format( + tf.compat.v1.logging.warn("var {} has unknown shape and it is not counted.".format( var.name)) continue if var.shape.as_list() == []: