diff --git a/src/unicode.cpp b/src/unicode.cpp index 3d459263525dc..f7ef7b8de1091 100644 --- a/src/unicode.cpp +++ b/src/unicode.cpp @@ -120,8 +120,8 @@ uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset) { // return result; //} -static std::vector unicode_cpt_flags_array() { - std::vector cpt_flags(MAX_CODEPOINTS, codepoint_flags::UNDEFINED); +static std::vector unicode_cpt_flags_array() { + std::vector cpt_flags(MAX_CODEPOINTS, llama_codepoint_flags::LLAMA_UNDEFINED); assert (unicode_ranges_flags.begin()[0].first == 0); assert (unicode_ranges_flags.begin()[unicode_ranges_flags.size()-1].first == MAX_CODEPOINTS); @@ -253,8 +253,8 @@ static std::vector unicode_regex_split_custom_gpt2(const std::string & t return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE; }; - auto _get_flags = [&] (const size_t pos) -> codepoint_flags { - return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : codepoint_flags{}; + auto _get_flags = [&](const size_t pos) -> llama_codepoint_flags { + return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : llama_codepoint_flags{}; }; size_t _prev_end = offset_ini; @@ -371,8 +371,8 @@ static std::vector unicode_regex_split_custom_llama3(const std::string & return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE; }; - auto _get_flags = [&] (const size_t pos) -> codepoint_flags { - return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : codepoint_flags{}; + auto _get_flags = [&](const size_t pos) -> llama_codepoint_flags { + return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : llama_codepoint_flags{}; }; size_t _prev_end = offset_ini; @@ -624,14 +624,14 @@ std::vector unicode_cpts_from_utf8(const std::string & utf8) { return result; } -codepoint_flags unicode_cpt_flags(const uint32_t cp) { - static const codepoint_flags undef(codepoint_flags::UNDEFINED); +llama_codepoint_flags unicode_cpt_flags(const uint32_t cp) { + static const llama_codepoint_flags undef(llama_codepoint_flags::LLAMA_UNDEFINED); static const auto cpt_flags = unicode_cpt_flags_array(); return cp < cpt_flags.size() ? cpt_flags[cp] : undef; } -codepoint_flags unicode_cpt_flags(const std::string & utf8) { - static const codepoint_flags undef(codepoint_flags::UNDEFINED); +llama_codepoint_flags unicode_cpt_flags(const std::string & utf8) { + static const llama_codepoint_flags undef(llama_codepoint_flags::LLAMA_UNDEFINED); if (utf8.empty()) { return undef; // undefined } @@ -664,21 +664,22 @@ uint32_t unicode_tolower(uint32_t cp) { std::vector unicode_regex_split(const std::string & text, const std::vector & regex_exprs) { // unicode categories static const std::map k_ucat_enum = { - { "\\p{N}", codepoint_flags::NUMBER }, - { "\\p{L}", codepoint_flags::LETTER }, - { "\\p{P}", codepoint_flags::PUNCTUATION }, + { "\\p{N}", llama_codepoint_flags::LLAMA_NUMBER }, + { "\\p{L}", llama_codepoint_flags::LLAMA_LETTER }, + { "\\p{P}", llama_codepoint_flags::LLAMA_PUNCTUATION }, }; static const std::map k_ucat_cpt = { - { codepoint_flags::NUMBER, 0xD1 }, - { codepoint_flags::LETTER, 0xD2 }, - { codepoint_flags::PUNCTUATION, 0xD3 }, + { llama_codepoint_flags::LLAMA_NUMBER, 0xD1 }, + { llama_codepoint_flags::LLAMA_LETTER, 0xD2 }, + { llama_codepoint_flags::LLAMA_PUNCTUATION, 0xD3 }, }; static const std::map k_ucat_map = { - { codepoint_flags::NUMBER, "\x30-\x39" }, // 0-9 - { codepoint_flags::LETTER, "\x41-\x5A\x61-\x7A" }, // A-Za-z - { codepoint_flags::PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\} + { llama_codepoint_flags::LLAMA_NUMBER, "\x30-\x39" }, // 0-9 + { llama_codepoint_flags::LLAMA_LETTER, "\x41-\x5A\x61-\x7A" }, // A-Za-z + { llama_codepoint_flags::LLAMA_PUNCTUATION, + "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\} }; // compute collapsed codepoints only if needed by at least one regex diff --git a/src/unicode.h b/src/unicode.h index 008532a242ab8..af7c2128d6aa4 100644 --- a/src/unicode.h +++ b/src/unicode.h @@ -4,19 +4,17 @@ #include #include -// TODO: prefix all symbols with "llama_" - -struct codepoint_flags { +struct llama_codepoint_flags { enum { - UNDEFINED = 0x0001, - NUMBER = 0x0002, // regex: \p{N} - LETTER = 0x0004, // regex: \p{L} - SEPARATOR = 0x0008, // regex: \p{Z} - ACCENT_MARK = 0x0010, // regex: \p{M} - PUNCTUATION = 0x0020, // regex: \p{P} - SYMBOL = 0x0040, // regex: \p{S} - CONTROL = 0x0080, // regex: \p{C} - MASK_CATEGORIES = 0x00FF, + LLAMA_UNDEFINED = 0x0001, + LLAMA_NUMBER = 0x0002, // regex: \p{N} + LLAMA_LETTER = 0x0004, // regex: \p{L} + LLAMA_SEPARATOR = 0x0008, // regex: \p{Z} + LLAMA_ACCENT_MARK = 0x0010, // regex: \p{M} + LLAMA_PUNCTUATION = 0x0020, // regex: \p{P} + LLAMA_SYMBOL = 0x0040, // regex: \p{S} + LLAMA_CONTROL = 0x0080, // regex: \p{C} + LLAMA_MASK_CATEGORIES = 0x00FF, }; // codepoint type @@ -35,7 +33,7 @@ struct codepoint_flags { uint16_t is_nfd : 1; // decode from uint16 - inline codepoint_flags(const uint16_t flags=0) { + inline llama_codepoint_flags(const uint16_t flags = 0) { *reinterpret_cast(this) = flags; } @@ -44,7 +42,7 @@ struct codepoint_flags { } inline uint16_t category_flag() const { - return this->as_uint() & MASK_CATEGORIES; + return this->as_uint() & LLAMA_MASK_CATEGORIES; } }; @@ -56,8 +54,8 @@ std::vector unicode_cpts_from_utf8(const std::string & utf8); std::vector unicode_cpts_normalize_nfd(const std::vector & cpts); -codepoint_flags unicode_cpt_flags(const uint32_t cp); -codepoint_flags unicode_cpt_flags(const std::string & utf8); +llama_codepoint_flags unicode_cpt_flags(const uint32_t cp); +llama_codepoint_flags unicode_cpt_flags(const std::string & utf8); std::string unicode_byte_to_utf8(uint8_t byte); uint8_t unicode_utf8_to_byte(const std::string & utf8);