Skip to content

Commit

Permalink
Fix torch import when padding without tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
Riccorl committed Jul 8, 2021
1 parent 26722c4 commit c1593d0
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 38 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

setuptools.setup(
name="transformer_embedder", # Replace with your own username
version="1.7.12",
version="1.7.13",
author="Riccardo Orlando",
author_email="[email protected]",
description="Word level transformer based embeddings",
Expand Down
55 changes: 18 additions & 37 deletions transformer_embedder/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@
from functools import partial
from typing import List, Dict, Union, Tuple, Any


import transformers as tr

from transformer_embedder import MODELS_WITH_STARTING_TOKEN, MODELS_WITH_DOUBLE_SEP
from transformer_embedder import utils
from transformer_embedder.utils import is_torch_available, is_spacy_available

if utils.is_torch_available():
if is_torch_available():
import torch

if utils.is_spacy_available():
if is_spacy_available():
import spacy
from spacy.cli.download import download as spacy_download

Expand All @@ -35,9 +35,7 @@ def __init__(
self.config = tr.AutoConfig.from_pretrained(model)
else:
self.huggingface_tokenizer = model
self.config = tr.AutoConfig.from_pretrained(
self.huggingface_tokenizer.name_or_path
)
self.config = tr.AutoConfig.from_pretrained(self.huggingface_tokenizer.name_or_path)
# spacy tokenizer, lazy load. None at first
self.spacy_tokenizer = None
# default multilingual model
Expand Down Expand Up @@ -131,16 +129,10 @@ def __call__(
)

# if text is str or a list of str and they are not split, then text needs to be tokenized
if isinstance(text, str) or (
not is_split_into_words and isinstance(text[0], str)
):
if isinstance(text, str) or (not is_split_into_words and isinstance(text[0], str)):
if not is_batched:
text = self.pretokenize(text, use_spacy=use_spacy)
text_pair = (
self.pretokenize(text_pair, use_spacy=use_spacy)
if text_pair
else None
)
text_pair = self.pretokenize(text_pair, use_spacy=use_spacy) if text_pair else None
else:
text = [self.pretokenize(t, use_spacy=use_spacy) for t in text]
text_pair = (
Expand Down Expand Up @@ -224,17 +216,13 @@ def build_tokens(
Returns:
a dictionary with A and B encoded
"""
words, input_ids, token_type_ids, offsets = self._build_tokens(
text, max_len=max_len
)
words, input_ids, token_type_ids, offsets = self._build_tokens(text, max_len=max_len)
if text_pair:
words_b, input_ids_b, token_type_ids_b, offsets_b = self._build_tokens(
text_pair, True, max_len
)
# align offsets of sentence b
offsets_b = [
(o[0] + len(input_ids), o[1] + len(input_ids)) for o in offsets_b
]
offsets_b = [(o[0] + len(input_ids), o[1] + len(input_ids)) for o in offsets_b]
offsets = offsets + offsets_b
input_ids += input_ids_b
token_type_ids += token_type_ids_b
Expand Down Expand Up @@ -302,9 +290,7 @@ def _build_tokens(
token_type_ids += [token_type_id]
return words, input_ids, token_type_ids, offsets

def pad_batch(
self, batch: Dict[str, list], max_length: int = None
) -> Dict[str, list]:
def pad_batch(self, batch: Dict[str, list], max_length: int = None) -> Dict[str, list]:
"""
Pad the batch to its maximum length.
Expand Down Expand Up @@ -358,17 +344,17 @@ def pad_sequence(
f"`length` must be an `int`, `subtoken` or `word`. Current value is `{length}`"
)
padding = [value] * abs(length - len(sequence))
if isinstance(sequence, torch.Tensor):
if is_torch_available() and isinstance(sequence, torch.Tensor):
if len(sequence.shape) > 1:
raise ValueError(
f"Sequence tensor must be 1D. Current shape is `{len(sequence.shape)}`"
)
padding = torch.as_tensor(padding)
if pad_to_left:
if isinstance(sequence, torch.Tensor):
if is_torch_available() and isinstance(sequence, torch.Tensor):
return torch.cat((padding, sequence), -1)
return padding + sequence
if isinstance(sequence, torch.Tensor):
if is_torch_available() and isinstance(sequence, torch.Tensor):
return torch.cat((sequence, padding), -1)
return sequence + padding

Expand All @@ -390,9 +376,7 @@ def pretokenize(self, text: str, use_spacy: bool = False) -> List[str]:
return [t.text for t in text]
return text.split(" ")

def add_special_tokens(
self, special_tokens_dict: Dict[str, Union[str, tr.AddedToken]]
) -> int:
def add_special_tokens(self, special_tokens_dict: Dict[str, Union[str, tr.AddedToken]]) -> int:
"""
Add a dictionary of special tokens (eos, pad, cls, etc.) to the encoder.
If special tokens are NOT in the vocabulary, they are added to it (indexed starting from the last
Expand Down Expand Up @@ -458,8 +442,7 @@ def to_tensor(self, batch: Union[List[dict], dict]) -> Dict[str, "torch.Tensor"]
"""
# convert to tensor
batch = {
k: torch.as_tensor(v) if k in self.to_tensor_inputs else v
for k, v in batch.items()
k: torch.as_tensor(v) if k in self.to_tensor_inputs else v for k, v in batch.items()
}
return batch

Expand All @@ -474,9 +457,7 @@ def _load_spacy(self) -> "spacy.tokenizer.Tokenizer":
try:
spacy_tagger = spacy.load(self.language, exclude=["ner", "parser"])
except OSError:
logger.info(
f"Spacy model '{self.language}' not found. Downloading and installing."
)
logger.info(f"Spacy model '{self.language}' not found. Downloading and installing.")
spacy_download(self.language)
spacy_tagger = spacy.load(self.language, exclude=["ner", "parser"])
self.spacy_tokenizer = spacy_tagger.tokenizer
Expand Down Expand Up @@ -583,9 +564,9 @@ def num_special_tokens(self) -> int:
int: the number of special tokens
"""
if isinstance(
self.huggingface_tokenizer, MODELS_WITH_DOUBLE_SEP
) and isinstance(self.huggingface_tokenizer, MODELS_WITH_STARTING_TOKEN):
if isinstance(self.huggingface_tokenizer, MODELS_WITH_DOUBLE_SEP) and isinstance(
self.huggingface_tokenizer, MODELS_WITH_STARTING_TOKEN
):
return 4
if isinstance(
self.huggingface_tokenizer,
Expand Down

0 comments on commit c1593d0

Please sign in to comment.