Skip to content

Commit

Permalink
Merge pull request #283 from Modalities/fix_only_append_eod_token_once
Browse files Browse the repository at this point in the history
Bug fix: Only append eod token once when packing / tokenizing
  • Loading branch information
le1nux authored Jan 13, 2025
2 parents e0b4274 + 4648f1b commit 13f1a26
Show file tree
Hide file tree
Showing 9 changed files with 439 additions and 11 deletions.
12 changes: 11 additions & 1 deletion CHANGELOG_DEV.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,4 +153,14 @@ https://github.com/Modalities/modalities/blob/0483362abac93e45850e56adaea7921e96

I added a switch case that maps to the respective byte sizes, when packing the data.

This adds some inefficiencies as a vobabulary size > 65536 already requires 4 bytes per token, effectively doubling the storage requirements.
This adds some inefficiencies as a vobabulary size > 65536 already requires 4 bytes per token, effectively doubling the storage requirements.


## PR #283: Bug fix: Only append eod token once when packing / tokenizing

Some HF tokenisers such as `xlm-roberta-large` add special tokens (e.g., eod token) automatically when encoding text, whereas others, such as `gpt2`, do not add special tokens.

This side-effect in the transformers library has lead to the eod token being appended twice when tokenizing / packing our data. We added a check for this and only append the eod token once now:
https://github.com/Modalities/modalities/blob/1c1ccdc973283c45bc8c9fadf4d20f03e435cd04/src/modalities/dataloader/create_packed_data.py#L327-L330

