From 436787f170329b3f549e6c2c46593d2af8482e7c Mon Sep 17 00:00:00 2001 From: Justine Tunney Date: Sun, 25 Aug 2024 23:09:53 -0700 Subject: [PATCH] llama : fix time complexity of string replacement (#9163) This change fixes a bug where replacing text in a very long string could cause llama.cpp to hang indefinitely. This is because the algorithm used was quadratic, due to memmove() when s.replace() is called in a loop. It seems most search results and LLM responses actually provide the O(n**2) algorithm, which is a great tragedy. Using a builder string fixes things --- common/common.cpp | 16 +++++++++++----- examples/llava/clip.cpp | 16 +++++++++++----- src/llama-impl.h | 14 ++++++++++---- 3 files changed, 32 insertions(+), 14 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index d46df39b5ac3c..72859c9674418 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1861,13 +1861,19 @@ std::string string_get_sortable_timestamp() { void string_replace_all(std::string & s, const std::string & search, const std::string & replace) { if (search.empty()) { - return; // Avoid infinite loop if 'search' is an empty string + return; } + std::string builder; + builder.reserve(s.length()); size_t pos = 0; - while ((pos = s.find(search, pos)) != std::string::npos) { - s.replace(pos, search.length(), replace); - pos += replace.length(); - } + size_t last_pos = 0; + while ((pos = s.find(search, last_pos)) != std::string::npos) { + builder.append(s, last_pos, pos - last_pos); + builder.append(replace); + last_pos = pos + search.length(); + } + builder.append(s, last_pos, std::string::npos); + s = std::move(builder); } void string_process_escapes(std::string & input) { diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index 7e9fa320aec44..10e8765b4cd19 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -216,13 +216,19 @@ static std::string gguf_data_to_str(enum gguf_type type, const void * data, int static void replace_all(std::string & s, const std::string & search, const std::string & replace) { if (search.empty()) { - return; // Avoid infinite loop if 'search' is an empty string + return; } + std::string builder; + builder.reserve(s.length()); size_t pos = 0; - while ((pos = s.find(search, pos)) != std::string::npos) { - s.replace(pos, search.length(), replace); - pos += replace.length(); - } + size_t last_pos = 0; + while ((pos = s.find(search, last_pos)) != std::string::npos) { + builder.append(s, last_pos, pos - last_pos); + builder.append(replace); + last_pos = pos + search.length(); + } + builder.append(s, last_pos, std::string::npos); + s = std::move(builder); } static std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i) { diff --git a/src/llama-impl.h b/src/llama-impl.h index 399b134a7f9bc..9527740961da6 100644 --- a/src/llama-impl.h +++ b/src/llama-impl.h @@ -31,11 +31,17 @@ void llama_log_callback_default(ggml_log_level level, const char * text, void * static void replace_all(std::string & s, const std::string & search, const std::string & replace) { if (search.empty()) { - return; // Avoid infinite loop if 'search' is an empty string + return; } + std::string builder; + builder.reserve(s.length()); size_t pos = 0; - while ((pos = s.find(search, pos)) != std::string::npos) { - s.replace(pos, search.length(), replace); - pos += replace.length(); + size_t last_pos = 0; + while ((pos = s.find(search, last_pos)) != std::string::npos) { + builder.append(s, last_pos, pos - last_pos); + builder.append(replace); + last_pos = pos + search.length(); } + builder.append(s, last_pos, std::string::npos); + s = std::move(builder); }