Skip to content

Commit

Permalink
Change order of arguments for the Embedding layer to match Keras and …
Browse files Browse the repository at this point in the history
…pytorch.

PiperOrigin-RevId: 319309769
  • Loading branch information
Lukasz Kaiser authored and copybara-github committed Jul 1, 2020
1 parent faaec34 commit 714fb8e
Show file tree
Hide file tree
Showing 9 changed files with 29 additions and 32 deletions.
14 changes: 5 additions & 9 deletions trax/layers/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,30 +114,26 @@ def init_weights_and_state(self, input_signature):
class Embedding(base.Layer):
"""Trainable layer that maps discrete tokens/ids to vectors."""

# TODO(jonni): Consider reversing param order to: vocab_size, d_feature
def __init__(self,
d_feature,
vocab_size,
d_feature,
kernel_initializer=init.RandomNormalInitializer(1.0)):
"""Returns an embedding layer with given vocabulary size and vector size.
The layer clips input values (token ids) to the range `[0, vocab_size)`.
That is, negative token ids all clip to `0` before being mapped to a
vector, and token ids with value `vocab_size` or greater all clip to
`vocab_size - 1` before being mapped to a vector. In effect, both id `0`
and id `vocab_size - 1` are potentially overloaded as out-of-vocabulary
token ids.
TODO(jonni): Is this the behavior we want going forward?
`vocab_size - 1` before being mapped to a vector.
Args:
d_feature: Dimensionality/depth of the output vectors.
vocab_size: Size of the input vocabulary. The layer will assign a unique
vector to each id in `range(vocab_size)`.
d_feature: Dimensionality/depth of the output vectors.
kernel_initializer: Function that creates (random) initial vectors for
the embedding.
"""
super().__init__()
# TODO(jonni): is the clipping behavior what we want going forward?
super().__init__(name=f'Embedding_{vocab_size}_{d_feature}')
self._d_feature = d_feature # feature dimensionality
self._vocab_size = vocab_size
self._kernel_initializer = kernel_initializer
Expand Down
10 changes: 5 additions & 5 deletions trax/layers/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def test_init_twice_weights_same_shape(self):
class EmbeddingTest(absltest.TestCase):

def test_forward(self):
layer = tl.Embedding(3, 10) # d_feature=3, vocab_size=10
layer = tl.Embedding(10, 3) # vocab_size=10, d_feature=3
_, _ = layer.init(None) # Embedding init doesn't use input signature.
x = np.array([2, 3, 5, 3, 2])
y = layer(x)
Expand All @@ -130,7 +130,7 @@ def test_forward(self):
self.assertEqual(y[1].tolist(), y[3].tolist())

def test_negative_inputs_clip_to_zero(self):
layer = tl.Embedding(3, 10)
layer = tl.Embedding(10, 3)
_, _ = layer.init(None)
x = np.array([0, 2, 3, -2, -3])
y = layer(x)
Expand All @@ -140,7 +140,7 @@ def test_negative_inputs_clip_to_zero(self):
self.assertEqual(y[0].tolist(), y[4].tolist())

def test_large_inputs_clip_to_upper_bound(self):
layer = tl.Embedding(3, 10)
layer = tl.Embedding(10, 3)
_, _ = layer.init(None)
x = np.array([2, 3, 9, 10, 20])
y = layer(x)
Expand All @@ -152,7 +152,7 @@ def test_large_inputs_clip_to_upper_bound(self):
self.assertEqual(y[2].tolist(), y[4].tolist())

def test_new_weights(self):
layer = tl.Embedding(5, 20)
layer = tl.Embedding(20, 5)
_, _ = layer.init(None)

# Default weights sampled from Gaussian, mu = 0, sigma = 1.
Expand All @@ -167,7 +167,7 @@ def f(shape, rng):
n_elements = np.prod(shape)
return np.arange(n_elements).reshape(shape)

layer = tl.Embedding(2, 5, kernel_initializer=f)
layer = tl.Embedding(5, 2, kernel_initializer=f)
_, _ = layer.init(None)
x = np.array([0, 1, 2, 3, 4])
y = layer(x)
Expand Down
2 changes: 1 addition & 1 deletion trax/models/neural_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def NeuralGPU(d_feature=96, steps=16, vocab_size=2, mode='train'):

core = ConvDiagonalGRU(units=d_feature)
return tl.Serial(
tl.Embedding(d_feature=d_feature, vocab_size=vocab_size),
tl.Embedding(vocab_size=vocab_size, d_feature=d_feature),
[core] * steps,
tl.Dense(vocab_size),
tl.LogSoftmax(),
Expand Down
8 changes: 4 additions & 4 deletions trax/models/reformer/reformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def ReformerLM(vocab_size,
dropout=dropout, mode=mode)

positional_embedder = [
tl.Embedding(d_emb, vocab_size),
tl.Embedding(vocab_size, d_emb),
tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode), # pylint: disable=no-value-for-parameter
positional_encoding,
]
Expand Down Expand Up @@ -431,7 +431,7 @@ def ReformerShortenLM(vocab_size,
dropout=dropout, mode=mode)