Additionally, I added a script that verifies the consistency of the indexation and tokenization of a given JSONL file. We run the indexation and tokenization routines in modalities and compare it to tokenized JSONL file to which we applied the HF tokenizer directly.
Binary file not shown.
9 changes: 6 additions & 3 deletions src/modalities/dataloader/create_packed_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ def __init__(
self.tokenizer = tokenizer
self.eod_token = eod_token
self._token_size_in_bytes = self._get_required_num_of_bytes_to_repr(self.tokenizer.vocab_size)
encoded_eod_token = self.tokenizer.get_token_id(self.eod_token)
self._encoded_eos_token_as_bytes = self._encoded_token_to_bytes(encoded_eod_token)
eod_token_id = self.tokenizer.get_token_id(self.eod_token)
self._encoded_eod_token_as_bytes = self._encoded_token_to_bytes(eod_token_id)
self.jq_filter = jq.compile(jq_pattern)
self._number_of_processes = number_of_processes
self._reader = LargeFileLinesReader(src_path, index_path=index_path) # reads string with utf-8 encoding
Expand Down Expand Up @@ -323,7 +323,10 @@ def _process_line(self, line: str, process_id: int) -> bytes:
tokens = self.tokenizer.tokenize(jq_retrieved_text)
if len(tokens) == 0:
raise EmptySampleError("Received empty sample...")
return b"".join(map(self._encoded_token_to_bytes, tokens)) + self._encoded_eos_token_as_bytes
token_byte_string = b"".join(map(self._encoded_token_to_bytes, tokens))
if not token_byte_string.endswith(self._encoded_eod_token_as_bytes):
token_byte_string = token_byte_string + self._encoded_eod_token_as_bytes
return token_byte_string


class EmbeddedStreamData:
Expand Down
57 changes: 52 additions & 5 deletions src/modalities/tokenization/tokenizer_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from abc import ABC
from typing import Optional

Expand Down Expand Up @@ -62,6 +63,20 @@ def get_token_id(self, token: str) -> int:
"""
raise NotImplementedError

def is_special_token_id(self, token_id: int) -> bool:
"""Returns whether a token ID is a special token ID.
Args:
token_id (int): Token ID to check.
Raises:
NotImplementedError: Must be implemented by a subclass.
Returns:
bool: Flag whether the token ID is a special token ID.
"""
raise NotImplementedError


class PreTrainedHFTokenizer(TokenizerWrapper):
"""Wrapper for pretrained Hugging Face tokenizers."""
Expand Down Expand Up @@ -102,6 +117,7 @@ def __init__(
self.max_length = max_length
self.truncation = truncation
self.padding = padding
self.special_token_ids = set(self.tokenizer.all_special_ids)

@property
def vocab_size(self) -> int:
Expand Down Expand Up @@ -163,10 +179,25 @@ def get_token_id(self, token: str) -> int:
int: Token ID.
"""
token_id = self.tokenizer.convert_tokens_to_ids(token)
if isinstance(token_id, list):
if not isinstance(token_id, int):
raise ValueError("Token is not represented by a single token id!")
if token_id is None:
raise ValueError("Token is not represented by a single token id!")
elif token_id == self.tokenizer.unk_token_id:
warnings.warn(f"The provided eod token {token} has the same token id ({token_id}) as the unk token")
return token_id

def is_special_token_id(self, token_id: int) -> bool:
"""Returns whether a token ID is a special token ID.
Args:
token_id (int): Token ID to check.
Returns:
bool: Flag whether the token ID is a special token ID.
"""
return token_id in self.special_token_ids


class PreTrainedSPTokenizer(TokenizerWrapper):
"""Wrapper for pretrained SentencePiece tokenizers."""
Expand All @@ -189,8 +220,8 @@ def tokenize(self, text: str) -> list[int]:
Returns:
list[int]: List of token IDs.
"""
tokens = self.tokenizer.encode(text)
return tokens
token_ids = self.tokenizer.Encode(text)
return token_ids

def decode(self, token_ids: list[int]) -> str:
"""Decodes a list of token IDs into the original text.
Expand All @@ -201,7 +232,7 @@ def decode(self, token_ids: list[int]) -> str:
Returns:
str: Decoded text.
"""
decoded_text = self.tokenizer.decode(token_ids)
decoded_text = self.tokenizer.Decode(token_ids)
return decoded_text

@property
Expand All @@ -226,6 +257,22 @@ def get_token_id(self, token: str) -> int:
int: Token ID.
"""
piece_id = self.tokenizer.PieceToId(token)
if not isinstance(piece_id, int):
raise ValueError("Token cannot be represented by a single token ID!")
if piece_id == self.tokenizer.unk_id():
raise ValueError("Token is not represented by a single token id!")
raise ValueError("Token cannot be represented by a single token id!")
return piece_id

def is_special_token_id(self, token_id: int) -> bool:
"""Returns whether a token ID is a special token ID.
Args:
token_id (int): Token ID to check.
Raises:
NotImplementedError: Must be implemented by a subclass.
Returns:
bool: Flag whether the token ID is a special token ID.
"""
return self.tokenizer.IsControl(token_id)
205 changes: 205 additions & 0 deletions src/modalities/utils/verify_tokenization_consistency.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
import json
import os
import pickle
import tempfile
import warnings
from enum import Enum
from pathlib import Path
from typing import Callable

import sentencepiece as spm
import tqdm
from transformers import AutoTokenizer

from modalities.api import create_raw_data_index, pack_encoded_data
from modalities.dataloader.dataset import PackedMemMapDatasetBase


class TokenizerTypes(Enum):
sentence_piece = "sentence_piece"
hugging_face = "hugging_face"


def _run_tokenization(
src_path: Path, index_path: Path, pbin_path: Path, eod_token: str, tokenizer_config: dict, jq_pattern: str = ".text"
):
# create index
create_raw_data_index(src_path=src_path, index_path=index_path)
# run tokenization
num_cpus = os.cpu_count()

tokenization_config_dict = {
"settings": {
"src_path": src_path,
"dst_path": pbin_path,
"index_path": index_path,
"jq_pattern": jq_pattern,
"num_cpus": num_cpus,
"eod_token": eod_token,
"processing_batch_size": 10,
"raw_samples_queue_size": 300,
"processed_samples_queue_size": 300,
},
"tokenizer": {**tokenizer_config},
}

pack_encoded_data(config_dict=tokenization_config_dict)


def _verify_index(src_path: Path, index_path: Path):
with open(src_path, "rb") as f:
jsonl_binary_string = f.read()

with open(src_path, "rb") as f:
binary_string_list = f.readlines()

with open(src_path, "r", encoding="utf-8") as f:
string_list = f.readlines()

with open(index_path, "rb") as f:
jsonl_index = pickle.load(f)

assert (
len(jsonl_binary_string.split(b"\n")) - int(jsonl_binary_string.endswith(b"\n"))
== len(binary_string_list)
== len(string_list)
== len(jsonl_index)
)

for i, (offset, length) in tqdm.tqdm(enumerate(jsonl_index), desc="Verifying index"):
# check that the index works correctly on the binary data
binary_string = binary_string_list[i]
if binary_string.endswith(b"\n"):
binary_string = binary_string[:-1]
assert jsonl_binary_string[offset : offset + length] == binary_string

# check that string when encoded with utf-8 matches the binary data
string = string_list[i]
if string.endswith("\n"):
string = string[:-1]
assert jsonl_binary_string[offset : offset + length] == string.encode("utf-8")


def _verify_pbin(
src_path: Path,
pbin_path: Path,
eod_token_id: int,
tokenizer: Callable[[str], list[int]],
jsonl_text_key: str,
):
dataset = PackedMemMapDatasetBase(raw_data_path=pbin_path, sample_key="text", load_index=True)

with open(src_path, "r", encoding="utf-8") as f:
string_list = f.readlines()
string_list_tokenized = [tokenizer(json.loads(string)[jsonl_text_key]) for string in string_list]

for i in tqdm.tqdm(range(len(dataset)), desc="Verifying pbin"):
pbin_sample = dataset[i]["text"]
recomputed_sample = string_list_tokenized[i]

# make sure that only the last token is the eod token
# and that the second last token is not the eod token
assert pbin_sample[-1] == eod_token_id
assert pbin_sample[-2] != eod_token_id

# we need to check if tokenizer adds the eod token as
# some tokenizers don't add the eod token at the end of the string
# whereas modalities always adds the eod token at the end of the string
if recomputed_sample[-1] != eod_token_id:
if i == 0:
warnings.warn("The tokenizer does not add the eod token at the end of the string!")
assert len(pbin_sample) - 1 == len(recomputed_sample)
assert all(pbin_sample[:-1] == recomputed_sample)
else:
assert len(pbin_sample) == len(recomputed_sample)
assert all(pbin_sample == recomputed_sample)


def build_hf_tokenization_components(tokenizer_path_or_name: str, eod_token: str):
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path_or_name)

