Skip to content

Commit

Permalink
llama : redirect external API to internal APIs
Browse files Browse the repository at this point in the history
ggml-ci
  • Loading branch information
ggerganov committed Jul 19, 2024
1 parent 749e1c2 commit 8c5f2c2
Show file tree
Hide file tree
Showing 9 changed files with 838 additions and 519 deletions.
6 changes: 3 additions & 3 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ static llama_token llama_sampling_sample_impl(
llama_token_data_array single_token_data_array = { &single_token_data, 1, false };

// Apply grammar constraints to the single token
llama_grammar_sample(ctx_main, &single_token_data_array, ctx_sampling->grammar);
llama_grammar_sample(ctx_sampling->grammar, ctx_main, &single_token_data_array);

// Check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY
bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
Expand Down Expand Up @@ -421,7 +421,7 @@ static llama_token_data_array llama_sampling_prepare_impl(

// apply grammar checks before sampling logic
if (apply_grammar && ctx_sampling->grammar != NULL) {
llama_grammar_sample(ctx_main, &cur_p, ctx_sampling->grammar);
llama_grammar_sample(ctx_sampling->grammar, ctx_main, &cur_p);
}

return cur_p;
Expand Down Expand Up @@ -455,6 +455,6 @@ void llama_sampling_accept(
ctx_sampling->prev.push_back(id);

if (ctx_sampling->grammar != NULL && apply_grammar) {
llama_grammar_accept_token(ctx_main, ctx_sampling->grammar, id);
llama_grammar_accept_token(ctx_sampling->grammar, ctx_main, id);
}
}
14 changes: 9 additions & 5 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -962,6 +962,10 @@ extern "C" {
bool remove_special,
bool unparse_special);

//
// Chat templates
//

/// Apply chat template. Inspired by hf apply_chat_template() on python.
/// Both "model" and "custom_template" are optional, but at least one is required. "custom_template" has higher precedence than "model"
/// NOTE: This function does not use a jinja parser. It only support a pre-defined list of template. See more: https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template
Expand Down Expand Up @@ -1002,19 +1006,19 @@ extern "C" {

/// @details Apply constraints from grammar
LLAMA_API void llama_grammar_sample(
struct llama_context * ctx,
llama_token_data_array * candidates,
const struct llama_grammar * grammar);
LLAMA_API DEPRECATED(bool llama_sample_grammar(
const struct llama_grammar * grammar,
const struct llama_context * ctx,
llama_token_data_array * candidates);
LLAMA_API DEPRECATED(void llama_sample_grammar(
struct llama_context * ctx,
llama_token_data_array * candidates,
const struct llama_grammar * grammar),
"use llama_grammar_sample instead");

/// @details Accepts the sampled token into the grammar
LLAMA_API void llama_grammar_accept_token(
struct llama_context * ctx,
struct llama_grammar * grammar,
struct llama_context * ctx,
llama_token token);

//
Expand Down
26 changes: 14 additions & 12 deletions src/llama-grammar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ static bool llama_grammar_detect_left_recursion(
// grammar - external
//

struct llama_grammar * llama_grammar_init(
struct llama_grammar * llama_grammar_init_impl(
const llama_grammar_element ** rules,
size_t n_rules,
size_t start_rule_index) {
Expand Down Expand Up @@ -441,11 +441,11 @@ struct llama_grammar * llama_grammar_init(
return new llama_grammar{ std::move(vec_rules), std::move(stacks), {} };
}

void llama_grammar_free(struct llama_grammar * grammar) {
void llama_grammar_free_impl(struct llama_grammar * grammar) {
delete grammar;
}

struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar) {
struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * grammar) {
llama_grammar * result = new llama_grammar{ grammar->rules, grammar->stacks, grammar->partial_utf8 };

// redirect elements in stacks to point to new rules
Expand All @@ -464,8 +464,10 @@ struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar)
return result;
}

void llama_grammar_sample(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar) {
GGML_ASSERT(ctx);
void llama_grammar_sample(const struct llama_grammar * grammar, const struct llama_vocab * vocab, const struct llama_sampling * smpl, llama_token_data_array * candidates) {
GGML_ASSERT(grammar);
GGML_ASSERT(vocab);

int64_t t_start_sample_us = ggml_time_us();

bool allow_eog = false;
Expand All @@ -484,9 +486,9 @@ void llama_grammar_sample(struct llama_context * ctx, llama_token_data_array * c

for (size_t i = 0; i < candidates->size; ++i) {
const llama_token id = candidates->data[i].id;
const std::string & piece = llama_get_vocab(ctx)->cache_token_to_piece.at(id);
const std::string & piece = vocab->cache_token_to_piece.at(id);

if (llama_token_is_eog(llama_get_model(ctx), id)) {
if (llama_token_is_eog(*vocab, id)) {
if (!allow_eog) {
candidates->data[i].logit = -INFINITY;
}
Expand All @@ -503,13 +505,13 @@ void llama_grammar_sample(struct llama_context * ctx, llama_token_data_array * c
candidates->data[reject.index].logit = -INFINITY;
}

llama_get_sampling(ctx)->t_sample_us += ggml_time_us() - t_start_sample_us;
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
}

void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token) {
void llama_grammar_accept_token(struct llama_grammar * grammar, const struct llama_vocab * vocab, const struct llama_sampling * smpl, llama_token token) {
const int64_t t_start_sample_us = ggml_time_us();

if (llama_token_is_eog(llama_get_model(ctx), token)) {
if (llama_token_is_eog(*vocab, token)) {
for (const auto & stack : grammar->stacks) {
if (stack.empty()) {
return;
Expand All @@ -518,7 +520,7 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
GGML_ASSERT(false);
}

const std::string & piece = llama_get_vocab(ctx)->cache_token_to_piece.at(token);
const std::string & piece = vocab->cache_token_to_piece.at(token);

// Note terminating 0 in decoded string
const auto decoded = decode_utf8(piece, grammar->partial_utf8);
Expand All @@ -533,5 +535,5 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
grammar->partial_utf8 = decoded.second;
GGML_ASSERT(!grammar->stacks.empty());

llama_get_sampling(ctx)->t_sample_us += ggml_time_us() - t_start_sample_us;
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
}
22 changes: 22 additions & 0 deletions src/llama-grammar.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "llama-impl.h"

struct llama_vocab;
struct llama_sampling;

struct llama_grammar {
const llama_grammar_rules rules;
Expand All @@ -13,3 +14,24 @@ struct llama_grammar {
};

struct llama_grammar * llama_get_grammar(struct llama_context * ctx);

struct llama_grammar * llama_grammar_init_impl(
const llama_grammar_element ** rules,
size_t n_rules,
size_t start_rule_index);

void llama_grammar_free_impl(struct llama_grammar * grammar);

struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * grammar);

void llama_grammar_sample(
const struct llama_grammar * grammar,
const struct llama_vocab * vocab,
const struct llama_sampling * smpl,
llama_token_data_array * candidates);

void llama_grammar_accept_token(
struct llama_grammar * grammar,
const struct llama_vocab * vocab,
const struct llama_sampling * smpl,
llama_token token);
Loading

0 comments on commit 8c5f2c2

Please sign in to comment.