Skip to content

Commit

Permalink
removed torchtext
Browse files Browse the repository at this point in the history
  • Loading branch information
m.habedank committed Oct 28, 2024
1 parent 8e90e70 commit 92a3ec0
Showing 1 changed file with 25 additions and 317 deletions.
342 changes: 25 additions & 317 deletions ludwig/utils/tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,16 @@

import logging
from abc import abstractmethod
from typing import Any, Dict, List, Optional, Union
from typing import Any, List, Union

import torch
import torchtext

from ludwig.constants import PADDING_SYMBOL, UNKNOWN_SYMBOL
from ludwig.utils.data_utils import load_json
from ludwig.utils.hf_utils import load_pretrained_hf_tokenizer
from ludwig.utils.nlp_utils import load_nlp_pipeline, process_text

logger = logging.getLogger(__name__)
torchtext_version = torch.torch_version.TorchVersion(torchtext.__version__)

TORCHSCRIPT_COMPATIBLE_TOKENIZERS = {"space", "space_punct", "comma", "underscore", "characters"}
TORCHTEXT_0_12_0_TOKENIZERS = {"sentencepiece", "clip", "gpt2bpe"}
TORCHTEXT_0_13_0_TOKENIZERS = {"bert"}

HF_TOKENIZER_SAMPLE_INPUTS = ["UNwant\u00E9d,running", "ah\u535A\u63A8zz", " \tHeLLo!how \n Are yoU? [UNK]"]


class BaseTokenizer:
Expand Down Expand Up @@ -913,7 +905,7 @@ def convert_token_to_id(self, token: str) -> int:


tokenizer_registry = {
# Torchscript-compatible tokenizers. Torchtext tokenizers are also available below (requires torchtext>=0.12.0).
# Torchscript-compatible tokenizers.
"space": SpaceStringToListTokenizer,
"space_punct": SpacePunctuationStringToListTokenizer,
"ngram": NgramTokenizer,
Expand Down Expand Up @@ -1021,231 +1013,40 @@ def convert_token_to_id(self, token: str) -> int:
"multi_lemmatize_remove_stopwords": MultiLemmatizeRemoveStopwordsTokenizer,
}

"""torchtext 0.12.0 tokenizers.
Only available with torchtext>=0.12.0.
"""


class SentencePieceTokenizer(torch.nn.Module):
def __init__(self, **kwargs):
super().__init__()
self.tokenizer = load_pretrained_hf_tokenizer("FacebookAI/xlm-roberta-base")

def forward(self, v: Union[str, List[str], torch.Tensor]):
if isinstance(v, torch.Tensor):
raise ValueError(f"Unsupported input: {v}")
return self.tokenizer.tokenize(v)


class _BPETokenizer(torch.nn.Module):
"""Superclass for tokenizers that use BPE, such as CLIPTokenizer and GPT2BPETokenizer."""

def __init__(self, pretrained_model_name_or_path: str, vocab_file: str):
super().__init__()
self.str2idx, self.idx2str = self._init_vocab(vocab_file)
self.tokenizer = self._init_tokenizer(pretrained_model_name_or_path, vocab_file)

def _init_vocab(self, vocab_file: str) -> Dict[str, str]:
"""Loads the vocab from the vocab file."""
str2idx = load_json(torchtext.utils.get_asset_local_path(vocab_file))
_, idx2str = zip(*sorted((v, k) for k, v in str2idx.items()))
return str2idx, idx2str

def _init_tokenizer(self, pretrained_model_name_or_path: str, vocab_file: str) -> Any:
"""Initializes and returns the tokenizer."""
raise NotImplementedError

def forward(self, v: Union[str, List[str], torch.Tensor]) -> Any:
"""Implements forward pass for tokenizer.
BPE tokenizers from torchtext return ids directly, which is inconsistent with the Ludwig tokenizer API. The
below implementation works around this by converting the ids back to their original string tokens.
"""
if isinstance(v, torch.Tensor):
raise ValueError(f"Unsupported input: {v}")

inputs: List[str] = []
# Ludwig calls map on List[str] objects, so we need to handle individual strings as well.
if isinstance(v, str):
inputs.append(v)
else:
inputs.extend(v)

token_ids = self.tokenizer(inputs)
assert torch.jit.isinstance(token_ids, List[List[str]])

tokens = [[self.idx2str[int(unit_idx)] for unit_idx in sequence] for sequence in token_ids]
return tokens[0] if isinstance(v, str) else tokens

