Skip to content

Commit

Permalink
Replace trie with linear search
Browse files Browse the repository at this point in the history
  • Loading branch information
Igoorx committed Aug 8, 2023
1 parent 4fc3776 commit 863a440
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 165 deletions.
162 changes: 0 additions & 162 deletions llama-util.h
Original file line number Diff line number Diff line change
Expand Up @@ -544,166 +544,4 @@ struct llama_ctx_buffer {
typedef llama_buffer llama_ctx_buffer;
#endif

struct llama_trie_node {
llama_trie_node(): is_terminator(false) {}

std::unordered_map<char, std::unique_ptr<llama_trie_node>> children;
bool is_terminator;
};

// Trie in C++. Creates a Trie out of a list of words. The trie is used to split on multiple delimiters in one pass
// Ported from: https://github.com/huggingface/transformers/blob/ee88ae59940fd4b2c8fc119373143d7a1175c651/src/transformers/tokenization_utils.py#L52
struct llama_trie {
public:
llama_trie(): root_(new llama_trie_node()) {}

void add(const std::string & word) {
if (word.empty()) {
return;
}

llama_trie_node *ref = root_.get();
for (char c : word) {
if (ref->children.find(c) == ref->children.end()) {
ref->children[c].reset(new llama_trie_node());
}
ref = ref->children[c].get();
}
ref->is_terminator = true;
}

// Will look for the words added to the trie within `text`. Output is the boundaries of the words found.
// Note that this trie will match the longest possible word first!
std::vector<size_t> split(const std::string & text) const {
std::map<size_t, llama_trie_node*> states;
std::vector<size_t> offsets{0};

size_t skip = 0;
for (size_t current = 0; current < text.size(); current++) {
char current_char = text[current];
if (skip > 0 && current < skip) {
// Prevents the lookahead for matching twice
// like extra_id_100 and id_100
continue;
}

// Whenever we found a match, we need to drop everything
// this is a greedy algorithm, it will match on the first found token
bool reset = false;

// In this case, we already have partial matches (But unfinished)
for (auto state = states.begin(); state != states.end(); ) {
size_t start = state->first;
llama_trie_node *trie_pointer = state->second;
if (trie_pointer->is_terminator) {
// This is a final match, we need to reset and
// store the results in `offsets`.

// Lookahead to match longest first
// Important in case of extra_id_1 vs extra_id_100
// Here we are also actively looking for other earlier partial
// matches
// "[CLS]", "L", we need to match CLS even if L is special
size_t end = 0;
for (const auto & look : states) {
size_t lookstart = look.first;
llama_trie_node *looktrie_pointer = look.second;
size_t lookahead_index = 0;
if (lookstart > start) {
// This partial match is later, we can stop looking
break;
}
if (lookstart < start) {
// This partial match is earlier, the trie pointer
// was already updated, so index is + 1
lookahead_index = current + 1;
end = current + 1;
} else {
// Here lookstart == start and
// looktrie_pointer == trie_pointer
// It wasn't updated yet so indices are current ones
lookahead_index = current;
end = current;
}
char next_char = lookahead_index < text.size() ? text[lookahead_index] : '\0';
if (looktrie_pointer->is_terminator) {
start = lookstart;
end = lookahead_index;
skip = lookahead_index;
}

auto looktrie_pointer_it = looktrie_pointer->children.find(next_char);
while (looktrie_pointer_it != looktrie_pointer->children.end()) {
looktrie_pointer = looktrie_pointer_it->second.get();
lookahead_index++;
if (looktrie_pointer->is_terminator) {
start = lookstart;
end = lookahead_index;
skip = lookahead_index;
}

if (lookahead_index == text.size()) {
// End of string
break;
}
next_char = text[lookahead_index];
looktrie_pointer_it = looktrie_pointer->children.find(next_char);
}
}

offsets.push_back(start);
offsets.push_back(end);
reset = true;
break;
}

auto trie_pointer_it = trie_pointer->children.find(current_char);
if (trie_pointer_it != trie_pointer->children.end()) {
// The current character being looked at has a match within the trie
// update the pointer (it will be stored back into states later).
trie_pointer = trie_pointer_it->second.get();
states[start] = trie_pointer;
++state;
} else {
// The new character has not match in the trie, we need
// to stop keeping track of this partial match.
state = states.erase(state);
}
}

if (reset) {
// Clear the full start (we found a real match)
states.clear();
}

// If this character is a starting character within the trie
// start keeping track of this partial match.
auto children_it = root_->children.find(current_char);
if (current >= skip && children_it != root_->children.end()) {
states[current] = children_it->second.get();
}
}

// We have a cut at the end with states.
for (const auto & state : states) {
size_t start = state.first;
llama_trie_node *trie_pointer = state.second;
if (trie_pointer->is_terminator) {
// This is a final match, we need to reset and
// store the results in `offsets`.
size_t end = text.size();
offsets.push_back(start);
offsets.push_back(end);
break;
}
}

offsets.push_back(text.size());
return offsets;
}

private:
std::unique_ptr<llama_trie_node> root_;
};

#endif
36 changes: 33 additions & 3 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -278,12 +278,10 @@ struct llama_vocab {
std::unordered_map<token, id> token_to_id;
std::vector<token_score> id_to_token;

llama_trie special_token_trie;
std::unordered_map<token, id> special_token_to_id;
size_t max_special_token_length = 0;

void add_special_token(const token & word, id token_id) {
special_token_trie.add(word);
special_token_to_id[word] = token_id;

if (max_special_token_length < word.size()) {
Expand Down Expand Up @@ -2090,6 +2088,38 @@ struct llama_tokenizer {
llama_sp_bigram::queue work_queue_;
};

static std::vector<size_t> llama_split_special_tokens(const llama_vocab & vocab, const std::string & text) {
std::vector<size_t> offsets{0};
size_t start = 0;

while (start < text.size()) {
size_t max_end = start;
const std::string * max_delimiter = nullptr;

for (const auto & mit : vocab.special_token_to_id) {
const std::string & delimiter = mit.first;
size_t end = start + delimiter.size();
if (end <= text.size() && text.compare(start, delimiter.size(), delimiter) == 0) {
if (max_delimiter == nullptr || delimiter.size() > max_delimiter->size()) {
max_end = end;
max_delimiter = &delimiter;
}
}
}

if (max_delimiter != nullptr) {
offsets.push_back(start);
offsets.push_back(max_end);
start = max_end;
} else {
start++;
}
}

offsets.push_back(text.size());
return offsets;
}

static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, const std::string & text, bool bos) {
llama_tokenizer tokenizer(vocab);
std::vector<llama_vocab::id> output;
Expand All @@ -2107,7 +2137,7 @@ static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, co
return output;
}

std::vector<size_t> offsets = vocab.special_token_trie.split(text);
std::vector<size_t> offsets = llama_split_special_tokens(vocab, text);
size_t start = 0;
for (size_t end : offsets) {
if (start >= end) {
Expand Down

0 comments on commit 863a440

Please sign in to comment.