positional_embedder = [
tl.Embedding(d_embedding, vocab_size),
tl.Embedding(vocab_size, d_embedding),
tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode), # pylint: disable=no-value-for-parameter
positional_encoding,
]
Expand Down Expand Up @@ -637,7 +637,7 @@ def PositionalEncoder(vocab_size, mode): # tokens --> vectors
positional_encoding = tl.PositionalEncoding(
max_len=max_len, dropout=dropout, mode=mode)
return [
tl.Embedding(d_model, vocab_size),
tl.Embedding(vocab_size, d_model),
tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode),
positional_encoding,
]
Expand Down Expand Up @@ -789,7 +789,7 @@ def PositionalEncoder(vocab_size, mode): # tokens --> vectors
dropout=dropout, mode=mode)

return [
tl.Embedding(d_model, vocab_size),
tl.Embedding(vocab_size, d_model),
tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode),
positional_encoding,
]
Expand Down
4 changes: 2 additions & 2 deletions trax/models/research/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ def BERT(d_model=768,
layer_norm_eps = 1e-12
d_head = d_model // n_heads

word_embeddings = tl.Embedding(d_model, vocab_size)
type_embeddings = tl.Embedding(d_model, type_vocab_size)
word_embeddings = tl.Embedding(vocab_size, d_model)
type_embeddings = tl.Embedding(type_vocab_size, d_model)
position_embeddings = tl.PositionalEncoding(max_len, mode=mode)
embeddings = [
tl.Select([0, 1, 0], n_in=3), # Drops 'idx' input.
Expand Down
2 changes: 1 addition & 1 deletion trax/models/research/skipping_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def SkippingTransformerLM(vocab_size,
to activations over a vocab set.
"""
embedder = [
tl.Embedding(d_model, vocab_size),
tl.Embedding(vocab_size, d_model),
tl.Dropout(rate=dropout, mode=mode),
tl.PositionalEncoding(max_len=max_len, mode=mode),
]
Expand Down
3 changes: 2 additions & 1 deletion trax/models/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ def ActionInjector(mode):
if is_discrete:
encode_layer = tl.Parallel(
tl.Dense(inject_actions_dim),
tl.Embedding(inject_actions_dim, vocab_size=vocab_size))
tl.Embedding(vocab_size, inject_actions_dim)
)
else:
encode_layer = tl.Parallel(
tl.Dense(inject_actions_dim),
Expand Down
8 changes: 4 additions & 4 deletions trax/models/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def MultiRNNCell():

return tl.Serial(
tl.ShiftRight(mode=mode),
tl.Embedding(d_model, vocab_size),
tl.Embedding(vocab_size, d_model),
tl.Dropout(rate=dropout, mode=mode),
tl.Branch([], zero_state),
tl.Scan(MultiRNNCell(), axis=1),
Expand Down Expand Up @@ -90,7 +90,7 @@ def GRULM(vocab_size=256,
"""
return tl.Serial(
tl.ShiftRight(mode=mode),
tl.Embedding(d_model, vocab_size),
tl.Embedding(vocab_size, d_model),
[tl.GRU(d_model) for _ in range(n_layers)],
tl.Dense(vocab_size),
tl.LogSoftmax()
Expand Down Expand Up @@ -133,13 +133,13 @@ def LSTMSeq2SeqAttn(input_vocab_size=256,
An LSTM sequence-to-sequence model with attention.
"""
input_encoder = tl.Serial(
tl.Embedding(d_model, input_vocab_size),
tl.Embedding(input_vocab_size, d_model),
[tl.LSTM(d_model) for _ in range(n_encoder_layers)],
)

pre_attention_decoder = tl.Serial(
tl.ShiftRight(mode=mode),
tl.Embedding(d_model, target_vocab_size),
tl.Embedding(target_vocab_size, d_model),
tl.LSTM(d_model),
)

Expand Down
10 changes: 5 additions & 5 deletions trax/models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def TransformerEncoder(vocab_size,
activations over a set of output classes.
"""
positional_encoder = [
tl.Embedding(d_model, vocab_size),
tl.Embedding(vocab_size, d_model),
tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode),
tl.PositionalEncoding(max_len=max_len)]

Expand Down Expand Up @@ -114,7 +114,7 @@ def TransformerDecoder(vocab_size=None,
tensor to a continuous tensor.
"""
positional_encoder = [
(tl.Embedding(d_model, vocab_size) if vocab_size is not None
(tl.Embedding(vocab_size, d_model) if vocab_size is not None
else tl.Dense(d_model)),
tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode),
tl.PositionalEncoding(max_len=max_len)]
Expand Down Expand Up @@ -165,7 +165,7 @@ def TransformerLM(vocab_size,
to activations over a vocab set.
"""
positional_encoder = [
tl.Embedding(d_model, vocab_size),
tl.Embedding(vocab_size, d_model),
tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode),
tl.PositionalEncoding(max_len=max_len, mode=mode)]

Expand Down Expand Up @@ -223,7 +223,7 @@ def Transformer(input_vocab_size,
"""
def PositionalEncoder(vocab_size): # tokens --> vectors
return [
tl.Embedding(d_model, vocab_size),
tl.Embedding(vocab_size, d_model),
tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode),
tl.PositionalEncoding(max_len=max_len),
]
Expand Down Expand Up @@ -315,7 +315,7 @@ def TransformerNoEncDecAttention(input_vocab_size,
"""
def PositionalEncoder(vocab_size): # tokens --> vectors
return [
tl.Embedding(d_model, vocab_size),
tl.Embedding(vocab_size, d_model),
tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode),
tl.PositionalEncoding(max_len=max_len),
]
Expand Down

0 comments on commit 714fb8e

Please sign in to comment.