Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama : tokenizer unicode codepoint categories #8606

Open
wants to merge 30 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
3d16f64
Update bruteforce test:
Jul 20, 2024
5ceab90
Store all unicode codepoint categories
Jul 20, 2024
ba4bbbd
Reimplement 'codepoint_flags' as 'codepoint_categ'
Jul 20, 2024
8f9f05b
Update unicode data
Jul 20, 2024
2636cb6
Decode unicode data categories
Jul 20, 2024
23cf064
Replace 'codepoint_flags' with 'codepoint_categ'
Jul 20, 2024
ecebfc0
Update unicode data: sorted whitespaces
Jul 25, 2024
8c8e1af
Fix codepoint_categ return types
Jul 25, 2024
8f7d56e
Add unicode_data helper functions
Jul 25, 2024
1cd7ac0
Reimplement 'collapsed' unicode categories:
Jul 25, 2024
aeac342
Add more comments
Aug 4, 2024
8bd3749
Merge commit '978ba3d8' into tokenizer-codepoint-categs
Aug 4, 2024
85c59df
minor: remove trailing whitespaces and extra semicolons
Aug 5, 2024
735105e
Use GGML_ASSERT and GGML_ABORT
Aug 5, 2024
fd6d9b9
Update bruteforce test: fix pyright complaints
Aug 5, 2024
3b36703
Update bruteforce test:
Aug 5, 2024
d558c73
Binary constants are a C++14 feature
Aug 5, 2024
674f0fa
Fix copy/paste wrong variable
Aug 5, 2024
2ca3138
Fix compiler complaints
Aug 5, 2024
80f4123
Update bruteforce test: fix binary search
Aug 7, 2024
7afe6df
Unicode data whitespaces as ranges
Aug 7, 2024
c240638
Reimplement unicode_regex_split()
Aug 7, 2024
312c432
Remove invalid assert
Aug 13, 2024
b565148
Update codepoint_categ:
Aug 13, 2024
5a93d2e
Reimplement unicode_regex_split():
Aug 13, 2024
7ff916e
Original regex for 'tekken'
Aug 13, 2024
50e1b1e
Remove unused function
Aug 13, 2024
dcac747
Using 32bit wchar_t by default, uint32_t on Windows
Aug 13, 2024
b67c81d
Fix previous commit
Aug 13, 2024
db78320
Fix compiler complaints
Aug 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 60 additions & 66 deletions scripts/gen-unicode-data.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,53 +49,42 @@ def unicode_data_iter():
yield (cpt, cpt_lower, cpt_upper, categ, bidir)


# see definition in unicode.h
CODEPOINT_FLAG_UNDEFINED = 0x0001 #
CODEPOINT_FLAG_NUMBER = 0x0002 # \p{N}
CODEPOINT_FLAG_LETTER = 0x0004 # \p{L}
CODEPOINT_FLAG_SEPARATOR = 0x0008 # \p{Z}
CODEPOINT_FLAG_MARK = 0x0010 # \p{M}
CODEPOINT_FLAG_PUNCTUATION = 0x0020 # \p{P}
CODEPOINT_FLAG_SYMBOL = 0x0040 # \p{S}
CODEPOINT_FLAG_CONTROL = 0x0080 # \p{C}

