diff --git a/llama.cpp b/llama.cpp index 8b675ea993a38..49191aa347c89 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2368,6 +2368,11 @@ struct llama_context { // control vectors struct llama_control_vector cvec; + + // caching token pieces & their decoded codepoints. + std::mutex token_cache_mutex; + std::vector, llama_partial_utf8>> + token_codepoints_without_partial_utf8_prefix; }; static size_t llama_get_device_count(const llama_model & model) { @@ -14496,9 +14501,24 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c } } - std::vector, llama_partial_utf8>> candidates_decoded; - candidates_decoded.reserve(candidates->size); + { + // cache tokens & their decoded codepoints (for common case where there's no partial utf8 prefix bytes) for grammar-constrained sampling. + std::unique_lock lock(ctx->token_cache_mutex); + if (ctx->token_codepoints_without_partial_utf8_prefix.empty()) { + auto n_vocab = llama_n_vocab(llama_get_model(ctx)); + ctx->token_codepoints_without_partial_utf8_prefix.resize(n_vocab); + for (llama_token id = 0; id < n_vocab; ++id) { + const std::string & piece = ctx->model.vocab.cache_token_to_piece.at(id); + ctx->token_codepoints_without_partial_utf8_prefix[id] = decode_utf8(piece, {0, 0}); + } + } + } + // Store decoded codepoints when they are not cached (happens when there's a partial utf8 string prefix). + std::vector, llama_partial_utf8>> candidates_decoded; + if (grammar->partial_utf8.n_remain > 0) { + candidates_decoded.reserve(candidates->size); + } std::vector candidates_grammar; candidates_grammar.reserve(candidates->size); @@ -14512,6 +14532,9 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c } } else if (piece.empty() || piece[0] == 0) { candidates->data[i].logit = -INFINITY; + } else if (grammar->partial_utf8.n_remain == 0){ + const auto & decoded = ctx->token_codepoints_without_partial_utf8_prefix.at(id); + candidates_grammar.push_back({ i, decoded.first.data(), decoded.second }); } else { candidates_decoded.push_back(decode_utf8(piece, grammar->partial_utf8)); candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second }); @@ -14707,7 +14730,9 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar const std::string & piece = ctx->model.vocab.cache_token_to_piece.at(token); // Note terminating 0 in decoded string - const auto decoded = decode_utf8(piece, grammar->partial_utf8); + const auto decoded = grammar->partial_utf8.n_remain == 0 + ? ctx->token_codepoints_without_partial_utf8_prefix[token] + : decode_utf8(piece, grammar->partial_utf8); const auto & code_points = decoded.first; std::vector> tmp_new_stacks; for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {