Skip to content

Commit

Permalink
Allow embedding extension to load from pre-trained embeddings file. (a…
Browse files Browse the repository at this point in the history
…llenai#2387)

* Rough attempt for Embedder/Embedding extension.

* fix some mistakes.

* add tests for token-embedder and text-field-embedder extension.

* fix vocab_namespace usage in embedding.py

* update names and change some comments.

* update embedding tests.

* fix some typos.

* add more tests.

* update some comments.

* fix minor pylint issue.

* Implement extend_vocab for TokenCharactersEncoder.

* minor simplification.

* Update help text for --extend-vocab in fine-tune command.

* Shift location of model tests appropriately.

* Allow root Embedding in model to be extendable.

* Incorporate PR comments in embedding.py.

* Fix annotations.

* Add appropriate docstrings and minor cleanup.

* Resolve pylint complains.

* shift disable pytlint protected-access to top of tests.

* Add a test to increase coverage.

* minor update in TokenEmbedder docstrings.

* Allow to pass pretrained_file in embedding extension (with tests).

* Remove a blank line.

* Add a blank line before Returns block in Embedding docstring.

* Fix pylint complains.

* Allow pretrained file to be passed in token_characters_encoder also.

* Fix pylint complains and update some comments.

* Test to ensure trained embeddings do not get overriden.

* PR feedback: update comments, fix annotation.
  • Loading branch information
HarshTrivedi authored and matt-gardner committed Jan 18, 2019
1 parent 174f539 commit 9719b5c
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 9 deletions.
28 changes: 22 additions & 6 deletions allennlp/modules/token_embedders/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,14 @@ def forward(self, inputs): # pylint: disable=arguments-differ
return embedded

@overrides
def extend_vocab(self, extended_vocab: Vocabulary, vocab_namespace: str = None):
def extend_vocab(self, # pylint: disable=arguments-differ
extended_vocab: Vocabulary,
vocab_namespace: str = None,
pretrained_file: str = None) -> None:
"""
Extends the embedding matrix according to the extended vocabulary.
Extended weight would be initialized with xavier uniform.
If pretrained_file is available, it will be used for initializing the new words
in the extended vocabulary; otherwise they will be initialized with xavier uniform.
Parameters
----------
Expand All @@ -162,6 +166,10 @@ def extend_vocab(self, extended_vocab: Vocabulary, vocab_namespace: str = None):
can pass it. If not passed, it will check if vocab_namespace used at the
time of ``Embedding`` construction is available. If so, this namespace
will be used or else default 'tokens' namespace will be used.
pretrained_file : str, (optional, default=None)
A file containing pretrained embeddings can be specified here. It can be
the path to a local file or an URL of a (cached) remote file. Check format
details in ``from_params`` of ``Embedding`` class.
"""
# Caveat: For allennlp v0.8.1 and below, we weren't storing vocab_namespace as an attribute,
# knowing which is necessary at time of embedding vocab extension. So old archive models are
Expand All @@ -172,11 +180,19 @@ def extend_vocab(self, extended_vocab: Vocabulary, vocab_namespace: str = None):
vocab_namespace = "tokens"
logging.warning("No vocab_namespace provided to Embedder.extend_vocab. Defaulting to 'tokens'.")

extended_num_embeddings = extended_vocab.get_vocab_size(vocab_namespace)
extra_num_embeddings = extended_num_embeddings - self.num_embeddings
embedding_dim = self.weight.data.shape[-1]
extra_weight = torch.FloatTensor(extra_num_embeddings, embedding_dim)
torch.nn.init.xavier_uniform_(extra_weight)
if not pretrained_file:
extended_num_embeddings = extended_vocab.get_vocab_size(vocab_namespace)
extra_num_embeddings = extended_num_embeddings - self.num_embeddings
extra_weight = torch.FloatTensor(extra_num_embeddings, embedding_dim)
torch.nn.init.xavier_uniform_(extra_weight)
else:
# It's easiest to just reload the embeddings for the entire vocab,
# then only keep the ones we need.
whole_weight = _read_pretrained_embeddings_file(pretrained_file, embedding_dim,
extended_vocab, vocab_namespace)
extra_weight = whole_weight[self.num_embeddings:, :]

extended_weight = torch.cat([self.weight.data, extra_weight], dim=0)
self.weight = torch.nn.Parameter(extended_weight, requires_grad=self.weight.requires_grad)

Expand Down
16 changes: 13 additions & 3 deletions allennlp/modules/token_embedders/token_characters_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,14 @@ def forward(self, token_characters: torch.Tensor) -> torch.Tensor: # pylint: di
return self._dropout(self._encoder(self._embedding(token_characters), mask))

@overrides
def extend_vocab(self, extended_vocab: Vocabulary, vocab_namespace: str = "token_characters"):
def extend_vocab(self, # pylint: disable=arguments-differ
extended_vocab: Vocabulary,
vocab_namespace: str = "token_characters",
pretrained_file: str = None) -> None:
"""
Extends the embedding module according to the extended vocabulary.
Extended weight would be initialized with xavier uniform.
If pretrained_file is available, it will be used for initializing the new words
in the extended vocabulary; otherwise they will be initialized with xavier uniform.
Parameters
----------
Expand All @@ -52,11 +56,17 @@ def extend_vocab(self, extended_vocab: Vocabulary, vocab_namespace: str = "token
you can pass it here. If not passed, it will check if vocab_namespace used
at the time of ``TokenCharactersEncoder`` construction is available. If so, this
namespace will be used or else default 'token_characters' namespace will be used.
pretrained_file : str, (optional, default=None)
A file containing pretrained embeddings can be specified here. It can be
the path to a local file or an URL of a (cached) remote file. Check format
details in ``from_params`` of ``Embedding`` class.
"""
# Caveat: For allennlp v0.8.1 and below, we weren't storing vocab_namespace as an attribute, knowing
# which is necessary at time of token_characters_encoder vocab extension. So old archive models are
# currently unextendable unless the user used default vocab_namespace 'token_characters' for it.
self._embedding._module.extend_vocab(extended_vocab, vocab_namespace) # pylint: disable=protected-access
self._embedding._module.extend_vocab(extended_vocab, # pylint: disable=protected-access
vocab_namespace=vocab_namespace,
pretrained_file=pretrained_file)

# The setdefault requires a custom from_params
@classmethod
Expand Down
36 changes: 36 additions & 0 deletions allennlp/tests/modules/token_embedders/embedding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,3 +281,39 @@ def test_embedding_vocab_extension_without_stored_namespace(self):
extended_weight = embedder.weight
assert extended_weight.shape[0] == 5
assert torch.all(extended_weight[:4, :] == original_weight[:4, :])

def test_embedding_vocab_extension_works_with_pretrained_embedding_file(self):
vocab = Vocabulary()
vocab.add_token_to_namespace('word1')
vocab.add_token_to_namespace('word2')

embeddings_filename = str(self.TEST_DIR / "embeddings2.gz")
with gzip.open(embeddings_filename, 'wb') as embeddings_file:
embeddings_file.write("word3 0.5 0.3 -6.0\n".encode('utf-8'))
embeddings_file.write("word4 1.0 2.3 -1.0\n".encode('utf-8'))
embeddings_file.write("word2 0.1 0.4 -4.0\n".encode('utf-8'))
embeddings_file.write("word1 1.0 2.3 -1.0\n".encode('utf-8'))

embedding_params = Params({"vocab_namespace": "tokens", "embedding_dim": 3,
"pretrained_file": embeddings_filename})
embedder = Embedding.from_params(vocab, embedding_params)

# Change weight to simulate embedding training
embedder.weight.data += 1
assert torch.all(embedder.weight[2:, :] == torch.Tensor([[2.0, 3.3, 0.0], [1.1, 1.4, -3.0]]))
original_weight = embedder.weight

assert tuple(original_weight.size()) == (4, 3) # 4 because of padding and OOV

vocab.add_token_to_namespace('word3')
embedder.extend_vocab(vocab, pretrained_file=embeddings_filename) # default namespace
extended_weight = embedder.weight

# Make sure extenstion happened for extra token in extended vocab
assert tuple(extended_weight.size()) == (5, 3)

# Make sure extension doesn't change original trained weights.
assert torch.all(original_weight[:4, :] == extended_weight[:4, :])

# Make sure extended weight is taken from the embedding file.
assert torch.all(extended_weight[4, :] == torch.Tensor([0.5, 0.3, -6.0]))

0 comments on commit 9719b5c

Please sign in to comment.