def get_vocab(self) -> Dict[str, str]:
return self.str2idx

class HFTokenizerShortcutFactory:
"""This factory can be used to build HuggingFace tokenizers form a shortcut string.
class CLIPTokenizer(_BPETokenizer):
def __init__(self, pretrained_model_name_or_path: Optional[str] = None, vocab_file: Optional[str] = None, **kwargs):
if pretrained_model_name_or_path is None:
pretrained_model_name_or_path = "http://download.pytorch.org/models/text/clip_merges.bpe"
if vocab_file is None:
vocab_file = "http://download.pytorch.org/models/text/clip_encoder.json"
super().__init__(pretrained_model_name_or_path, vocab_file)

def _init_tokenizer(self, pretrained_model_name_or_path: str, vocab_file: str):
return torchtext.transforms.CLIPTokenizer(
encoder_json_path=vocab_file, merges_path=pretrained_model_name_or_path
)


class GPT2BPETokenizer(_BPETokenizer):
def __init__(self, pretrained_model_name_or_path: Optional[str] = None, vocab_file: Optional[str] = None, **kwargs):
if pretrained_model_name_or_path is None:
pretrained_model_name_or_path = "https://download.pytorch.org/models/text/gpt2_bpe_vocab.bpe"
if vocab_file is None:
vocab_file = "https://download.pytorch.org/models/text/gpt2_bpe_encoder.json"
super().__init__(pretrained_model_name_or_path, vocab_file)

def _init_tokenizer(self, pretrained_model_name_or_path: str, vocab_file: str):
return torchtext.transforms.GPT2BPETokenizer(
encoder_json_path=vocab_file, vocab_bpe_path=pretrained_model_name_or_path
)

Those shortcuts were originally used for torchtext tokenizers. They also guarantee backward compatibility.
"""

