Skip to content

Commit

Permalink
cont : fix grammar sampling speed-up
Browse files Browse the repository at this point in the history
ggml-ci
  • Loading branch information
ggerganov committed Aug 16, 2024
1 parent fd24e68 commit acaf26c
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 14 deletions.
44 changes: 36 additions & 8 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,12 +230,8 @@ llama_token_data_array llama_sampling_prepare(
struct llama_context * ctx_main,
struct llama_context * ctx_cfg,
int idx) {
llama_sampling * smpl = ctx_sampling->smpl;

const gpt_sampling_params & params = ctx_sampling->params;

const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));

auto & cur = ctx_sampling->cur;

// Get a pointer to the logits
Expand All @@ -250,11 +246,15 @@ llama_token_data_array llama_sampling_prepare(
logits[llama_token_eos(llama_get_model(ctx_main))] = -INFINITY;
}

llama_sampling * smpl = ctx_sampling->smpl;

if (ctx_cfg) {
float * logits_guidance = llama_get_logits_ith(ctx_cfg, idx);
llama_sampling_cfg(smpl, logits, logits_guidance);
}

const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));

cur.resize(n_vocab);

for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
Expand All @@ -279,12 +279,10 @@ llama_token_data_array llama_sampling_prepare(
}
}

llama_sampling_grammar(smpl, &cur_p);

return cur_p;
}

llama_token llama_sampling_sample(
static llama_token llama_sampling_sample(
struct llama_sampling_context * ctx_sampling,
struct llama_token_data_array * cur_p) {
llama_sampling * smpl = ctx_sampling->smpl;
Expand Down Expand Up @@ -339,7 +337,37 @@ llama_token llama_sampling_sample(
int idx) {
llama_token_data_array cur_p = llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx);

return llama_sampling_sample(ctx_sampling, &cur_p);
if (ctx_sampling->params.grammar.empty()) {
return llama_sampling_sample(ctx_sampling, &cur_p);
}

// TODO: this lofic is confusing, try to figure out a better way to handle this

// store the original candidates
ctx_sampling->org = ctx_sampling->cur;
llama_token_data_array org_p = { ctx_sampling->org.data(), ctx_sampling->org.size(), false };

llama_token id = llama_sampling_sample(ctx_sampling, &cur_p);

// Create an array with a single token data element for the sampled id
llama_token_data single_token_data = { id, 1.0f, 0.0f };
llama_token_data_array single_token_data_array = { &single_token_data, 1, false };

// Apply grammar constraints to the single token
llama_sampling_grammar(ctx_sampling->smpl, &single_token_data_array);

// Check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY
const bool is_valid = single_token_data_array.data[0].logit != -INFINITY;

if (!is_valid) {
llama_sampling_grammar(ctx_sampling->smpl, &org_p);

id = llama_sampling_sample(ctx_sampling, &org_p);

ctx_sampling->cur = std::move(ctx_sampling->org);
}

return id;
}

void llama_sampling_accept(
Expand Down
7 changes: 4 additions & 3 deletions common/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ struct llama_sampling_context {
llama_sampling * smpl;

std::vector<llama_token_data> cur;
std::vector<llama_token_data> org;

size_t n_valid; // Number of correct top tokens with correct probabilities.
};
Expand Down Expand Up @@ -125,9 +126,9 @@ llama_token_data_array llama_sampling_prepare(
// - token: sampled token
// - candidates: vector of candidate tokens
//
llama_token llama_sampling_sample(
struct llama_sampling_context * ctx_sampling,
struct llama_token_data_array * cur_p);
//llama_token llama_sampling_sample(
// struct llama_sampling_context * ctx_sampling,
// struct llama_token_data_array * cur_p);

llama_token llama_sampling_sample(
struct llama_sampling_context * ctx_sampling,
Expand Down
2 changes: 1 addition & 1 deletion examples/speculative/speculative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,6 @@ int main(int argc, char ** argv) {
// draft sequence data
std::vector<seq_draft> drafts(n_seq_dft);

params.sparams.grammar.clear(); // the draft samplers will copy the target sampler's grammar
if (params.sparams.temp == 0) {
params.sparams.temp = -1.0f; // force greedy sampling with probs for the draft model
}
Expand Down Expand Up @@ -232,6 +231,7 @@ int main(int argc, char ** argv) {
// stochastic verification

llama_token_data_array dist_tgt = llama_sampling_prepare(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]);
llama_sampling_grammar(ctx_sampling->smpl, &dist_tgt);
llama_sampling_softmax(ctx_sampling->smpl, &dist_tgt);

float p_tgt = 0.0f;
Expand Down
5 changes: 3 additions & 2 deletions src/llama-grammar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -683,8 +683,8 @@ static void llama_grammar_advance_stack(
}

static llama_grammar_candidates llama_grammar_reject_candidates(
const llama_grammar_rules & rules,
const llama_grammar_stacks & stacks,
const llama_grammar_rules & rules,
const llama_grammar_stacks & stacks,
const llama_grammar_candidates & candidates) {
GGML_ASSERT(!stacks.empty()); // REVIEW

Expand All @@ -697,6 +697,7 @@ static llama_grammar_candidates llama_grammar_reject_candidates(
for (size_t i = 1, size = stacks.size(); i < size; ++i) {
rejects = llama_grammar_reject_candidates_for_stack(rules, stacks[i], rejects);
}

return rejects;
}

Expand Down

0 comments on commit acaf26c

Please sign in to comment.