Skip to content

Commit

Permalink
from_pretrained flag in TransformersEmbedder, to avoid downloadin…
Browse files Browse the repository at this point in the history
…g the model if you want to start from a checkpoint.
  • Loading branch information
Riccorl committed Jul 7, 2022
1 parent d8e6971 commit 09b6359
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 6 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ stuff
/test.ipynb
/test.py

# Fleet
.fleet

# Created by https://www.toptal.com/developers/gitignore/api/python,pycharm+all,vscode,macos,linux,windows
# Edit at https://www.toptal.com/developers/gitignore?templates=python,pycharm+all,vscode,macos,linux,windows
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
torch>=1.7,<1.13
transformers>=4.3,<4.21
transformers>=4.14,<4.21
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
extras["torch"] = ["torch>=1.5,<1.13"]
extras["all"] = extras["torch"]

install_requires = ["transformers>=4.3,<4.21"]
install_requires = ["transformers>=4.14,<4.21"]

setuptools.setup(
name="transformers_embedder",
version="3.0.2",
version="3.0.3",
author="Riccardo Orlando",
author_email="[email protected]",
description="Word level transformer based embeddings",
Expand Down
16 changes: 13 additions & 3 deletions transformers_embedder/embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ class TransformersEmbedder(torch.nn.Module):
If ``True``, the transformer model is fine-tuned during training.
return_all (:obj:`bool`, optional, defaults to :obj:`False`):
If ``True``, returns all the outputs from the HuggingFace model.
from_pretrained (:obj:`bool`, optional, defaults to :obj:`True`):
If ``True``, the model is loaded from a pre-trained model, otherwise it is initialized with
random weights. Usefull when you want to load a model from a specific checkpoint, without
having to download the entire model.
"""

def __init__(
Expand All @@ -62,6 +66,7 @@ def __init__(
output_layers: Sequence[int] = (-4, -3, -2, -1),
fine_tune: bool = True,
return_all: bool = False,
from_pretrained: bool = True,
*args,
**kwargs,
) -> None:
Expand All @@ -74,9 +79,14 @@ def __init__(
*args,
**kwargs,
)
self.transformer_model = tr.AutoModel.from_pretrained(
model, config=config, *args, **kwargs
)
if from_pretrained:
self.transformer_model = tr.AutoModel.from_pretrained(
model, config=config, *args, **kwargs
)
else:
self.transformer_model = tr.AutoModel.from_config(
config, *args, **kwargs
)
else:
self.transformer_model = model

Expand Down

0 comments on commit 09b6359

Please sign in to comment.