Skip to content

Commit

Permalink
grammars: early exit when no next_candidates to reject
Browse files Browse the repository at this point in the history
grammars: cache decoded tokens

grammars: faster llama_grammar_copy

grammars: fix bad merge

grammars: keep llama_grammar_copy non-quadratic optim for later

grammars: move token caches to llama_context

grammars: cache codepoints in llama_new_context_with_model

grammar: nit (layout)

grammars: nits (revert const grammar sig, fix comment)

Update llama.cpp

Co-authored-by: Clint Herron <[email protected]>

grammars: mutex-guarded lazy caching of token pieces in llama_sample_grammar

grammars: remove early exit --> ggerganov#7370
  • Loading branch information
ochafik committed May 21, 2024
1 parent 059031b commit 1fa5a4b
Showing 1 changed file with 33 additions and 5 deletions.
38 changes: 33 additions & 5 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2269,6 +2269,12 @@ struct llama_context {
// control vectors
struct llama_control_vector cvec;

// caching token pieces & their decoded codepoints.
std::mutex token_cache_mutex;
std::vector<std::string> token_pieces;
std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>>
token_codepoints_without_partial_utf8_prefix;

#ifdef GGML_USE_MPI
ggml_mpi_context * ctx_mpi = NULL;
#endif
Expand Down Expand Up @@ -13833,21 +13839,41 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c
}
}

{
// cache tokens & their decoded codepoints (for common case where there's no partial utf8 prefix bytes) for grammar-constrained sampling.
std::unique_lock<std::mutex> lock(ctx->token_cache_mutex);
if (ctx->token_pieces.empty()) {
auto n_vocab = llama_n_vocab(llama_get_model(ctx));
ctx->token_codepoints_without_partial_utf8_prefix.resize(n_vocab);
ctx->token_pieces.resize(n_vocab);
for (llama_token id = 0; id < n_vocab; ++id) {
const std::string piece = llama_token_to_piece(ctx, id, false);
ctx->token_pieces[id] = piece;
ctx->token_codepoints_without_partial_utf8_prefix[id] = decode_utf8(piece, {0, 0});
}
}
}

// Store decoded codepoints when they are not cached (happens when there's a partial utf8 string prefix).
std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded;
candidates_decoded.reserve(candidates->size);
if (grammar->partial_utf8.n_remain > 0) {
candidates_decoded.reserve(candidates->size);
}
std::vector<llama_grammar_candidate> candidates_grammar;
candidates_grammar.reserve(candidates->size);

for (size_t i = 0; i < candidates->size; ++i) {
const llama_token id = candidates->data[i].id;
const std::string piece = llama_token_to_piece(ctx, id, false);

const auto & piece = ctx->token_pieces[id];
if (llama_token_is_eog(&ctx->model, id)) {
if (!allow_eog) {
candidates->data[i].logit = -INFINITY;
}
} else if (piece.empty() || piece[0] == 0) {
candidates->data[i].logit = -INFINITY;
} else if (grammar->partial_utf8.n_remain == 0){
const auto & decoded = ctx->token_codepoints_without_partial_utf8_prefix.at(id);
candidates_grammar.push_back({ i, decoded.first.data(), decoded.second });
} else {
candidates_decoded.push_back(decode_utf8(piece, grammar->partial_utf8));
candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second });
Expand Down Expand Up @@ -14040,10 +14066,12 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
GGML_ASSERT(false);
}

const std::string piece = llama_token_to_piece(ctx, token, false);
const auto & piece = ctx->token_pieces.at(token);

// Note terminating 0 in decoded string
const auto decoded = decode_utf8(piece, grammar->partial_utf8);
const auto decoded = grammar->partial_utf8.n_remain == 0
? ctx->token_codepoints_without_partial_utf8_prefix[token]
: decode_utf8(piece, grammar->partial_utf8);
const auto & code_points = decoded.first;
std::vector<std::vector<const llama_grammar_element *>> tmp_new_stacks;
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
Expand Down

0 comments on commit 1fa5a4b

Please sign in to comment.