Skip to content

Commit

Permalink
sampling : option to use internal set of candidates
Browse files Browse the repository at this point in the history
ggml-ci
  • Loading branch information
ggerganov committed Aug 29, 2024
1 parent 9dd2061 commit 77998ae
Show file tree
Hide file tree
Showing 8 changed files with 82 additions and 25 deletions.
22 changes: 19 additions & 3 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,9 +199,25 @@ llama_token llama_sampling_sample(
int idx) {
llama_sampling_set_logits(smpl, llama_get_logits_ith(ctx, idx));

auto * cur_p = llama_sampling_get_candidates(smpl);
// first, sample the token without any grammar constraints
auto id = llama_sampling_sample(smpl, nullptr);

llama_sampling_grammar(smpl, cur_p);
// create an array with a single token data element for the sampled id
llama_token_data single_token_data = {id, 1.0f, 0.0f};
llama_token_data_array single_token_data_array = { &single_token_data, 1, false };

return llama_sampling_sample(smpl, cur_p);
llama_sampling_grammar(smpl, &single_token_data_array);

// check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY
const bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
if (is_valid) {
return id;
}

// if the token is not valid, sample again, after applying the grammar constraints
llama_sampling_set_logits(smpl, llama_get_logits_ith(ctx, idx));

llama_sampling_grammar(smpl, nullptr);

return llama_sampling_sample(smpl, nullptr);
}
2 changes: 1 addition & 1 deletion common/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,4 +67,4 @@ std::vector<enum llama_sampler_type> llama_sampling_types_from_chars(const std::
llama_token llama_sampling_sample(
struct llama_sampling * smpl,
struct llama_context * ctx,
int idx = -1);
int idx);
23 changes: 6 additions & 17 deletions examples/batched.swift/Sources/main.swift
Original file line number Diff line number Diff line change
Expand Up @@ -136,28 +136,17 @@ while n_cur <= n_len {
continue
}

var n_vocab = llama_n_vocab(model)
var logits = llama_get_logits_ith(context, i_batch[i])

var candidates: [llama_token_data] = .init(repeating: llama_token_data(), count: Int(n_vocab))
llama_sampling_set_logits(smpl, logits)

for token_id in 0 ..< n_vocab {
candidates.append(llama_token_data(id: token_id, logit: logits![Int(token_id)], p: 0.0))
}

var candidates_p: llama_token_data_array = .init(
data: &candidates,
size: candidates.count,
sorted: false
)

llama_sampling_top_k(smpl, &candidates_p)
llama_sampling_top_p(smpl, &candidates_p)
llama_sampling_temp (smpl, &candidates_p)
llama_sampling_top_k(smpl, nil)
llama_sampling_top_p(smpl, nil)
llama_sampling_temp (smpl, nil)

let new_token_id = llama_sampling_sample_dist(smpl, &candidates_p)
let new_token_id = llama_sampling_sample_dist(smpl, nil)

// const llama_token new_token_id = llama_sampling_sample_greedy(smpl, &candidates_p);
// const llama_token new_token_id = llama_sampling_sample_greedy(smpl, nil);

