diff --git a/llama-util.h b/llama-util.h index 30a6c0eb53eac3..ba0eec70bdb433 100644 --- a/llama-util.h +++ b/llama-util.h @@ -544,166 +544,4 @@ struct llama_ctx_buffer { typedef llama_buffer llama_ctx_buffer; #endif -struct llama_trie_node { - llama_trie_node(): is_terminator(false) {} - - std::unordered_map<char, std::unique_ptr<llama_trie_node>> children; - bool is_terminator; -}; - -// Trie in C++. Creates a Trie out of a list of words. The trie is used to split on multiple delimiters in one pass -// Ported from: https://github.com/huggingface/transformers/blob/ee88ae59940fd4b2c8fc119373143d7a1175c651/src/transformers/tokenization_utils.py#L52 -struct llama_trie { -public: - llama_trie(): root_(new llama_trie_node()) {} - - void add(const std::string & word) { - if (word.empty()) { - return; - } - - llama_trie_node *ref = root_.get(); - for (char c : word) { - if (ref->children.find(c) == ref->children.end()) { - ref->children[c].reset(new llama_trie_node()); - } - ref = ref->children[c].get(); - } - ref->is_terminator = true; - } - - // Will look for the words added to the trie within `text`. Output is the boundaries of the words found. - // Note that this trie will match the longest possible word first! - std::vector<size_t> split(const std::string & text) const { - std::map<size_t, llama_trie_node*> states; - std::vector<size_t> offsets{0}; - - size_t skip = 0; - for (size_t current = 0; current < text.size(); current++) { - char current_char = text[current]; - if (skip > 0 && current < skip) { - // Prevents the lookahead for matching twice - // like extra_id_100 and id_100 - continue; - } - - // Whenever we found a match, we need to drop everything - // this is a greedy algorithm, it will match on the first found token - bool reset = false; - - // In this case, we already have partial matches (But unfinished) - for (auto state = states.begin(); state != states.end(); ) { - size_t start = state->first; - llama_trie_node *trie_pointer = state->second; - if (trie_pointer->is_terminator) { - // This is a final match, we need to reset and - // store the results in `offsets`. - - // Lookahead to match longest first - // Important in case of extra_id_1 vs extra_id_100 - // Here we are also actively looking for other earlier partial - // matches - // "[CLS]", "L", we need to match CLS even if L is special - size_t end = 0; - for (const auto & look : states) { - size_t lookstart = look.first; - llama_trie_node *looktrie_pointer = look.second; - size_t lookahead_index = 0; - if (lookstart > start) { - // This partial match is later, we can stop looking - break; - } - if (lookstart < start) { - // This partial match is earlier, the trie pointer - // was already updated, so index is + 1 - lookahead_index = current + 1; - end = current + 1; - } else { - // Here lookstart == start and - // looktrie_pointer == trie_pointer - // It wasn't updated yet so indices are current ones - lookahead_index = current; - end = current; - } - char next_char = lookahead_index < text.size() ? text[lookahead_index] : '\0'; - if (looktrie_pointer->is_terminator) { - start = lookstart; - end = lookahead_index; - skip = lookahead_index; - } - - auto looktrie_pointer_it = looktrie_pointer->children.find(next_char); - while (looktrie_pointer_it != looktrie_pointer->children.end()) { - looktrie_pointer = looktrie_pointer_it->second.get(); - lookahead_index++; - if (looktrie_pointer->is_terminator) { - start = lookstart; - end = lookahead_index; - skip = lookahead_index; - } - - if (lookahead_index == text.size()) { - // End of string - break; - } - next_char = text[lookahead_index]; - looktrie_pointer_it = looktrie_pointer->children.find(next_char); - } - } - - offsets.push_back(start); - offsets.push_back(end); - reset = true; - break; - } - - auto trie_pointer_it = trie_pointer->children.find(current_char); - if (trie_pointer_it != trie_pointer->children.end()) { - // The current character being looked at has a match within the trie - // update the pointer (it will be stored back into states later). - trie_pointer = trie_pointer_it->second.get(); - states[start] = trie_pointer; - ++state; - } else { - // The new character has not match in the trie, we need - // to stop keeping track of this partial match. - state = states.erase(state); - } - } - - if (reset) { - // Clear the full start (we found a real match) - states.clear(); - } - - // If this character is a starting character within the trie - // start keeping track of this partial match. - auto children_it = root_->children.find(current_char); - if (current >= skip && children_it != root_->children.end()) { - states[current] = children_it->second.get(); - } - } - - // We have a cut at the end with states. - for (const auto & state : states) { - size_t start = state.first; - llama_trie_node *trie_pointer = state.second; - if (trie_pointer->is_terminator) { - // This is a final match, we need to reset and - // store the results in `offsets`. - size_t end = text.size(); - offsets.push_back(start); - offsets.push_back(end); - break; - } - } - - offsets.push_back(text.size()); - return offsets; - } - -private: - std::unique_ptr<llama_trie_node> root_; -}; - #endif diff --git a/llama.cpp b/llama.cpp index 3a50090f828817..3b6d23eac572c4 100644 --- a/llama.cpp +++ b/llama.cpp @@ -278,12 +278,10 @@ struct llama_vocab { std::unordered_map<token, id> token_to_id; std::vector<token_score> id_to_token; - llama_trie special_token_trie; std::unordered_map<token, id> special_token_to_id; size_t max_special_token_length = 0; void add_special_token(const token & word, id token_id) { - special_token_trie.add(word); special_token_to_id[word] = token_id; if (max_special_token_length < word.size()) { @@ -2090,6 +2088,38 @@ struct llama_tokenizer { llama_sp_bigram::queue work_queue_; }; +static std::vector<size_t> llama_split_special_tokens(const llama_vocab & vocab, const std::string & text) { + std::vector<size_t> offsets{0}; + size_t start = 0; + + while (start < text.size()) { + size_t max_end = start; + const std::string * max_delimiter = nullptr; + + for (const auto & mit : vocab.special_token_to_id) { + const std::string & delimiter = mit.first; + size_t end = start + delimiter.size(); + if (end <= text.size() && text.compare(start, delimiter.size(), delimiter) == 0) { + if (max_delimiter == nullptr || delimiter.size() > max_delimiter->size()) { + max_end = end; + max_delimiter = &delimiter; + } + } + } + + if (max_delimiter != nullptr) { + offsets.push_back(start); + offsets.push_back(max_end); + start = max_end; + } else { + start++; + } + } + + offsets.push_back(text.size()); + return offsets; +} + static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, const std::string & text, bool bos) { llama_tokenizer tokenizer(vocab); std::vector<llama_vocab::id> output; @@ -2107,7 +2137,7 @@ static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, co return output; } - std::vector<size_t> offsets = vocab.special_token_trie.split(text); + std::vector<size_t> offsets = llama_split_special_tokens(vocab, text); size_t start = 0; for (size_t end : offsets) { if (start >= end) {