UNICODE_CATEGORY_TO_FLAG = {
"Cn": CODEPOINT_FLAG_UNDEFINED, # Undefined
"Cc": CODEPOINT_FLAG_CONTROL, # Control
"Cf": CODEPOINT_FLAG_CONTROL, # Format
"Co": CODEPOINT_FLAG_CONTROL, # Private Use
"Cs": CODEPOINT_FLAG_CONTROL, # Surrrogate
"Ll": CODEPOINT_FLAG_LETTER, # Lowercase Letter
"Lm": CODEPOINT_FLAG_LETTER, # Modifier Letter
"Lo": CODEPOINT_FLAG_LETTER, # Other Letter
"Lt": CODEPOINT_FLAG_LETTER, # Titlecase Letter
"Lu": CODEPOINT_FLAG_LETTER, # Uppercase Letter
"L&": CODEPOINT_FLAG_LETTER, # Cased Letter
"Mc": CODEPOINT_FLAG_MARK, # Spacing Mark
"Me": CODEPOINT_FLAG_MARK, # Enclosing Mark
"Mn": CODEPOINT_FLAG_MARK, # Nonspacing Mark
"Nd": CODEPOINT_FLAG_NUMBER, # Decimal Number
"Nl": CODEPOINT_FLAG_NUMBER, # Letter Number
"No": CODEPOINT_FLAG_NUMBER, # Other Number
"Pc": CODEPOINT_FLAG_PUNCTUATION, # Connector Punctuation
"Pd": CODEPOINT_FLAG_PUNCTUATION, # Dash Punctuation
"Pe": CODEPOINT_FLAG_PUNCTUATION, # Close Punctuation
"Pf": CODEPOINT_FLAG_PUNCTUATION, # Final Punctuation
"Pi": CODEPOINT_FLAG_PUNCTUATION, # Initial Punctuation
"Po": CODEPOINT_FLAG_PUNCTUATION, # Other Punctuation
"Ps": CODEPOINT_FLAG_PUNCTUATION, # Open Punctuation
"Sc": CODEPOINT_FLAG_SYMBOL, # Currency Symbol
"Sk": CODEPOINT_FLAG_SYMBOL, # Modifier Symbol
"Sm": CODEPOINT_FLAG_SYMBOL, # Math Symbol
"So": CODEPOINT_FLAG_SYMBOL, # Other Symbol
"Zl": CODEPOINT_FLAG_SEPARATOR, # Line Separator
"Zp": CODEPOINT_FLAG_SEPARATOR, # Paragraph Separator
"Zs": CODEPOINT_FLAG_SEPARATOR, # Space Separator
# see codepoint_categ::from_index() in unicode.h
UNICODE_CATEGORY_TO_INDEX = {
"Cn": 0, # \p{Cn} Undefined
"Cc": 1, # \p{Cc} Control
"Cf": 2, # \p{Cf} Format
"Co": 3, # \p{Co} Private Use
"Cs": 4, # \p{Cs} Surrrogate
"Ll": 5, # \p{Ll} Lowercase Letter
"Lm": 6, # \p{Lm} Modifier Letter
"Lo": 7, # \p{Lo} Other Letter
"Lt": 8, # \p{Lt} Titlecase Letter
"Lu": 9, # \p{Lu} Uppercase Letter
"Mc": 10, # \p{Mc} Spacing Mark
"Me": 11, # \p{Me} Enclosing Mark
"Mn": 12, # \p{Mn} Nonspacing Mark
"Nd": 13, # \p{Nd} Decimal Number
"Nl": 14, # \p{Nl} Letter Number
"No": 15, # \p{No} Other Number
"Pc": 16, # \p{Pc} Connector Punctuation
"Pd": 17, # \p{Pd} Dash Punctuation
"Pe": 18, # \p{Pe} Close Punctuation
"Pf": 19, # \p{Pf} Final Punctuation
"Pi": 20, # \p{Pi} Initial Punctuation
"Po": 21, # \p{Po} Other Punctuation
"Ps": 22, # \p{Ps} Open Punctuation
"Sc": 23, # \p{Sc} Currency Symbol
"Sk": 24, # \p{Sk} Modifier Symbol
"Sm": 25, # \p{Sm} Math Symbol
"So": 26, # \p{So} Other Symbol
"Zl": 27, # \p{Zl} Line Separator
"Zp": 28, # \p{Zp} Paragraph Separator
"Zs": 29, # \p{Zs} Space Separator
}


codepoint_flags = array.array('H', [CODEPOINT_FLAG_UNDEFINED]) * MAX_CODEPOINTS
table_whitespace = []
codepoint_categs = array.array('B', [0]) * MAX_CODEPOINTS # Undefined
table_lowercase = []
table_uppercase = []
table_nfd = []
Expand All @@ -105,7 +94,7 @@ def unicode_data_iter():
char = chr(cpt)

# codepoint category flags
codepoint_flags[cpt] = UNICODE_CATEGORY_TO_FLAG[categ]
codepoint_categs[cpt] = UNICODE_CATEGORY_TO_INDEX[categ]

# lowercase conversion
if cpt_lower:
Expand All @@ -121,25 +110,31 @@ def unicode_data_iter():
table_nfd.append((cpt, norm))


# whitespaces, see "<White_Space>" https://www.unicode.org/Public/UCD/latest/ucd/PropList.txt
table_whitespace.extend(range(0x0009, 0x000D + 1))
table_whitespace.extend(range(0x2000, 0x200A + 1))
table_whitespace.extend([0x0020, 0x0085, 0x00A0, 0x1680, 0x2028, 0x2029, 0x202F, 0x205F, 0x3000])


# sort by codepoint
table_whitespace.sort()
table_lowercase.sort()
table_uppercase.sort()
table_nfd.sort()


# group ranges with same flags
ranges_flags: list[tuple[int, int]] = [(0, codepoint_flags[0])] # start, flags
for codepoint, flags in enumerate(codepoint_flags):
if flags != ranges_flags[-1][1]:
ranges_flags.append((codepoint, flags))
ranges_flags.append((MAX_CODEPOINTS, 0x0000))
# whitespaces, see "<White_Space>" https://www.unicode.org/Public/UCD/latest/ucd/PropList.txt
whitespace_ranges: list[tuple[int, int]] = [] # start, last
whitespace_ranges.append((0x0009, 0x000D))
whitespace_ranges.append((0x2000, 0x200A))
for whitespace in [0x0020, 0x0085, 0x00A0, 0x1680, 0x2028, 0x2029, 0x202F, 0x205F, 0x3000]:
whitespace_ranges.append((whitespace, whitespace))


# run length encoding, see unicode_cpt_category() in unicode.cpp
assert (max(UNICODE_CATEGORY_TO_INDEX.values()) < 32)
codepoint_categs_runs = [codepoint_categs[0]] # 5 bits categ + 11 bits length
for cpt, categ in enumerate(codepoint_categs[1:], 1):
prev = codepoint_categs_runs[-1]
if prev <= (0xFFFF - 32) and (prev & 31) == categ:
codepoint_categs_runs[-1] += 32 # increment run length
else:
codepoint_categs_runs.append(categ) # new run value
assert (codepoint_categs_runs[-1] < 0xFFFF)
assert (MAX_CODEPOINTS == sum((rle >> 5) + 1 for rle in codepoint_categs_runs))


# group ranges with same nfd
Expand All @@ -153,7 +148,7 @@ def unicode_data_iter():


# Generate 'unicode-data.cpp':
# python ./scripts//gen-unicode-data.py > unicode-data.cpp
# python ./scripts//gen-unicode-data.py > ./src/unicode-data.cpp

def out(line=""):
print(line, end='\n') # noqa
Expand All @@ -167,17 +162,16 @@ def out(line=""):
#include <cstdint>
#include <vector>
#include <unordered_map>
#include <unordered_set>
""")

out("const std::vector<std::pair<uint32_t, uint16_t>> unicode_ranges_flags = { // start, flags // last=next_start-1")
for codepoint, flags in ranges_flags:
out("{0x%06X, 0x%04X}," % (codepoint, flags))
out("const std::vector<uint16_t> unicode_rle_codepoints_categs = { // run length encoding, 5 bits categ + 11 bits length")
for rle in codepoint_categs_runs:
out("0x%04X," % rle)
out("};\n")

out("const std::unordered_set<uint32_t> unicode_set_whitespace = {")
for codepoint in table_whitespace:
out("0x%06X," % codepoint)
out("const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_whitespace = {")
for (start, last) in whitespace_ranges:
out("{0x%06X, 0x%06X}," % (start, last))
out("};\n")

out("const std::unordered_map<uint32_t, uint32_t> unicode_map_lowercase = {")
Expand Down
15 changes: 6 additions & 9 deletions src/llama-vocab.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -440,10 +440,8 @@ struct llm_tokenizer_bpe {
};
break;
case LLAMA_VOCAB_PRE_TYPE_TEKKEN:
// original regex from tokenizer.json
// "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
regex_exprs = {
"[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
"[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
};
break;
default:
Expand Down Expand Up @@ -701,22 +699,21 @@ struct llm_tokenizer_wpm {
std::vector<std::string> words(1, "");

for (const uint32_t cpt : cpts_nfd) {
const auto flags = unicode_cpt_flags(cpt);
const auto categ = unicode_cpt_category(cpt);

if (flags.is_whitespace) {
if (categ.is_whitespace()) {
if (words.back().size()) { // finish previous word if any
words.emplace_back();
}
continue;
}

assert (!flags.is_separator);
if (cpt == 0 || cpt == 0xFFFD || flags.is_control) {
if (cpt == 0 || cpt == 0xFFFD || categ.is_C()) {
continue;
}

const std::string s = unicode_cpt_to_utf8(unicode_tolower(cpt));
if (flags.is_punctuation || ( cpt < 0x7F && flags.is_symbol ) || is_chinese_char(cpt)) {
if (categ.is_P() || (cpt < 0x7F && categ.is_S()) || is_chinese_char(cpt)) {
if (words.back().size()) { // finish previous word if any
words.emplace_back();
}
Expand All @@ -734,7 +731,7 @@ struct llm_tokenizer_wpm {
return words;
}

static bool is_chinese_char(uint32_t cpt) {
static bool is_chinese_char(uint32_t cpt) { //TODO: move to unicode-data.cpp? unicode_cpt_category(cpt).is_chinese()?
return
(cpt >= 0x04E00 && cpt <= 0x09FFF) ||
(cpt >= 0x03400 && cpt <= 0x04DBF) ||
Expand Down
Loading
Loading