From c27d1232461f1c05915477731770b44f0a2a7ee7 Mon Sep 17 00:00:00 2001 From: "m.habedank" Date: Sat, 26 Oct 2024 01:24:22 +0200 Subject: [PATCH 1/4] Removed torchtext from NGramTokenizer --- ludwig/utils/tokenizers.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/ludwig/utils/tokenizers.py b/ludwig/utils/tokenizers.py index 6fb2d2117a4..41f761fdf28 100644 --- a/ludwig/utils/tokenizers.py +++ b/ludwig/utils/tokenizers.py @@ -140,9 +140,25 @@ def __init__(self, ngram_size: int = 2, **kwargs): self.n = ngram_size or 2 def get_tokens(self, tokens: List[str]) -> List[str]: - from torchtext.data.utils import ngrams_iterator + return list(self._ngrams_iterator(tokens, ngrams=self.n)) - return list(ngrams_iterator(tokens, ngrams=self.n)) + def _ngrams_iterator(self, token_list, ngrams): + """Return an iterator that yields the given tokens and their ngrams. This code is taken from + https://pytorch.org/text/stable/_modules/torchtext/data/utils.html#ngrams_iterator. + + Args: + token_list: A list of tokens + ngrams: the number of ngrams. + """ + + def _get_ngrams(n): + return zip(*[token_list[i:] for i in range(n)]) + + for x in token_list: + yield x + for n in range(2, ngrams + 1): + for x in _get_ngrams(n): + yield " ".join(x) class SpacePunctuationStringToListTokenizer(torch.nn.Module): From 8e90e70c0efd4c0ee4370f6029a31c755d406d70 Mon Sep 17 00:00:00 2001 From: "m.habedank" Date: Sat, 26 Oct 2024 02:51:41 +0200 Subject: [PATCH 2/4] Refactored SentencePieceTokenizer See: #4032 --- ludwig/utils/tokenizers.py | 8 +++----- tests/ludwig/utils/test_tokenizers.py | 14 +++++++++++++- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/ludwig/utils/tokenizers.py b/ludwig/utils/tokenizers.py index 41f761fdf28..a9005142115 100644 --- a/ludwig/utils/tokenizers.py +++ b/ludwig/utils/tokenizers.py @@ -1028,16 +1028,14 @@ def convert_token_to_id(self, token: str) -> int: class SentencePieceTokenizer(torch.nn.Module): - def __init__(self, pretrained_model_name_or_path: Optional[str] = None, **kwargs): + def __init__(self, **kwargs): super().__init__() - if pretrained_model_name_or_path is None: - pretrained_model_name_or_path = "https://download.pytorch.org/models/text/xlmr.sentencepiece.bpe.model" - self.tokenizer = torchtext.transforms.SentencePieceTokenizer(sp_model_path=pretrained_model_name_or_path) + self.tokenizer = load_pretrained_hf_tokenizer("FacebookAI/xlm-roberta-base") def forward(self, v: Union[str, List[str], torch.Tensor]): if isinstance(v, torch.Tensor): raise ValueError(f"Unsupported input: {v}") - return self.tokenizer(v) + return self.tokenizer.tokenize(v) class _BPETokenizer(torch.nn.Module): diff --git a/tests/ludwig/utils/test_tokenizers.py b/tests/ludwig/utils/test_tokenizers.py index 82f6d86bdff..ad7b6f732ce 100644 --- a/tests/ludwig/utils/test_tokenizers.py +++ b/tests/ludwig/utils/test_tokenizers.py @@ -4,7 +4,12 @@ import torch import torchtext -from ludwig.utils.tokenizers import EnglishLemmatizeFilterTokenizer, NgramTokenizer, StringSplitTokenizer +from ludwig.utils.tokenizers import ( + EnglishLemmatizeFilterTokenizer, + NgramTokenizer, + SentencePieceTokenizer, + StringSplitTokenizer, +) TORCHTEXT_0_14_0_HF_NAMES = [ "bert-base-uncased", @@ -85,3 +90,10 @@ def test_english_lemmatize_filter_tokenizer(): tokenizer = EnglishLemmatizeFilterTokenizer() tokens = tokenizer(inputs) assert len(tokens) > 0 + + +def test_sentence_piece_tokenizer(): + inputs = "This is a sentence. And this is another one." + tokenizer = SentencePieceTokenizer() + tokens = tokenizer(inputs) + assert tokens == ["▁This", "▁is", "▁a", "▁sentence", ".", "▁And", "▁this", "▁is", "▁another", "▁one", "."] From 92a3ec096572f613afc546cb514e534874d9e77c Mon Sep 17 00:00:00 2001 From: "m.habedank" Date: Mon, 28 Oct 2024 19:01:09 +0100 Subject: [PATCH 3/4] removed torchtext --- ludwig/utils/tokenizers.py | 342 +++---------------------------------- 1 file changed, 25 insertions(+), 317 deletions(-) diff --git a/ludwig/utils/tokenizers.py b/ludwig/utils/tokenizers.py index a9005142115..99cde68d51a 100644 --- a/ludwig/utils/tokenizers.py +++ b/ludwig/utils/tokenizers.py @@ -15,24 +15,16 @@ import logging from abc import abstractmethod -from typing import Any, Dict, List, Optional, Union +from typing import Any, List, Union import torch -import torchtext -from ludwig.constants import PADDING_SYMBOL, UNKNOWN_SYMBOL -from ludwig.utils.data_utils import load_json from ludwig.utils.hf_utils import load_pretrained_hf_tokenizer from ludwig.utils.nlp_utils import load_nlp_pipeline, process_text logger = logging.getLogger(__name__) -torchtext_version = torch.torch_version.TorchVersion(torchtext.__version__) TORCHSCRIPT_COMPATIBLE_TOKENIZERS = {"space", "space_punct", "comma", "underscore", "characters"} -TORCHTEXT_0_12_0_TOKENIZERS = {"sentencepiece", "clip", "gpt2bpe"} -TORCHTEXT_0_13_0_TOKENIZERS = {"bert"} - -HF_TOKENIZER_SAMPLE_INPUTS = ["UNwant\u00E9d,running", "ah\u535A\u63A8zz", " \tHeLLo!how \n Are yoU? [UNK]"] class BaseTokenizer: @@ -913,7 +905,7 @@ def convert_token_to_id(self, token: str) -> int: tokenizer_registry = { - # Torchscript-compatible tokenizers. Torchtext tokenizers are also available below (requires torchtext>=0.12.0). + # Torchscript-compatible tokenizers. "space": SpaceStringToListTokenizer, "space_punct": SpacePunctuationStringToListTokenizer, "ngram": NgramTokenizer, @@ -1021,231 +1013,40 @@ def convert_token_to_id(self, token: str) -> int: "multi_lemmatize_remove_stopwords": MultiLemmatizeRemoveStopwordsTokenizer, } -"""torchtext 0.12.0 tokenizers. - -Only available with torchtext>=0.12.0. -""" - - -class SentencePieceTokenizer(torch.nn.Module): - def __init__(self, **kwargs): - super().__init__() - self.tokenizer = load_pretrained_hf_tokenizer("FacebookAI/xlm-roberta-base") - - def forward(self, v: Union[str, List[str], torch.Tensor]): - if isinstance(v, torch.Tensor): - raise ValueError(f"Unsupported input: {v}") - return self.tokenizer.tokenize(v) - - -class _BPETokenizer(torch.nn.Module): - """Superclass for tokenizers that use BPE, such as CLIPTokenizer and GPT2BPETokenizer.""" - - def __init__(self, pretrained_model_name_or_path: str, vocab_file: str): - super().__init__() - self.str2idx, self.idx2str = self._init_vocab(vocab_file) - self.tokenizer = self._init_tokenizer(pretrained_model_name_or_path, vocab_file) - - def _init_vocab(self, vocab_file: str) -> Dict[str, str]: - """Loads the vocab from the vocab file.""" - str2idx = load_json(torchtext.utils.get_asset_local_path(vocab_file)) - _, idx2str = zip(*sorted((v, k) for k, v in str2idx.items())) - return str2idx, idx2str - - def _init_tokenizer(self, pretrained_model_name_or_path: str, vocab_file: str) -> Any: - """Initializes and returns the tokenizer.""" - raise NotImplementedError - - def forward(self, v: Union[str, List[str], torch.Tensor]) -> Any: - """Implements forward pass for tokenizer. - - BPE tokenizers from torchtext return ids directly, which is inconsistent with the Ludwig tokenizer API. The - below implementation works around this by converting the ids back to their original string tokens. - """ - if isinstance(v, torch.Tensor): - raise ValueError(f"Unsupported input: {v}") - - inputs: List[str] = [] - # Ludwig calls map on List[str] objects, so we need to handle individual strings as well. - if isinstance(v, str): - inputs.append(v) - else: - inputs.extend(v) - - token_ids = self.tokenizer(inputs) - assert torch.jit.isinstance(token_ids, List[List[str]]) - - tokens = [[self.idx2str[int(unit_idx)] for unit_idx in sequence] for sequence in token_ids] - return tokens[0] if isinstance(v, str) else tokens - - def get_vocab(self) -> Dict[str, str]: - return self.str2idx +class HFTokenizerShortcutFactory: + """This factory can be used to build HuggingFace tokenizers form a shortcut string. -class CLIPTokenizer(_BPETokenizer): - def __init__(self, pretrained_model_name_or_path: Optional[str] = None, vocab_file: Optional[str] = None, **kwargs): - if pretrained_model_name_or_path is None: - pretrained_model_name_or_path = "http://download.pytorch.org/models/text/clip_merges.bpe" - if vocab_file is None: - vocab_file = "http://download.pytorch.org/models/text/clip_encoder.json" - super().__init__(pretrained_model_name_or_path, vocab_file) - - def _init_tokenizer(self, pretrained_model_name_or_path: str, vocab_file: str): - return torchtext.transforms.CLIPTokenizer( - encoder_json_path=vocab_file, merges_path=pretrained_model_name_or_path - ) - - -class GPT2BPETokenizer(_BPETokenizer): - def __init__(self, pretrained_model_name_or_path: Optional[str] = None, vocab_file: Optional[str] = None, **kwargs): - if pretrained_model_name_or_path is None: - pretrained_model_name_or_path = "https://download.pytorch.org/models/text/gpt2_bpe_vocab.bpe" - if vocab_file is None: - vocab_file = "https://download.pytorch.org/models/text/gpt2_bpe_encoder.json" - super().__init__(pretrained_model_name_or_path, vocab_file) - - def _init_tokenizer(self, pretrained_model_name_or_path: str, vocab_file: str): - return torchtext.transforms.GPT2BPETokenizer( - encoder_json_path=vocab_file, vocab_bpe_path=pretrained_model_name_or_path - ) - + Those shortcuts were originally used for torchtext tokenizers. They also guarantee backward compatibility. + """ -tokenizer_registry.update( - { - "sentencepiece": SentencePieceTokenizer, - "clip": CLIPTokenizer, - "gpt2bpe": GPT2BPETokenizer, + MODELS = { + "sentencepiece": "FacebookAI/xlm-roberta-base", + "clip": "openai/clip-vit-base-patch32", + "gpt2bpe": "openai-community/gpt2", + "bert": "bert-base-uncased", } -) -TORCHSCRIPT_COMPATIBLE_TOKENIZERS.update(TORCHTEXT_0_12_0_TOKENIZERS) + @classmethod + def create_class(cls, model_name: str): + """Creating a tokenizer class from a model name.""" -class BERTTokenizer(torch.nn.Module): - def __init__( - self, - vocab_file: Optional[str] = None, - is_hf_tokenizer: Optional[bool] = False, - hf_tokenizer_attrs: Optional[Dict[str, Any]] = None, - **kwargs, - ): - super().__init__() - - if vocab_file is None: - # If vocab_file not passed in, use default "bert-base-uncased" vocab and kwargs. - kwargs = _get_bert_config("bert-base-uncased") - vocab_file = kwargs["vocab_file"] - vocab = self._init_vocab(vocab_file) - hf_tokenizer_attrs = { - "pad_token": "[PAD]", - "unk_token": "[UNK]", - "sep_token_id": vocab["[SEP]"], - "cls_token_id": vocab["[CLS]"], - } - else: - vocab = self._init_vocab(vocab_file) - - self.vocab = vocab - - self.is_hf_tokenizer = is_hf_tokenizer - if self.is_hf_tokenizer: - # Values used by Ludwig extracted from the corresponding HF model. - self.pad_token = hf_tokenizer_attrs["pad_token"] # Used as padding symbol - self.unk_token = hf_tokenizer_attrs["unk_token"] # Used as unknown symbol - self.cls_token_id = hf_tokenizer_attrs["cls_token_id"] # Used as start symbol. Only used if HF. - self.sep_token_id = hf_tokenizer_attrs["sep_token_id"] # Used as stop symbol. Only used if HF. - self.never_split = hf_tokenizer_attrs["all_special_tokens"] - else: - self.pad_token = PADDING_SYMBOL - self.unk_token = UNKNOWN_SYMBOL - self.cls_token_id = None - self.sep_token_id = None - self.never_split = [UNKNOWN_SYMBOL] - - tokenizer_kwargs = {} - if "do_lower_case" in kwargs: - tokenizer_kwargs["do_lower_case"] = kwargs["do_lower_case"] - if "strip_accents" in kwargs: - tokenizer_kwargs["strip_accents"] = kwargs["strip_accents"] - - # Return tokens as raw tokens only if not being used as a HF tokenizer. - self.return_tokens = not self.is_hf_tokenizer - - tokenizer_init_kwargs = { - **tokenizer_kwargs, - "vocab_path": vocab_file, - "return_tokens": self.return_tokens, - } - if torchtext_version >= (0, 14, 0): - # never_split kwarg added in torchtext 0.14.0 - tokenizer_init_kwargs["never_split"] = self.never_split - - self.tokenizer = torchtext.transforms.BERTTokenizer(**tokenizer_init_kwargs) - - def _init_vocab(self, vocab_file: str) -> Dict[str, int]: - from transformers.models.bert.tokenization_bert import load_vocab - - return load_vocab(vocab_file) - - def forward(self, v: Union[str, List[str], torch.Tensor]) -> Any: - """Implements forward pass for tokenizer. - - If the is_hf_tokenizer flag is set to True, then the output follows the HF convention, i.e. the output is an - List[List[int]] of tokens and the cls and sep tokens are automatically added as the start and stop symbols. - - If the is_hf_tokenizer flag is set to False, then the output follows the Ludwig convention, i.e. the output - is a List[List[str]] of tokens. - """ - if isinstance(v, torch.Tensor): - raise ValueError(f"Unsupported input: {v}") - - inputs: List[str] = [] - # Ludwig calls map on List[str] objects, so we need to handle individual strings as well. - if isinstance(v, str): - inputs.append(v) - else: - inputs.extend(v) - - if self.is_hf_tokenizer: - token_ids_str = self.tokenizer(inputs) - assert torch.jit.isinstance(token_ids_str, List[List[str]]) - # Must cast token_ids to ints because they are used directly as indices. - token_ids: List[List[int]] = [] - for token_ids_str_i in token_ids_str: - token_ids_i = [int(token_id_str) for token_id_str in token_ids_str_i] - token_ids_i = self._add_special_token_ids(token_ids_i) - token_ids.append(token_ids_i) - return token_ids[0] if isinstance(v, str) else token_ids - - tokens = self.tokenizer(inputs) - assert torch.jit.isinstance(tokens, List[List[str]]) - return tokens[0] if isinstance(v, str) else tokens + class DynamicHFTokenizer(torch.nn.Module): + def __init__(self, **kwargs): + super().__init__() + self.tokenizer = load_pretrained_hf_tokenizer(model_name, use_fast=False) - def get_vocab(self) -> Dict[str, int]: - return self.vocab + def forward(self, v: Union[str, List[str], torch.Tensor]): + if isinstance(v, torch.Tensor): + raise ValueError(f"Unsupported input: {v}") + return self.tokenizer.tokenize(v) - def get_pad_token(self) -> str: - return self.pad_token - - def get_unk_token(self) -> str: - return self.unk_token - - def _add_special_token_ids(self, token_ids: List[int]) -> List[int]: - """Adds special token ids to the token_ids list.""" - if torch.jit.isinstance(self.cls_token_id, int) and torch.jit.isinstance(self.sep_token_id, int): - token_ids.insert(0, self.cls_token_id) - token_ids.append(self.sep_token_id) - return token_ids - - def convert_token_to_id(self, token: str) -> int: - return self.vocab[token] + return DynamicHFTokenizer tokenizer_registry.update( - { - "bert": BERTTokenizer, - } + {name: HFTokenizerShortcutFactory.create_class(model) for name, model in HFTokenizerShortcutFactory.MODELS.items()} ) -TORCHSCRIPT_COMPATIBLE_TOKENIZERS.update(TORCHTEXT_0_13_0_TOKENIZERS) def get_hf_tokenizer(pretrained_model_name_or_path, **kwargs): @@ -1256,82 +1057,8 @@ def get_hf_tokenizer(pretrained_model_name_or_path, **kwargs): Returns: A torchscript-able HF tokenizer if it is available. Else, returns vanilla HF tokenizer. """ - from transformers import BertTokenizer, DistilBertTokenizer, ElectraTokenizer - - # HuggingFace has implemented a DO Repeat Yourself policy for models - # https://github.com/huggingface/transformers/issues/19303 - # We now need to manually track BERT-like tokenizers to map onto the TorchText implementation - # until PyTorch improves TorchScript to be able to compile HF tokenizers. This would require - # 1. Support for string inputs for torch.jit.trace, or - # 2. Support for `kwargs` in torch.jit.script - # This is populated in the `get_hf_tokenizer` since the set requires `transformers` to be installed - HF_BERTLIKE_TOKENIZER_CLS_SET = {BertTokenizer, DistilBertTokenizer, ElectraTokenizer} - - hf_name = pretrained_model_name_or_path - # use_fast=False to leverage python class inheritance - # cannot tokenize HF tokenizers directly because HF lacks strict typing and List[str] cannot be traced - hf_tokenizer = load_pretrained_hf_tokenizer(hf_name, use_fast=False) - - torchtext_tokenizer = None - if "bert" in TORCHSCRIPT_COMPATIBLE_TOKENIZERS and any( - isinstance(hf_tokenizer, cls) for cls in HF_BERTLIKE_TOKENIZER_CLS_SET - ): - tokenizer_kwargs = _get_bert_config(hf_name) - torchtext_tokenizer = BERTTokenizer( - **tokenizer_kwargs, - is_hf_tokenizer=True, - hf_tokenizer_attrs={ - "pad_token": hf_tokenizer.pad_token, - "unk_token": hf_tokenizer.unk_token, - "cls_token_id": hf_tokenizer.cls_token_id, - "sep_token_id": hf_tokenizer.sep_token_id, - "all_special_tokens": hf_tokenizer.all_special_tokens, - }, - ) - - use_torchtext = torchtext_tokenizer is not None - if use_torchtext: - # If a torchtext tokenizer is instantiable, tenatively we will use it. However, - # if the tokenizer does not pass (lightweight) validation, then we will fall back to the vanilla HF tokenizer. - # TODO(geoffrey): can we better validate tokenizer parity before swapping in the TorchText tokenizer? - # Samples from https://github.com/huggingface/transformers/blob/main/tests/models/bert/test_tokenization_bert.py - for sample_input in HF_TOKENIZER_SAMPLE_INPUTS: - hf_output = hf_tokenizer.encode(sample_input) - tt_output = torchtext_tokenizer(sample_input) - if hf_output != tt_output: - use_torchtext = False - logger.warning("Falling back to HuggingFace tokenizer because TorchText tokenizer failed validation.") - logger.warning(f"Sample input: {sample_input}\nHF output: {hf_output}\nTT output: {tt_output}") - break - - if use_torchtext: - logger.info(f"Loaded TorchText implementation of {hf_name} tokenizer") - return torchtext_tokenizer - else: - # If hf_name does not have a torchtext equivalent implementation, load the - # HuggingFace implementation. - logger.info(f"Loaded HuggingFace implementation of {hf_name} tokenizer") - return HFTokenizer(hf_name) - - -def _get_bert_config(hf_name): - """Gets configs from BERT tokenizers in HuggingFace. - - `vocab_file` is required for BERT tokenizers. `tokenizer_config.json` are optional keyword arguments used to - initialize the tokenizer object. If no `tokenizer_config.json` is found, then we instantiate the tokenizer with - default arguments. - """ - from huggingface_hub import hf_hub_download - from huggingface_hub.utils import EntryNotFoundError - - vocab_file = hf_hub_download(repo_id=hf_name, filename="vocab.txt") - - try: - tokenizer_config = load_json(hf_hub_download(repo_id=hf_name, filename="tokenizer_config.json")) - except EntryNotFoundError: - tokenizer_config = {} - return {"vocab_file": vocab_file, **tokenizer_config} + return HFTokenizer(pretrained_model_name_or_path) tokenizer_registry.update( @@ -1349,24 +1076,5 @@ def get_tokenizer_from_registry(tokenizer_name: str) -> torch.nn.Module: """ if tokenizer_name in tokenizer_registry: return tokenizer_registry[tokenizer_name] - - if ( - torch.torch_version.TorchVersion(torchtext.__version__) < (0, 12, 0) - and tokenizer_name in TORCHTEXT_0_12_0_TOKENIZERS - ): - raise KeyError( - f"torchtext>=0.12.0 is not installed, so '{tokenizer_name}' and the following tokenizers are not " - f"available: {TORCHTEXT_0_12_0_TOKENIZERS}" - ) - - if ( - torch.torch_version.TorchVersion(torchtext.__version__) < (0, 13, 0) - and tokenizer_name in TORCHTEXT_0_13_0_TOKENIZERS - ): - raise KeyError( - f"torchtext>=0.13.0 is not installed, so '{tokenizer_name}' and the following tokenizers are not " - f"available: {TORCHTEXT_0_13_0_TOKENIZERS}" - ) - # Tokenizer does not exist or is unavailable. raise KeyError(f"Invalid tokenizer name: '{tokenizer_name}'. Available tokenizers: {tokenizer_registry.keys()}") From 01f308ef523d40dc668c68bc1bfc364463114df2 Mon Sep 17 00:00:00 2001 From: "m.habedank" Date: Tue, 29 Oct 2024 22:46:55 +0100 Subject: [PATCH 4/4] Rewrote tests --- tests/ludwig/utils/test_tokenizers.py | 94 +++++++++++---------------- 1 file changed, 38 insertions(+), 56 deletions(-) diff --git a/tests/ludwig/utils/test_tokenizers.py b/tests/ludwig/utils/test_tokenizers.py index ad7b6f732ce..0fa8104ed10 100644 --- a/tests/ludwig/utils/test_tokenizers.py +++ b/tests/ludwig/utils/test_tokenizers.py @@ -1,64 +1,10 @@ -import os - -import pytest -import torch -import torchtext - from ludwig.utils.tokenizers import ( EnglishLemmatizeFilterTokenizer, + get_tokenizer_from_registry, NgramTokenizer, - SentencePieceTokenizer, StringSplitTokenizer, ) -TORCHTEXT_0_14_0_HF_NAMES = [ - "bert-base-uncased", - "distilbert-base-uncased", - "google/electra-small-discriminator", - "dbmdz/bert-base-italian-cased", # Community model - "nreimers/MiniLM-L6-H384-uncased", # Community model - "emilyalsentzer/Bio_ClinicalBERT", # Community model - "bionlp/bluebert_pubmed_mimic_uncased_L-12_H-768_A-12", # Community model -] - - -@pytest.mark.parametrize( - "pretrained_model_name_or_path", - [ - pytest.param( - model_name, - marks=[ - pytest.mark.skipif( - torch.torch_version.TorchVersion(torchtext.__version__) < (0, 14, 0), - reason="requires torchtext 0.14.0 or higher", - ), - ], - ) - for model_name in TORCHTEXT_0_14_0_HF_NAMES - ], -) -def test_bert_hf_tokenizer_parity(tmpdir, pretrained_model_name_or_path): - """Tests the BERTTokenizer implementation. - - Asserts both tokens and token IDs are the same by initializing the BERTTokenizer as a standalone tokenizer and as a - HF tokenizer. - """ - from ludwig.utils.tokenizers import get_hf_tokenizer, HFTokenizer - - inputs = "Hello, ``I'm'' ónë of 1,205,000 sentences!" - hf_tokenizer = HFTokenizer(pretrained_model_name_or_path) - torchtext_tokenizer = get_hf_tokenizer(pretrained_model_name_or_path) - - # Ensure that the tokenizer is scriptable - tokenizer_path = os.path.join(tmpdir, "tokenizer.pt") - torch.jit.script(torchtext_tokenizer).save(tokenizer_path) - torchtext_tokenizer = torch.jit.load(tokenizer_path) - - token_ids_expected = hf_tokenizer(inputs) - token_ids = torchtext_tokenizer(inputs) - - assert token_ids_expected == token_ids - def test_ngram_tokenizer(): inputs = "Hello, I'm a single sentence!" @@ -94,6 +40,42 @@ def test_english_lemmatize_filter_tokenizer(): def test_sentence_piece_tokenizer(): inputs = "This is a sentence. And this is another one." - tokenizer = SentencePieceTokenizer() + tokenizer = get_tokenizer_from_registry("sentencepiece")() tokens = tokenizer(inputs) assert tokens == ["▁This", "▁is", "▁a", "▁sentence", ".", "▁And", "▁this", "▁is", "▁another", "▁one", "."] + + +def test_clip_tokenizer(): + inputs = "This is a sentence. And this is another one." + tokenizer = get_tokenizer_from_registry("clip")() + tokens = tokenizer(inputs) + print(tokens) + assert tokens == [ + "this", + "is", + "a", + "sentence", + ".", + "and", + "this", + "is", + "another", + "one", + ".", + ] + + +def test_gpt2_bpe_tokenizer(): + inputs = "This is a sentence. And this is another one." + tokenizer = get_tokenizer_from_registry("gpt2bpe")() + tokens = tokenizer(inputs) + print(tokens) + assert tokens == ["This", "Ġis", "Ġa", "Ġsentence", ".", "ĠAnd", "Ġthis", "Ġis", "Ġanother", "Ġone", "."] + + +def test_bert_tokenizer(): + inputs = "This is a sentence. And this is another one." + tokenizer = get_tokenizer_from_registry("bert")() + tokens = tokenizer(inputs) + print(tokens) + assert tokens == ["this", "is", "a", "sentence", ".", "and", "this", "is", "another", "one", "."]