diff --git a/llama-util.h b/llama-util.h
index 30a6c0eb53eac3..ba0eec70bdb433 100644
--- a/llama-util.h
+++ b/llama-util.h
@@ -544,166 +544,4 @@ struct llama_ctx_buffer {
 typedef llama_buffer llama_ctx_buffer;
 #endif
 
-struct llama_trie_node {
-    llama_trie_node(): is_terminator(false) {}
-
-    std::unordered_map<char, std::unique_ptr<llama_trie_node>> children;
-    bool is_terminator;
-};
-
-// Trie in C++. Creates a Trie out of a list of words. The trie is used to split on multiple delimiters in one pass
-// Ported from: https://github.com/huggingface/transformers/blob/ee88ae59940fd4b2c8fc119373143d7a1175c651/src/transformers/tokenization_utils.py#L52
-struct llama_trie {
-public:
-    llama_trie(): root_(new llama_trie_node()) {}
-
-    void add(const std::string & word) {
-        if (word.empty()) {
-            return;
-        }
-
-        llama_trie_node *ref = root_.get();
-        for (char c : word) {
-            if (ref->children.find(c) == ref->children.end()) {
-                ref->children[c].reset(new llama_trie_node());
-            }
-            ref = ref->children[c].get();
-        }
-        ref->is_terminator = true;
-    }
-
-    // Will look for the words added to the trie within `text`. Output is the boundaries of the words found.
-    // Note that this trie will match the longest possible word first!
-    std::vector<size_t> split(const std::string & text) const {
-        std::map<size_t, llama_trie_node*> states;
-        std::vector<size_t> offsets{0};
-
-        size_t skip = 0;
-        for (size_t current = 0; current < text.size(); current++) {
-            char current_char = text[current];
-            if (skip > 0 && current < skip) {
-                // Prevents the lookahead for matching twice
-                // like extra_id_100 and id_100
-                continue;
-            }
-
-            // Whenever we found a match, we need to drop everything
-            // this is a greedy algorithm, it will match on the first found token
-            bool reset = false;
-
-            // In this case, we already have partial matches (But unfinished)
-            for (auto state = states.begin(); state != states.end(); ) {
-                size_t start = state->first;
-                llama_trie_node *trie_pointer = state->second;
-                if (trie_pointer->is_terminator) {
-                    // This is a final match, we need to reset and
-                    // store the results in `offsets`.
-
-                    // Lookahead to match longest first
-                    // Important in case of extra_id_1 vs extra_id_100
-                    // Here we are also actively looking for other earlier partial
-                    // matches
-                    // "[CLS]", "L", we need to match CLS even if L is special
-                    size_t end = 0;
-                    for (const auto & look : states) {
-                        size_t lookstart = look.first;
-                        llama_trie_node *looktrie_pointer = look.second;
-                        size_t lookahead_index = 0;
-                        if (lookstart > start) {
-                            // This partial match is later, we can stop looking
-                            break;
-                        }
-                        if (lookstart < start) {
-                            // This partial match is earlier, the trie pointer
-                            // was already updated, so index is + 1
-                            lookahead_index = current + 1;
-                            end = current + 1;
-                        } else {
-                            // Here lookstart == start and
-                            //      looktrie_pointer == trie_pointer
-                            // It wasn't updated yet so indices are current ones
-                            lookahead_index = current;
-                            end = current;
-                        }
-                        char next_char = lookahead_index < text.size() ? text[lookahead_index] : '\0';
-                        if (looktrie_pointer->is_terminator) {
-                            start = lookstart;
-                            end = lookahead_index;
-                            skip = lookahead_index;
-                        }
-
-                        auto looktrie_pointer_it = looktrie_pointer->children.find(next_char);
-                        while (looktrie_pointer_it != looktrie_pointer->children.end()) {
-                            looktrie_pointer = looktrie_pointer_it->second.get();
-                            lookahead_index++;
-                            if (looktrie_pointer->is_terminator) {
-                                start = lookstart;
-                                end = lookahead_index;
-                                skip = lookahead_index;
-                            }
-
-                            if (lookahead_index == text.size()) {
-                                // End of string
-                                break;
-                            }
-                            next_char = text[lookahead_index];
-                            looktrie_pointer_it = looktrie_pointer->children.find(next_char);
-                        }
-                    }
-
-                    offsets.push_back(start);
-                    offsets.push_back(end);
-                    reset = true;
-                    break;
-                }
-
-                auto trie_pointer_it = trie_pointer->children.find(current_char);
-                if (trie_pointer_it != trie_pointer->children.end()) {
-                    // The current character being looked at has a match within the trie
-                    // update the pointer (it will be stored back into states later).
-                    trie_pointer = trie_pointer_it->second.get();
-                    states[start] = trie_pointer;
-                    ++state;
-                } else {
-                    // The new character has not match in the trie, we need
-                    // to stop keeping track of this partial match.
-                    state = states.erase(state);
-                }
-            }
-
-            if (reset) {
-                // Clear the full start (we found a real match)
-                states.clear();
-            }
-
-            // If this character is a starting character within the trie
-            // start keeping track of this partial match.
-            auto children_it = root_->children.find(current_char);
-            if (current >= skip && children_it != root_->children.end()) {
-                states[current] = children_it->second.get();
-            }
-        }
-
-        // We have a cut at the end with states.
-        for (const auto & state : states) {
-            size_t start = state.first;
-            llama_trie_node *trie_pointer = state.second;
-            if (trie_pointer->is_terminator) {
-                // This is a final match, we need to reset and
-                // store the results in `offsets`.
-                size_t end = text.size();
-                offsets.push_back(start);
-                offsets.push_back(end);
-                break;
-            }
-        }
-
-        offsets.push_back(text.size());
-        return offsets;
-    }
-
-private:
-    std::unique_ptr<llama_trie_node> root_;
-};
-
 #endif
diff --git a/llama.cpp b/llama.cpp
index 3a50090f828817..3b6d23eac572c4 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -278,12 +278,10 @@ struct llama_vocab {
     std::unordered_map<token, id> token_to_id;
     std::vector<token_score> id_to_token;
 
-    llama_trie special_token_trie;
     std::unordered_map<token, id> special_token_to_id;
     size_t max_special_token_length = 0;
 
     void add_special_token(const token & word, id token_id) {
-        special_token_trie.add(word);
         special_token_to_id[word] = token_id;
 
         if (max_special_token_length < word.size()) {
@@ -2090,6 +2088,38 @@ struct llama_tokenizer {
     llama_sp_bigram::queue work_queue_;
 };
 
+static std::vector<size_t> llama_split_special_tokens(const llama_vocab & vocab, const std::string & text) {
+    std::vector<size_t> offsets{0};
+    size_t start = 0;
+
+    while (start < text.size()) {
+        size_t max_end = start;
+        const std::string * max_delimiter = nullptr;
+
+        for (const auto & mit : vocab.special_token_to_id) {
+            const std::string & delimiter = mit.first;
+            size_t end = start + delimiter.size();
+            if (end <= text.size() && text.compare(start, delimiter.size(), delimiter) == 0) {
+                if (max_delimiter == nullptr || delimiter.size() > max_delimiter->size()) {
+                    max_end = end;
+                    max_delimiter = &delimiter;
+                }
+            }
+        }
+
+        if (max_delimiter != nullptr) {
+            offsets.push_back(start);
+            offsets.push_back(max_end);
+            start = max_end;
+        } else {
+            start++;
+        }
+    }
+
+    offsets.push_back(text.size());
+    return offsets;
+}
+
 static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, const std::string & text, bool bos) {
     llama_tokenizer tokenizer(vocab);
     std::vector<llama_vocab::id> output;
@@ -2107,7 +2137,7 @@ static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, co
         return output;
     }
 
-    std::vector<size_t> offsets = vocab.special_token_trie.split(text);
+    std::vector<size_t> offsets = llama_split_special_tokens(vocab, text);
     size_t start = 0;
     for (size_t end : offsets) {
         if (start >= end) {