tokenizer_registry.update(
{
"sentencepiece": SentencePieceTokenizer,
"clip": CLIPTokenizer,
"gpt2bpe": GPT2BPETokenizer,
MODELS = {
"sentencepiece": "FacebookAI/xlm-roberta-base",
"clip": "openai/clip-vit-base-patch32",
"gpt2bpe": "openai-community/gpt2",
"bert": "bert-base-uncased",
}
)
TORCHSCRIPT_COMPATIBLE_TOKENIZERS.update(TORCHTEXT_0_12_0_TOKENIZERS)

@classmethod
def create_class(cls, model_name: str):
"""Creating a tokenizer class from a model name."""

class BERTTokenizer(torch.nn.Module):
def __init__(
self,
vocab_file: Optional[str] = None,
is_hf_tokenizer: Optional[bool] = False,
hf_tokenizer_attrs: Optional[Dict[str, Any]] = None,
**kwargs,
):
super().__init__()

if vocab_file is None:
# If vocab_file not passed in, use default "bert-base-uncased" vocab and kwargs.
kwargs = _get_bert_config("bert-base-uncased")
vocab_file = kwargs["vocab_file"]
vocab = self._init_vocab(vocab_file)
hf_tokenizer_attrs = {
"pad_token": "[PAD]",
"unk_token": "[UNK]",
"sep_token_id": vocab["[SEP]"],
"cls_token_id": vocab["[CLS]"],
}
else:
vocab = self._init_vocab(vocab_file)

self.vocab = vocab

self.is_hf_tokenizer = is_hf_tokenizer
if self.is_hf_tokenizer:
# Values used by Ludwig extracted from the corresponding HF model.
self.pad_token = hf_tokenizer_attrs["pad_token"] # Used as padding symbol
self.unk_token = hf_tokenizer_attrs["unk_token"] # Used as unknown symbol
self.cls_token_id = hf_tokenizer_attrs["cls_token_id"] # Used as start symbol. Only used if HF.
self.sep_token_id = hf_tokenizer_attrs["sep_token_id"] # Used as stop symbol. Only used if HF.
self.never_split = hf_tokenizer_attrs["all_special_tokens"]
else:
self.pad_token = PADDING_SYMBOL
self.unk_token = UNKNOWN_SYMBOL
self.cls_token_id = None
self.sep_token_id = None
self.never_split = [UNKNOWN_SYMBOL]

tokenizer_kwargs = {}
if "do_lower_case" in kwargs:
tokenizer_kwargs["do_lower_case"] = kwargs["do_lower_case"]
if "strip_accents" in kwargs:
tokenizer_kwargs["strip_accents"] = kwargs["strip_accents"]

# Return tokens as raw tokens only if not being used as a HF tokenizer.
self.return_tokens = not self.is_hf_tokenizer

tokenizer_init_kwargs = {
**tokenizer_kwargs,
"vocab_path": vocab_file,
"return_tokens": self.return_tokens,
}
if torchtext_version >= (0, 14, 0):
# never_split kwarg added in torchtext 0.14.0
tokenizer_init_kwargs["never_split"] = self.never_split

self.tokenizer = torchtext.transforms.BERTTokenizer(**tokenizer_init_kwargs)

def _init_vocab(self, vocab_file: str) -> Dict[str, int]:
from transformers.models.bert.tokenization_bert import load_vocab

return load_vocab(vocab_file)

def forward(self, v: Union[str, List[str], torch.Tensor]) -> Any:
"""Implements forward pass for tokenizer.
If the is_hf_tokenizer flag is set to True, then the output follows the HF convention, i.e. the output is an
List[List[int]] of tokens and the cls and sep tokens are automatically added as the start and stop symbols.
If the is_hf_tokenizer flag is set to False, then the output follows the Ludwig convention, i.e. the output
is a List[List[str]] of tokens.
"""
if isinstance(v, torch.Tensor):
raise ValueError(f"Unsupported input: {v}")

inputs: List[str] = []
# Ludwig calls map on List[str] objects, so we need to handle individual strings as well.
if isinstance(v, str):
inputs.append(v)
else:
inputs.extend(v)

if self.is_hf_tokenizer:
token_ids_str = self.tokenizer(inputs)
assert torch.jit.isinstance(token_ids_str, List[List[str]])
# Must cast token_ids to ints because they are used directly as indices.
token_ids: List[List[int]] = []
for token_ids_str_i in token_ids_str:
token_ids_i = [int(token_id_str) for token_id_str in token_ids_str_i]
token_ids_i = self._add_special_token_ids(token_ids_i)
token_ids.append(token_ids_i)
return token_ids[0] if isinstance(v, str) else token_ids

tokens = self.tokenizer(inputs)
assert torch.jit.isinstance(tokens, List[List[str]])
return tokens[0] if isinstance(v, str) else tokens
class DynamicHFTokenizer(torch.nn.Module):
def __init__(self, **kwargs):
super().__init__()
self.tokenizer = load_pretrained_hf_tokenizer(model_name, use_fast=False)

def get_vocab(self) -> Dict[str, int]:
return self.vocab
def forward(self, v: Union[str, List[str], torch.Tensor]):
if isinstance(v, torch.Tensor):
raise ValueError(f"Unsupported input: {v}")
return self.tokenizer.tokenize(v)

def get_pad_token(self) -> str:
return self.pad_token

def get_unk_token(self) -> str:
return self.unk_token

def _add_special_token_ids(self, token_ids: List[int]) -> List[int]:
"""Adds special token ids to the token_ids list."""
if torch.jit.isinstance(self.cls_token_id, int) and torch.jit.isinstance(self.sep_token_id, int):
token_ids.insert(0, self.cls_token_id)
token_ids.append(self.sep_token_id)
return token_ids

def convert_token_to_id(self, token: str) -> int:
return self.vocab[token]
return DynamicHFTokenizer


tokenizer_registry.update(
{
"bert": BERTTokenizer,
}
{name: HFTokenizerShortcutFactory.create_class(model) for name, model in HFTokenizerShortcutFactory.MODELS.items()}
)
TORCHSCRIPT_COMPATIBLE_TOKENIZERS.update(TORCHTEXT_0_13_0_TOKENIZERS)


def get_hf_tokenizer(pretrained_model_name_or_path, **kwargs):
Expand All @@ -1256,82 +1057,8 @@ def get_hf_tokenizer(pretrained_model_name_or_path, **kwargs):
Returns:
A torchscript-able HF tokenizer if it is available. Else, returns vanilla HF tokenizer.
"""
from transformers import BertTokenizer, DistilBertTokenizer, ElectraTokenizer

# HuggingFace has implemented a DO Repeat Yourself policy for models
# https://github.com/huggingface/transformers/issues/19303
# We now need to manually track BERT-like tokenizers to map onto the TorchText implementation
# until PyTorch improves TorchScript to be able to compile HF tokenizers. This would require
# 1. Support for string inputs for torch.jit.trace, or
# 2. Support for `kwargs` in torch.jit.script
# This is populated in the `get_hf_tokenizer` since the set requires `transformers` to be installed
HF_BERTLIKE_TOKENIZER_CLS_SET = {BertTokenizer, DistilBertTokenizer, ElectraTokenizer}

hf_name = pretrained_model_name_or_path
# use_fast=False to leverage python class inheritance
# cannot tokenize HF tokenizers directly because HF lacks strict typing and List[str] cannot be traced
hf_tokenizer = load_pretrained_hf_tokenizer(hf_name, use_fast=False)

torchtext_tokenizer = None
if "bert" in TORCHSCRIPT_COMPATIBLE_TOKENIZERS and any(
isinstance(hf_tokenizer, cls) for cls in HF_BERTLIKE_TOKENIZER_CLS_SET
):
tokenizer_kwargs = _get_bert_config(hf_name)
torchtext_tokenizer = BERTTokenizer(
**tokenizer_kwargs,
is_hf_tokenizer=True,
hf_tokenizer_attrs={
"pad_token": hf_tokenizer.pad_token,
"unk_token": hf_tokenizer.unk_token,
"cls_token_id": hf_tokenizer.cls_token_id,
"sep_token_id": hf_tokenizer.sep_token_id,
"all_special_tokens": hf_tokenizer.all_special_tokens,
},
)

use_torchtext = torchtext_tokenizer is not None
if use_torchtext:
# If a torchtext tokenizer is instantiable, tenatively we will use it. However,
# if the tokenizer does not pass (lightweight) validation, then we will fall back to the vanilla HF tokenizer.
# TODO(geoffrey): can we better validate tokenizer parity before swapping in the TorchText tokenizer?
# Samples from https://github.com/huggingface/transformers/blob/main/tests/models/bert/test_tokenization_bert.py
for sample_input in HF_TOKENIZER_SAMPLE_INPUTS:
hf_output = hf_tokenizer.encode(sample_input)
tt_output = torchtext_tokenizer(sample_input)
if hf_output != tt_output:
use_torchtext = False
logger.warning("Falling back to HuggingFace tokenizer because TorchText tokenizer failed validation.")
logger.warning(f"Sample input: {sample_input}\nHF output: {hf_output}\nTT output: {tt_output}")
break

if use_torchtext:
logger.info(f"Loaded TorchText implementation of {hf_name} tokenizer")
return torchtext_tokenizer
else:
# If hf_name does not have a torchtext equivalent implementation, load the
# HuggingFace implementation.
logger.info(f"Loaded HuggingFace implementation of {hf_name} tokenizer")
return HFTokenizer(hf_name)


def _get_bert_config(hf_name):
"""Gets configs from BERT tokenizers in HuggingFace.
`vocab_file` is required for BERT tokenizers. `tokenizer_config.json` are optional keyword arguments used to
initialize the tokenizer object. If no `tokenizer_config.json` is found, then we instantiate the tokenizer with
default arguments.
"""
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import EntryNotFoundError

vocab_file = hf_hub_download(repo_id=hf_name, filename="vocab.txt")

try:
tokenizer_config = load_json(hf_hub_download(repo_id=hf_name, filename="tokenizer_config.json"))
except EntryNotFoundError:
tokenizer_config = {}

return {"vocab_file": vocab_file, **tokenizer_config}
return HFTokenizer(pretrained_model_name_or_path)


tokenizer_registry.update(
Expand All @@ -1349,24 +1076,5 @@ def get_tokenizer_from_registry(tokenizer_name: str) -> torch.nn.Module:
"""
if tokenizer_name in tokenizer_registry:
return tokenizer_registry[tokenizer_name]

if (
torch.torch_version.TorchVersion(torchtext.__version__) < (0, 12, 0)
and tokenizer_name in TORCHTEXT_0_12_0_TOKENIZERS
):
raise KeyError(
f"torchtext>=0.12.0 is not installed, so '{tokenizer_name}' and the following tokenizers are not "
f"available: {TORCHTEXT_0_12_0_TOKENIZERS}"
)

if (
torch.torch_version.TorchVersion(torchtext.__version__) < (0, 13, 0)
and tokenizer_name in TORCHTEXT_0_13_0_TOKENIZERS
):
raise KeyError(
f"torchtext>=0.13.0 is not installed, so '{tokenizer_name}' and the following tokenizers are not "
f"available: {TORCHTEXT_0_13_0_TOKENIZERS}"
)

# Tokenizer does not exist or is unavailable.
raise KeyError(f"Invalid tokenizer name: '{tokenizer_name}'. Available tokenizers: {tokenizer_registry.keys()}")

0 comments on commit 92a3ec0

Please sign in to comment.