diff --git a/transformer_embedder/tokenizer.py b/transformer_embedder/tokenizer.py index e9c8457..e19781a 100644 --- a/transformer_embedder/tokenizer.py +++ b/transformer_embedder/tokenizer.py @@ -35,7 +35,9 @@ def __init__( self.config = tr.AutoConfig.from_pretrained(model) else: self.huggingface_tokenizer = model - self.config = tr.AutoConfig.from_pretrained(self.huggingface_tokenizer.name_or_path) + self.config = tr.AutoConfig.from_pretrained( + self.huggingface_tokenizer.name_or_path + ) # spacy tokenizer, lazy load. None at first self.spacy_tokenizer = None # default multilingual model @@ -129,10 +131,16 @@ def __call__( ) # if text is str or a list of str and they are not split, then text needs to be tokenized - if isinstance(text, str) or (not is_split_into_words and isinstance(text[0], str)): + if isinstance(text, str) or ( + not is_split_into_words and isinstance(text[0], str) + ): if not is_batched: text = self.pretokenize(text, use_spacy=use_spacy) - text_pair = self.pretokenize(text_pair, use_spacy=use_spacy) if text_pair else None + text_pair = ( + self.pretokenize(text_pair, use_spacy=use_spacy) + if text_pair + else None + ) else: text = [self.pretokenize(t, use_spacy=use_spacy) for t in text] text_pair = ( @@ -216,13 +224,17 @@ def build_tokens( Returns: a dictionary with A and B encoded """ - words, input_ids, token_type_ids, offsets = self._build_tokens(text, max_len=max_len) + words, input_ids, token_type_ids, offsets = self._build_tokens( + text, max_len=max_len + ) if text_pair: words_b, input_ids_b, token_type_ids_b, offsets_b = self._build_tokens( text_pair, True, max_len ) # align offsets of sentence b - offsets_b = [(o[0] + len(input_ids), o[1] + len(input_ids)) for o in offsets_b] + offsets_b = [ + (o[0] + len(input_ids), o[1] + len(input_ids)) for o in offsets_b + ] offsets = offsets + offsets_b input_ids += input_ids_b token_type_ids += token_type_ids_b @@ -290,7 +302,9 @@ def _build_tokens( token_type_ids += [token_type_id] return words, input_ids, token_type_ids, offsets - def pad_batch(self, batch: Dict[str, list], max_length: int = None) -> Dict[str, list]: + def pad_batch( + self, batch: Dict[str, list], max_length: int = None + ) -> Dict[str, list]: """ Pad the batch to its maximum length. @@ -376,7 +390,9 @@ def pretokenize(self, text: str, use_spacy: bool = False) -> List[str]: return [t.text for t in text] return text.split(" ") - def add_special_tokens(self, special_tokens_dict: Dict[str, Union[str, tr.AddedToken]]) -> int: + def add_special_tokens( + self, special_tokens_dict: Dict[str, Union[str, tr.AddedToken]] + ) -> int: """ Add a dictionary of special tokens (eos, pad, cls, etc.) to the encoder. If special tokens are NOT in the vocabulary, they are added to it (indexed starting from the last @@ -442,7 +458,8 @@ def to_tensor(self, batch: Union[List[dict], dict]) -> Dict[str, "torch.Tensor"] """ # convert to tensor batch = { - k: torch.as_tensor(v) if k in self.to_tensor_inputs else v for k, v in batch.items() + k: torch.as_tensor(v) if k in self.to_tensor_inputs else v + for k, v in batch.items() } return batch @@ -457,7 +474,9 @@ def _load_spacy(self) -> "spacy.tokenizer.Tokenizer": try: spacy_tagger = spacy.load(self.language, exclude=["ner", "parser"]) except OSError: - logger.info(f"Spacy model '{self.language}' not found. Downloading and installing.") + logger.info( + f"Spacy model '{self.language}' not found. Downloading and installing." + ) spacy_download(self.language) spacy_tagger = spacy.load(self.language, exclude=["ner", "parser"]) self.spacy_tokenizer = spacy_tagger.tokenizer @@ -564,9 +583,9 @@ def num_special_tokens(self) -> int: int: the number of special tokens """ - if isinstance(self.huggingface_tokenizer, MODELS_WITH_DOUBLE_SEP) and isinstance( - self.huggingface_tokenizer, MODELS_WITH_STARTING_TOKEN - ): + if isinstance( + self.huggingface_tokenizer, MODELS_WITH_DOUBLE_SEP + ) and isinstance(self.huggingface_tokenizer, MODELS_WITH_STARTING_TOKEN): return 4 if isinstance( self.huggingface_tokenizer,