// is it an end of stream? -> mark the stream as finished
if llama_token_is_eog(model, new_token_id) || n_cur == n_len {
Expand Down
2 changes: 1 addition & 1 deletion examples/infill/infill.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ int main(int argc, char ** argv) {
embd.clear();

if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
const llama_token id = llama_sampling_sample(smpl, ctx);
const llama_token id = llama_sampling_sample(smpl, ctx, -1);

llama_sampling_accept(smpl, id, true);

Expand Down
2 changes: 1 addition & 1 deletion examples/llava/llava-cli.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ static bool eval_string(struct llama_context * ctx_llama, const char* str, int n
static const char * sample(struct llama_sampling * smpl,
struct llama_context * ctx_llama,
int * n_past) {
const llama_token id = llama_sampling_sample(smpl, ctx_llama);
const llama_token id = llama_sampling_sample(smpl, ctx_llama, -1);
llama_sampling_accept(smpl, id, true);
static std::string ret;
if (llama_token_is_eog(llama_get_model(ctx_llama), id)) {
Expand Down
2 changes: 1 addition & 1 deletion examples/llava/minicpmv-cli.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ static void process_image(struct llava_context * ctx_llava, struct llava_image_e
static const char * sample(struct llama_sampling * smpl,
struct llama_context * ctx_llama,
int * n_past) {
const llama_token id = llama_sampling_sample(smpl, ctx_llama);
const llama_token id = llama_sampling_sample(smpl, ctx_llama, -1);
llama_sampling_accept(smpl, id, true);
static std::string ret;
if (llama_token_is_eog(llama_get_model(ctx_llama), id)) {
Expand Down
2 changes: 1 addition & 1 deletion examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -650,7 +650,7 @@ int main(int argc, char ** argv) {
LOG("saved session to %s\n", path_session.c_str());
}

const llama_token id = llama_sampling_sample(smpl, ctx);
const llama_token id = llama_sampling_sample(smpl, ctx, -1);

llama_sampling_accept(smpl, id, /* apply_grammar= */ true);

Expand Down
52 changes: 52 additions & 0 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20139,42 +20139,70 @@ llama_token_data_array * llama_sampling_get_candidates(struct llama_sampling * s
void llama_sampling_softmax(struct llama_sampling * smpl, llama_token_data_array * candidates) {
time_meas tm(smpl->t_sample_us);

if (candidates == nullptr) {
candidates = &smpl->cur_p;
}

llama_sampling_softmax_impl(candidates);
}

void llama_sampling_top_k(struct llama_sampling * smpl, llama_token_data_array * candidates) {
time_meas tm(smpl->t_sample_us);

if (candidates == nullptr) {
candidates = &smpl->cur_p;
}

llama_sampling_top_k_impl(candidates, smpl->params.top_k, smpl->params.min_keep);
}

void llama_sampling_top_p(struct llama_sampling * smpl, llama_token_data_array * candidates) {
time_meas tm(smpl->t_sample_us);

if (candidates == nullptr) {
candidates = &smpl->cur_p;
}

llama_sampling_top_p_impl(candidates, smpl->params.top_p, smpl->params.min_keep);
}

void llama_sampling_min_p(struct llama_sampling * smpl, llama_token_data_array * candidates) {
time_meas tm(smpl->t_sample_us);

if (candidates == nullptr) {
candidates = &smpl->cur_p;
}

llama_sampling_min_p_impl(candidates, smpl->params.min_p, smpl->params.min_keep);
}

void llama_sampling_tail_free(struct llama_sampling * smpl, llama_token_data_array * candidates) {
time_meas tm(smpl->t_sample_us);

if (candidates == nullptr) {
candidates = &smpl->cur_p;
}

llama_sampling_tail_free_impl(candidates, smpl->params.tfs_z, smpl->params.min_keep);
}

void llama_sampling_typical(struct llama_sampling * smpl, llama_token_data_array * candidates) {
time_meas tm(smpl->t_sample_us);

if (candidates == nullptr) {
candidates = &smpl->cur_p;
}

llama_sampling_typical_impl(candidates, smpl->params.typ_p, smpl->params.min_keep);
}

void llama_sampling_temp(struct llama_sampling * smpl, llama_token_data_array * candidates) {
time_meas tm(smpl->t_sample_us);

if (candidates == nullptr) {
candidates = &smpl->cur_p;
}

if (smpl->params.dynatemp_range > 0) {
const float dynatemp_min = std::max(0.0f, smpl->params.temp - smpl->params.dynatemp_range);
const float dynatemp_max = std::max(0.0f, smpl->params.temp + smpl->params.dynatemp_range);
Expand All @@ -20188,6 +20216,10 @@ void llama_sampling_temp(struct llama_sampling * smpl, llama_token_data_array *
void llama_sampling_grammar(struct llama_sampling * smpl, llama_token_data_array * candidates) {
time_meas tm(smpl->t_grammar_us);

if (candidates == nullptr) {
candidates = &smpl->cur_p;
}

if (smpl->grammar) {
llama_sampling_grammar_impl(candidates, *smpl->grammar);

Expand All @@ -20200,6 +20232,10 @@ void llama_sampling_penalties(
llama_token_data_array * candidates) {
time_meas tm(smpl->t_sample_us);

if (candidates == nullptr) {
candidates = &smpl->cur_p;
}

const size_t penalty_last_n = std::min<size_t>(smpl->params.penalty_last_n, smpl->prev.size());

const float penalty_repeat = smpl->params.penalty_repeat;
Expand All @@ -20224,6 +20260,10 @@ void llama_sampling_penalties(
llama_token llama_sampling_sample_mirostat(struct llama_sampling * smpl, llama_token_data_array * candidates) {
time_meas tm(smpl->t_sample_us);

if (candidates == nullptr) {
candidates = &smpl->cur_p;
}

const auto type = smpl->params.mirostat;

llama_token res;
Expand Down Expand Up @@ -20254,6 +20294,10 @@ llama_token llama_sampling_sample_mirostat(struct llama_sampling * smpl, llama_t
llama_token llama_sampling_sample_greedy(struct llama_sampling * smpl, llama_token_data_array * candidates) {
time_meas tm(smpl->t_sample_us);

if (candidates == nullptr) {
candidates = &smpl->cur_p;
}

auto res = llama_sampling_sample_greedy_impl(candidates);

smpl->n_sample++;
Expand All @@ -20264,6 +20308,10 @@ llama_token llama_sampling_sample_greedy(struct llama_sampling * smpl, llama_tok
llama_token llama_sampling_sample_dist(struct llama_sampling * smpl, llama_token_data_array * candidates) {
time_meas tm(smpl->t_sample_us);

if (candidates == nullptr) {
candidates = &smpl->cur_p;
}

auto res = llama_sampling_sample_dist_impl(candidates, smpl->rng);

smpl->n_sample++;
Expand All @@ -20274,6 +20322,10 @@ llama_token llama_sampling_sample_dist(struct llama_sampling * smpl, llama_token
llama_token llama_sampling_sample(struct llama_sampling * smpl, llama_token_data_array * candidates) {
time_meas tm(smpl->t_sample_us);

if (candidates == nullptr) {
candidates = &smpl->cur_p;
}

const auto & params = smpl->params;

const float temp = params.temp;
Expand Down

0 comments on commit 77998ae

Please sign in to comment.