From 714fb8e47a6d4f2d422d8c404d2de5e95e43d2f0 Mon Sep 17 00:00:00 2001 From: Lukasz Kaiser Date: Wed, 1 Jul 2020 15:37:55 -0700 Subject: [PATCH] Change order of arguments for the Embedding layer to match Keras and pytorch. PiperOrigin-RevId: 319309769 --- trax/layers/core.py | 14 +++++--------- trax/layers/core_test.py | 10 +++++----- trax/models/neural_gpu.py | 2 +- trax/models/reformer/reformer.py | 8 ++++---- trax/models/research/bert.py | 4 ++-- trax/models/research/skipping_transformer.py | 2 +- trax/models/rl.py | 3 ++- trax/models/rnn.py | 8 ++++---- trax/models/transformer.py | 10 +++++----- 9 files changed, 29 insertions(+), 32 deletions(-) diff --git a/trax/layers/core.py b/trax/layers/core.py index 9aa8b0deb..b6368d440 100644 --- a/trax/layers/core.py +++ b/trax/layers/core.py @@ -114,30 +114,26 @@ def init_weights_and_state(self, input_signature): class Embedding(base.Layer): """Trainable layer that maps discrete tokens/ids to vectors.""" - # TODO(jonni): Consider reversing param order to: vocab_size, d_feature def __init__(self, - d_feature, vocab_size, + d_feature, kernel_initializer=init.RandomNormalInitializer(1.0)): """Returns an embedding layer with given vocabulary size and vector size. The layer clips input values (token ids) to the range `[0, vocab_size)`. That is, negative token ids all clip to `0` before being mapped to a vector, and token ids with value `vocab_size` or greater all clip to - `vocab_size - 1` before being mapped to a vector. In effect, both id `0` - and id `vocab_size - 1` are potentially overloaded as out-of-vocabulary - token ids. - - TODO(jonni): Is this the behavior we want going forward? + `vocab_size - 1` before being mapped to a vector. Args: - d_feature: Dimensionality/depth of the output vectors. vocab_size: Size of the input vocabulary. The layer will assign a unique vector to each id in `range(vocab_size)`. + d_feature: Dimensionality/depth of the output vectors. kernel_initializer: Function that creates (random) initial vectors for the embedding. """ - super().__init__() + # TODO(jonni): is the clipping behavior what we want going forward? + super().__init__(name=f'Embedding_{vocab_size}_{d_feature}') self._d_feature = d_feature # feature dimensionality self._vocab_size = vocab_size self._kernel_initializer = kernel_initializer diff --git a/trax/layers/core_test.py b/trax/layers/core_test.py index 8ce69878f..a89a1b451 100644 --- a/trax/layers/core_test.py +++ b/trax/layers/core_test.py @@ -114,7 +114,7 @@ def test_init_twice_weights_same_shape(self): class EmbeddingTest(absltest.TestCase): def test_forward(self): - layer = tl.Embedding(3, 10) # d_feature=3, vocab_size=10 + layer = tl.Embedding(10, 3) # vocab_size=10, d_feature=3 _, _ = layer.init(None) # Embedding init doesn't use input signature. x = np.array([2, 3, 5, 3, 2]) y = layer(x) @@ -130,7 +130,7 @@ def test_forward(self): self.assertEqual(y[1].tolist(), y[3].tolist()) def test_negative_inputs_clip_to_zero(self): - layer = tl.Embedding(3, 10) + layer = tl.Embedding(10, 3) _, _ = layer.init(None) x = np.array([0, 2, 3, -2, -3]) y = layer(x) @@ -140,7 +140,7 @@ def test_negative_inputs_clip_to_zero(self): self.assertEqual(y[0].tolist(), y[4].tolist()) def test_large_inputs_clip_to_upper_bound(self): - layer = tl.Embedding(3, 10) + layer = tl.Embedding(10, 3) _, _ = layer.init(None) x = np.array([2, 3, 9, 10, 20]) y = layer(x) @@ -152,7 +152,7 @@ def test_large_inputs_clip_to_upper_bound(self): self.assertEqual(y[2].tolist(), y[4].tolist()) def test_new_weights(self): - layer = tl.Embedding(5, 20) + layer = tl.Embedding(20, 5) _, _ = layer.init(None) # Default weights sampled from Gaussian, mu = 0, sigma = 1. @@ -167,7 +167,7 @@ def f(shape, rng): n_elements = np.prod(shape) return np.arange(n_elements).reshape(shape) - layer = tl.Embedding(2, 5, kernel_initializer=f) + layer = tl.Embedding(5, 2, kernel_initializer=f) _, _ = layer.init(None) x = np.array([0, 1, 2, 3, 4]) y = layer(x) diff --git a/trax/models/neural_gpu.py b/trax/models/neural_gpu.py index ad877c656..a497d2917 100644 --- a/trax/models/neural_gpu.py +++ b/trax/models/neural_gpu.py @@ -72,7 +72,7 @@ def NeuralGPU(d_feature=96, steps=16, vocab_size=2, mode='train'): core = ConvDiagonalGRU(units=d_feature) return tl.Serial( - tl.Embedding(d_feature=d_feature, vocab_size=vocab_size), + tl.Embedding(vocab_size=vocab_size, d_feature=d_feature), [core] * steps, tl.Dense(vocab_size), tl.LogSoftmax(), diff --git a/trax/models/reformer/reformer.py b/trax/models/reformer/reformer.py index 3bc10c2df..8aa54d21e 100644 --- a/trax/models/reformer/reformer.py +++ b/trax/models/reformer/reformer.py @@ -325,7 +325,7 @@ def ReformerLM(vocab_size, dropout=dropout, mode=mode) positional_embedder = [ - tl.Embedding(d_emb, vocab_size), + tl.Embedding(vocab_size, d_emb), tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode), # pylint: disable=no-value-for-parameter positional_encoding, ] @@ -431,7 +431,7 @@ def ReformerShortenLM(vocab_size, dropout=dropout, mode=mode) positional_embedder = [ - tl.Embedding(d_embedding, vocab_size), + tl.Embedding(vocab_size, d_embedding), tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode), # pylint: disable=no-value-for-parameter positional_encoding, ] @@ -637,7 +637,7 @@ def PositionalEncoder(vocab_size, mode): # tokens --> vectors positional_encoding = tl.PositionalEncoding( max_len=max_len, dropout=dropout, mode=mode) return [ - tl.Embedding(d_model, vocab_size), + tl.Embedding(vocab_size, d_model), tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode), positional_encoding, ] @@ -789,7 +789,7 @@ def PositionalEncoder(vocab_size, mode): # tokens --> vectors dropout=dropout, mode=mode) return [ - tl.Embedding(d_model, vocab_size), + tl.Embedding(vocab_size, d_model), tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode), positional_encoding, ] diff --git a/trax/models/research/bert.py b/trax/models/research/bert.py index d49943e73..cc94a18b5 100644 --- a/trax/models/research/bert.py +++ b/trax/models/research/bert.py @@ -75,8 +75,8 @@ def BERT(d_model=768, layer_norm_eps = 1e-12 d_head = d_model // n_heads - word_embeddings = tl.Embedding(d_model, vocab_size) - type_embeddings = tl.Embedding(d_model, type_vocab_size) + word_embeddings = tl.Embedding(vocab_size, d_model) + type_embeddings = tl.Embedding(type_vocab_size, d_model) position_embeddings = tl.PositionalEncoding(max_len, mode=mode) embeddings = [ tl.Select([0, 1, 0], n_in=3), # Drops 'idx' input. diff --git a/trax/models/research/skipping_transformer.py b/trax/models/research/skipping_transformer.py index f9e0622d4..fca6b5909 100644 --- a/trax/models/research/skipping_transformer.py +++ b/trax/models/research/skipping_transformer.py @@ -153,7 +153,7 @@ def SkippingTransformerLM(vocab_size, to activations over a vocab set. """ embedder = [ - tl.Embedding(d_model, vocab_size), + tl.Embedding(vocab_size, d_model), tl.Dropout(rate=dropout, mode=mode), tl.PositionalEncoding(max_len=max_len, mode=mode), ] diff --git a/trax/models/rl.py b/trax/models/rl.py index fe5dbaf57..b6e03e334 100644 --- a/trax/models/rl.py +++ b/trax/models/rl.py @@ -78,7 +78,8 @@ def ActionInjector(mode): if is_discrete: encode_layer = tl.Parallel( tl.Dense(inject_actions_dim), - tl.Embedding(inject_actions_dim, vocab_size=vocab_size)) + tl.Embedding(vocab_size, inject_actions_dim) + ) else: encode_layer = tl.Parallel( tl.Dense(inject_actions_dim), diff --git a/trax/models/rnn.py b/trax/models/rnn.py index 02e366adb..e311d99f8 100644 --- a/trax/models/rnn.py +++ b/trax/models/rnn.py @@ -60,7 +60,7 @@ def MultiRNNCell(): return tl.Serial( tl.ShiftRight(mode=mode), - tl.Embedding(d_model, vocab_size), + tl.Embedding(vocab_size, d_model), tl.Dropout(rate=dropout, mode=mode), tl.Branch([], zero_state), tl.Scan(MultiRNNCell(), axis=1), @@ -90,7 +90,7 @@ def GRULM(vocab_size=256, """ return tl.Serial( tl.ShiftRight(mode=mode), - tl.Embedding(d_model, vocab_size), + tl.Embedding(vocab_size, d_model), [tl.GRU(d_model) for _ in range(n_layers)], tl.Dense(vocab_size), tl.LogSoftmax() @@ -133,13 +133,13 @@ def LSTMSeq2SeqAttn(input_vocab_size=256, An LSTM sequence-to-sequence model with attention. """ input_encoder = tl.Serial( - tl.Embedding(d_model, input_vocab_size), + tl.Embedding(input_vocab_size, d_model), [tl.LSTM(d_model) for _ in range(n_encoder_layers)], ) pre_attention_decoder = tl.Serial( tl.ShiftRight(mode=mode), - tl.Embedding(d_model, target_vocab_size), + tl.Embedding(target_vocab_size, d_model), tl.LSTM(d_model), ) diff --git a/trax/models/transformer.py b/trax/models/transformer.py index 4f1f6b29a..90d9a123b 100644 --- a/trax/models/transformer.py +++ b/trax/models/transformer.py @@ -55,7 +55,7 @@ def TransformerEncoder(vocab_size, activations over a set of output classes. """ positional_encoder = [ - tl.Embedding(d_model, vocab_size), + tl.Embedding(vocab_size, d_model), tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode), tl.PositionalEncoding(max_len=max_len)] @@ -114,7 +114,7 @@ def TransformerDecoder(vocab_size=None, tensor to a continuous tensor. """ positional_encoder = [ - (tl.Embedding(d_model, vocab_size) if vocab_size is not None + (tl.Embedding(vocab_size, d_model) if vocab_size is not None else tl.Dense(d_model)), tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode), tl.PositionalEncoding(max_len=max_len)] @@ -165,7 +165,7 @@ def TransformerLM(vocab_size, to activations over a vocab set. """ positional_encoder = [ - tl.Embedding(d_model, vocab_size), + tl.Embedding(vocab_size, d_model), tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode), tl.PositionalEncoding(max_len=max_len, mode=mode)] @@ -223,7 +223,7 @@ def Transformer(input_vocab_size, """ def PositionalEncoder(vocab_size): # tokens --> vectors return [ - tl.Embedding(d_model, vocab_size), + tl.Embedding(vocab_size, d_model), tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode), tl.PositionalEncoding(max_len=max_len), ] @@ -315,7 +315,7 @@ def TransformerNoEncDecAttention(input_vocab_size, """ def PositionalEncoder(vocab_size): # tokens --> vectors return [ - tl.Embedding(d_model, vocab_size), + tl.Embedding(vocab_size, d_model), tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode), tl.PositionalEncoding(max_len=max_len), ]