def tokenizer_callable(text: str) -> list[int]:
return tokenizer(text, add_special_tokens=True, max_length=51200000, padding=False, truncation=False)[
"input_ids"
]

tokenizer_config = {
"component_key": "tokenizer",
"variant_key": "pretrained_hf_tokenizer",
"config": {
"pretrained_model_name_or_path": tokenizer_path_or_name,
"padding": False,
"max_length": 51200000,
},
}

eod_token_id = tokenizer.convert_tokens_to_ids(eod_token)
return tokenizer_callable, tokenizer_config, eod_token_id


def build_sp_tokenization_components(tokenizer_path: Path, eod_token: str):
tokenizer = spm.SentencePieceProcessor()
tokenizer.Load(tokenizer_path)

def tokenizer_callable(text: str) -> list[int]:
return tokenizer.Encode(text)

tokenizer_config = {
"component_key": "tokenizer",
"variant_key": "pretrained_sp_tokenizer",
"config": {
"tokenizer_model_file": tokenizer_path,
},
}

eod_token_id = tokenizer.PieceToId(eod_token)
return tokenizer_callable, tokenizer_config, eod_token_id


def verify_tokenization_consistency(
src_path: Path,
eod_token: str,
eod_token_id: int,
tokenizer: Callable[[str], list[int]],
tokenizer_config: dict,
jsonl_text_key: str,
):
"""Verifies that the indexation and tokenization is consistent.
This function applies the indexation and tokenization routines and then verifies
that the index always captures entire samples and that the tokens in the JSON
are correctly determined.
For an example verification check out the test_end_to_end_indexation_and_tokenization_consistency test
Args:
src_path (Path): Path to the JSONL file
eod_token (str): end of document token
eod_token_id (int): The token id of the end of document token
tokenizer (Callable[[str], list[int]]): Callable executing the tokenization
tokenizer_config (dict): Tokenizer config (same as used in the tokenization entry point)
jsonl_text_key (str): The key mapping to the text of interest in each JSON file
"""
# run indeaxing and tokenization
with tempfile.TemporaryDirectory() as tmp_dir:
index_path = Path(tmp_dir) / "index.idx"
pbin_path = Path(tmp_dir) / "data.pbin"
_run_tokenization(
src_path=src_path,
index_path=index_path,
pbin_path=pbin_path,
eod_token=eod_token,
tokenizer_config=tokenizer_config,
jq_pattern=f".{jsonl_text_key}",
)

# verify the index
_verify_index(src_path=src_path, index_path=index_path)
print("Index verified")
# verify the tokenized data
_verify_pbin(
src_path=src_path,
pbin_path=pbin_path,
eod_token_id=eod_token_id,
tokenizer=tokenizer,
jsonl_text_key=jsonl_text_key,
)
print("Tokenization verified")
Loading

0 comments on commit 13f1a26

Please sign in to comment.