Skip to content

Commit

Permalink
Refactor: Add llamma_ prefix in unicode.h unicode.cpp
Browse files Browse the repository at this point in the history
  • Loading branch information
MichelleTanPY committed Dec 14, 2024
1 parent ba1cb19 commit 0579e3b
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 35 deletions.
39 changes: 20 additions & 19 deletions src/unicode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,8 @@ uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset) {
// return result;
//}

static std::vector<codepoint_flags> unicode_cpt_flags_array() {
std::vector<codepoint_flags> cpt_flags(MAX_CODEPOINTS, codepoint_flags::UNDEFINED);
static std::vector<llama_codepoint_flags> unicode_cpt_flags_array() {
std::vector<llama_codepoint_flags> cpt_flags(MAX_CODEPOINTS, llama_codepoint_flags::LLAMA_UNDEFINED);

assert (unicode_ranges_flags.begin()[0].first == 0);
assert (unicode_ranges_flags.begin()[unicode_ranges_flags.size()-1].first == MAX_CODEPOINTS);
Expand Down Expand Up @@ -253,8 +253,8 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & t
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE;
};

auto _get_flags = [&] (const size_t pos) -> codepoint_flags {
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : codepoint_flags{};
auto _get_flags = [&](const size_t pos) -> llama_codepoint_flags {
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : llama_codepoint_flags{};
};

size_t _prev_end = offset_ini;
Expand Down Expand Up @@ -371,8 +371,8 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE;
};

auto _get_flags = [&] (const size_t pos) -> codepoint_flags {
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : codepoint_flags{};
auto _get_flags = [&](const size_t pos) -> llama_codepoint_flags {
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : llama_codepoint_flags{};
};

size_t _prev_end = offset_ini;
Expand Down Expand Up @@ -624,14 +624,14 @@ std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8) {
return result;
}

codepoint_flags unicode_cpt_flags(const uint32_t cp) {
static const codepoint_flags undef(codepoint_flags::UNDEFINED);
llama_codepoint_flags unicode_cpt_flags(const uint32_t cp) {
static const llama_codepoint_flags undef(llama_codepoint_flags::LLAMA_UNDEFINED);
static const auto cpt_flags = unicode_cpt_flags_array();
return cp < cpt_flags.size() ? cpt_flags[cp] : undef;
}

codepoint_flags unicode_cpt_flags(const std::string & utf8) {
static const codepoint_flags undef(codepoint_flags::UNDEFINED);
llama_codepoint_flags unicode_cpt_flags(const std::string & utf8) {
static const llama_codepoint_flags undef(llama_codepoint_flags::LLAMA_UNDEFINED);
if (utf8.empty()) {
return undef; // undefined
}
Expand Down Expand Up @@ -664,21 +664,22 @@ uint32_t unicode_tolower(uint32_t cp) {
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs) {
// unicode categories
static const std::map<std::string, int> k_ucat_enum = {
{ "\\p{N}", codepoint_flags::NUMBER },
{ "\\p{L}", codepoint_flags::LETTER },
{ "\\p{P}", codepoint_flags::PUNCTUATION },
{ "\\p{N}", llama_codepoint_flags::LLAMA_NUMBER },
{ "\\p{L}", llama_codepoint_flags::LLAMA_LETTER },
{ "\\p{P}", llama_codepoint_flags::LLAMA_PUNCTUATION },
};

static const std::map<int, int> k_ucat_cpt = {
{ codepoint_flags::NUMBER, 0xD1 },
{ codepoint_flags::LETTER, 0xD2 },
{ codepoint_flags::PUNCTUATION, 0xD3 },
{ llama_codepoint_flags::LLAMA_NUMBER, 0xD1 },
{ llama_codepoint_flags::LLAMA_LETTER, 0xD2 },
{ llama_codepoint_flags::LLAMA_PUNCTUATION, 0xD3 },
};

static const std::map<int, std::string> k_ucat_map = {
{ codepoint_flags::NUMBER, "\x30-\x39" }, // 0-9
{ codepoint_flags::LETTER, "\x41-\x5A\x61-\x7A" }, // A-Za-z
{ codepoint_flags::PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\}
{ llama_codepoint_flags::LLAMA_NUMBER, "\x30-\x39" }, // 0-9
{ llama_codepoint_flags::LLAMA_LETTER, "\x41-\x5A\x61-\x7A" }, // A-Za-z
{ llama_codepoint_flags::LLAMA_PUNCTUATION,
"\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\}
};

// compute collapsed codepoints only if needed by at least one regex
Expand Down
30 changes: 14 additions & 16 deletions src/unicode.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,17 @@
#include <string>
#include <vector>

// TODO: prefix all symbols with "llama_"

struct codepoint_flags {
struct llama_codepoint_flags {
enum {
UNDEFINED = 0x0001,
NUMBER = 0x0002, // regex: \p{N}
LETTER = 0x0004, // regex: \p{L}
SEPARATOR = 0x0008, // regex: \p{Z}
ACCENT_MARK = 0x0010, // regex: \p{M}
PUNCTUATION = 0x0020, // regex: \p{P}
SYMBOL = 0x0040, // regex: \p{S}
CONTROL = 0x0080, // regex: \p{C}
MASK_CATEGORIES = 0x00FF,
LLAMA_UNDEFINED = 0x0001,
LLAMA_NUMBER = 0x0002, // regex: \p{N}
LLAMA_LETTER = 0x0004, // regex: \p{L}
LLAMA_SEPARATOR = 0x0008, // regex: \p{Z}
LLAMA_ACCENT_MARK = 0x0010, // regex: \p{M}
LLAMA_PUNCTUATION = 0x0020, // regex: \p{P}
LLAMA_SYMBOL = 0x0040, // regex: \p{S}
LLAMA_CONTROL = 0x0080, // regex: \p{C}
LLAMA_MASK_CATEGORIES = 0x00FF,
};

// codepoint type
Expand All @@ -35,7 +33,7 @@ struct codepoint_flags {
uint16_t is_nfd : 1;

// decode from uint16
inline codepoint_flags(const uint16_t flags=0) {
inline llama_codepoint_flags(const uint16_t flags = 0) {
*reinterpret_cast<uint16_t*>(this) = flags;
}

Expand All @@ -44,7 +42,7 @@ struct codepoint_flags {
}

inline uint16_t category_flag() const {
return this->as_uint() & MASK_CATEGORIES;
return this->as_uint() & LLAMA_MASK_CATEGORIES;
}
};

Expand All @@ -56,8 +54,8 @@ std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8);

std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & cpts);

codepoint_flags unicode_cpt_flags(const uint32_t cp);
codepoint_flags unicode_cpt_flags(const std::string & utf8);
llama_codepoint_flags unicode_cpt_flags(const uint32_t cp);
llama_codepoint_flags unicode_cpt_flags(const std::string & utf8);

std::string unicode_byte_to_utf8(uint8_t byte);
uint8_t unicode_utf8_to_byte(const std::string & utf8);
Expand Down

0 comments on commit 0579e3b

Please sign in to comment.