Skip to content

Commit

Permalink
sampling : avoid llama_model in few samplers
Browse files Browse the repository at this point in the history
ggml-ci
  • Loading branch information
ggerganov committed Sep 7, 2024
1 parent 19c3696 commit 0e6d170
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 86 deletions.
8 changes: 5 additions & 3 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,13 +152,15 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st

llama_sampler_chain_add(result->chain,
llama_sampler_init_logit_bias(
model,
llama_n_vocab(model),
params.logit_bias.size(),
params.logit_bias.data()));

llama_sampler_chain_add(result->chain,
llama_sampler_init_penalties(
model,
llama_n_vocab (model),
llama_token_eos(model),
llama_token_nl (model),
params.penalty_last_n,
params.penalty_repeat,
params.penalty_freq,
Expand Down Expand Up @@ -196,7 +198,7 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
} else if (params.mirostat == 1) {
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(model, params.seed, params.mirostat_tau, params.mirostat_eta, 100));
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_n_vocab(model), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
} else if (params.mirostat == 2) {
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta));
Expand Down
11 changes: 7 additions & 4 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -1003,12 +1003,13 @@ extern "C" {
// // sample from the logits of the last token in the batch
// const llama_token id = llama_sampler_sample(smpl, ctx, -1);
//
// // accepting the token updates the internal state of certain samplers (e.g. grammar, repetition, etc.)
// llama_sampler_accept(smpl, id);
// ...
// }
//
// llama_sampler_free(smpl);
//
//
// TODO: In the future, llama_sampler will be utilized to offload the sampling to the backends (e.g. GPU).
// TODO: in the future, the entire sampling API that uses llama_model should start using llama_vocab
//
Expand Down Expand Up @@ -1086,7 +1087,7 @@ extern "C" {
/// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.
/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
LLAMA_API struct llama_sampler * llama_sampler_init_mirostat(
const struct llama_model * model,
int32_t n_vocab,
uint32_t seed,
float tau,
float eta,
Expand All @@ -1108,7 +1109,9 @@ extern "C" {
const char * grammar_root);

LLAMA_API struct llama_sampler * llama_sampler_init_penalties(
const struct llama_model * model,
int32_t n_vocab, // llama_n_vocab()
llama_token special_eos_id, // llama_token_eos()
llama_token linefeed_id, // llama_token_nl()
int32_t penalty_last_n, // last n tokens to penalize (0 = disable penalty, -1 = context size)
float penalty_repeat, // 1.0 = disabled
float penalty_freq, // 0.0 = disabled
Expand All @@ -1117,7 +1120,7 @@ extern "C" {
bool ignore_eos); // ignore the end-of-sequence token

LLAMA_API struct llama_sampler * llama_sampler_init_logit_bias(
const struct llama_model * model,
int32_t n_vocab,
int32_t n_logit_bias,
const llama_logit_bias * logit_bias);

Expand Down
138 changes: 105 additions & 33 deletions src/llama-sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "llama-vocab.h"
#include "llama-grammar.h"

#include <cassert>
#include <algorithm>
#include <cstring>
#include <ctime>
Expand Down Expand Up @@ -926,7 +927,7 @@ struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, floa
// mirostat

struct llama_sampler_mirostat {
const struct llama_vocab * vocab;
const int32_t n_vocab;

const uint32_t seed;

Expand Down Expand Up @@ -964,7 +965,7 @@ static struct llama_sampler_i llama_sampler_mirostat_i = {

// Compute k from the estimated s_hat and target surprise value
float epsilon_hat = s_hat - 1;
float k = powf((epsilon_hat * powf(2, ctx->mu)) / (1 - powf(ctx->vocab->n_vocab, -epsilon_hat)), 1 / s_hat);
float k = powf((epsilon_hat * powf(2, ctx->mu)) / (1 - powf(ctx->n_vocab, -epsilon_hat)), 1 / s_hat);

llama_sampler_top_k_impl(cur_p, std::max(int(k), 1));
llama_sampler_softmax_impl(cur_p);
Expand All @@ -986,25 +987,25 @@ static struct llama_sampler_i llama_sampler_mirostat_i = {
},
/* .clone = */ [](const struct llama_sampler * smpl) {
const auto * ctx = (const llama_sampler_mirostat *) smpl->ctx;
return llama_sampler_init_mirostat_impl(*ctx->vocab, ctx->seed, ctx->tau, ctx->eta, ctx->m);
return llama_sampler_init_mirostat(ctx->n_vocab, ctx->seed, ctx->tau, ctx->eta, ctx->m);
},
/* .free = */ [](struct llama_sampler * smpl) {
delete (llama_sampler_mirostat *) smpl->ctx;
},
};

struct llama_sampler * llama_sampler_init_mirostat_impl(const struct llama_vocab & vocab, uint32_t seed, float tau, float eta, int32_t m) {
struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) {
return new llama_sampler {
/* .iface = */ &llama_sampler_mirostat_i,
/* .ctx = */ new llama_sampler_mirostat {
/* .vocab = */ &vocab,
/* .seed = */ seed,
/* .tau = */ tau,
/* .eta = */ eta,
/* .m = */ m,
/* .mu = */ 2.0f*tau,
/* .rng = */ std::mt19937(seed),
/* .probs = */ {},
/* .n_vocab = */ n_vocab,
/* .seed = */ seed,
/* .tau = */ tau,
/* .eta = */ eta,
/* .m = */ m,
/* .mu = */ 2.0f*tau,
/* .rng = */ std::mt19937(seed),
/* .probs = */ {},
},
};
}
Expand Down Expand Up @@ -1172,7 +1173,9 @@ struct llama_sampler * llama_sampler_init_grammar_impl(const struct llama_vocab
// penalties

struct llama_sampler_penalties {
const struct llama_vocab * vocab;
const int32_t n_vocab;
const llama_token special_eos_id;
const llama_token linefeed_id;

const int32_t penalty_last_n;
const float penalty_repeat;
Expand All @@ -1194,18 +1197,51 @@ static struct llama_sampler_i llama_sampler_penalties_i = {
/* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
auto * ctx = (llama_sampler_penalties *) smpl->ctx;

GGML_ASSERT(cur_p->size == ctx->vocab->n_vocab && cur_p->sorted == false && "the 'penalties' sampler must be applied on the full vocabulary");

if (ctx->ignore_eos) {
cur_p->data[ctx->vocab->special_eos_id].logit = -INFINITY;
assert(ctx->special_eos_id >= 0);

// optimistically check if the candidates are not yet sorted/shuffled/truncated
if (cur_p->size > (size_t) ctx->special_eos_id && cur_p->data[ctx->special_eos_id].id == ctx->special_eos_id) {
cur_p->data[ctx->special_eos_id].logit = -INFINITY;
} else {
// else, search for the special EOS token
for (size_t i = 0; i < cur_p->size; ++i) {
if (cur_p->data[i].id == ctx->special_eos_id) {
cur_p->data[i].logit = -INFINITY;
break;
}
}
}
}

if ((ctx->penalty_last_n == 0) ||
(ctx->penalty_repeat == 1.0f && ctx->penalty_freq == 0.0f && ctx->penalty_present == 0.0f)) {
return;
}

const float nl_logit = !ctx->penalize_nl ? cur_p->data[ctx->vocab->linefeed_id].logit : -INFINITY;
bool nl_found = false;
size_t nl_idx = 0;
float nl_logit = -INFINITY;
if (!ctx->penalize_nl) {
assert(ctx->linefeed_id >= 0);

// optimistically check if the candidates are not yet sorted/shuffled/truncated
if (cur_p->size > (size_t) ctx->linefeed_id && cur_p->data[ctx->linefeed_id].id == ctx->linefeed_id) {
nl_found = true;
nl_idx = ctx->linefeed_id;
nl_logit = cur_p->data[ctx->linefeed_id].logit;
} else {
// else, search for the linefeed token
for (size_t i = 0; i < cur_p->size; ++i) {
if (cur_p->data[i].id == ctx->linefeed_id) {
nl_found = true;
nl_idx = i;
nl_logit = cur_p->data[i].logit;
break;
}
}
}
}

// Create a frequency map to count occurrences of each token in last_tokens
// TODO: optimize this by maintaining the token count in the sampler context
Expand All @@ -1216,9 +1252,9 @@ static struct llama_sampler_i llama_sampler_penalties_i = {

llama_sampler_penalties_impl(cur_p, token_count, ctx->penalty_repeat, ctx->penalty_freq, ctx->penalty_present);

if (!ctx->penalize_nl) {
if (!ctx->penalize_nl && nl_found) {
// restore the logit of the newline token if it was penalized
cur_p->data[ctx->vocab->linefeed_id].logit = nl_logit;
cur_p->data[nl_idx].logit = nl_logit;
}
},
/* .reset = */ [](struct llama_sampler * smpl) {
Expand All @@ -1227,8 +1263,10 @@ static struct llama_sampler_i llama_sampler_penalties_i = {
},
/* .clone = */ [](const struct llama_sampler * smpl) {
const auto * ctx_src = (const llama_sampler_penalties *) smpl->ctx;
auto * result = llama_sampler_init_penalties_impl(
*ctx_src->vocab,
auto * result = llama_sampler_init_penalties(
ctx_src->n_vocab,
ctx_src->special_eos_id,
ctx_src->linefeed_id,
ctx_src->penalty_last_n,
ctx_src->penalty_repeat,
ctx_src->penalty_freq,
Expand All @@ -1246,14 +1284,30 @@ static struct llama_sampler_i llama_sampler_penalties_i = {
},
};

struct llama_sampler * llama_sampler_init_penalties_impl(const struct llama_vocab & vocab, int32_t penalty_last_n, float penalty_repeat, float penalty_freq, float penalty_present, bool penalize_nl, bool ignore_eos) {
GGML_ASSERT(penalize_nl || vocab.linefeed_id != LLAMA_TOKEN_NULL);
GGML_ASSERT(!ignore_eos || vocab.special_eos_id != LLAMA_TOKEN_NULL);
struct llama_sampler * llama_sampler_init_penalties(
int32_t n_vocab,
llama_token special_eos_id,
llama_token linefeed_id,
int32_t penalty_last_n,
float penalty_repeat,
float penalty_freq,
float penalty_present,
bool penalize_nl,
bool ignore_eos) {
if (linefeed_id == LLAMA_TOKEN_NULL) {
penalize_nl = false;
}

if (special_eos_id == LLAMA_TOKEN_NULL) {
ignore_eos = true;
}

return new llama_sampler {
/* .iface = */ &llama_sampler_penalties_i,
/* .ctx = */ new llama_sampler_penalties {
/* .vocab = */ &vocab,
/* .n_vocab = */ n_vocab,
/* .special_eos_id = */ special_eos_id,
/* .linefeed_id = */ linefeed_id,
/* .penalty_last_n = */ penalty_last_n,
/* .penalty_repeat = */ penalty_repeat,
/* .penalty_freq = */ penalty_freq,
Expand All @@ -1268,9 +1322,11 @@ struct llama_sampler * llama_sampler_init_penalties_impl(const struct llama_voca
// logit-bias

struct llama_sampler_logit_bias {
const struct llama_vocab * vocab;
const int32_t n_vocab;

const std::vector<llama_logit_bias> logit_bias;

std::vector<llama_logit_bias> logit_bias;
std::vector<llama_logit_bias> to_search;
};

static struct llama_sampler_i llama_sampler_logit_bias_i = {
Expand All @@ -1279,31 +1335,47 @@ static struct llama_sampler_i llama_sampler_logit_bias_i = {
/* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
auto * ctx = (llama_sampler_logit_bias *) smpl->ctx;

GGML_ASSERT(cur_p->size == ctx->vocab->n_vocab && cur_p->sorted == false && "the 'logit_bias' sampler must be applied on the full vocabulary");
ctx->to_search.clear();

// update the candidates that have not been shuffled in the vocabulary (i.e. idx == id)
for (const auto & lb : ctx->logit_bias) {
cur_p->data[lb.token].logit += lb.bias;
if (lb.token >= 0 && cur_p->size > (size_t) lb.token && cur_p->data[lb.token].id == lb.token) {
cur_p->data[lb.token].logit += lb.bias;
} else {
ctx->to_search.push_back(lb);
}
}

// search for the remaining candidates that were not found in the previous step
for (size_t i = 0; i < cur_p->size; ++i) {
for (const auto & lb : ctx->to_search) {
if (cur_p->data[i].id == lb.token) {
cur_p->data[i].logit += lb.bias;
break;
}
}
}
},
/* .reset = */ nullptr,
/* .clone = */ [](const struct llama_sampler * smpl) {
const auto * ctx_src = (const llama_sampler_logit_bias *) smpl->ctx;
return llama_sampler_init_logit_bias_impl(*ctx_src->vocab, ctx_src->logit_bias.size(), ctx_src->logit_bias.data());
return llama_sampler_init_logit_bias(ctx_src->n_vocab, ctx_src->logit_bias.size(), ctx_src->logit_bias.data());
},
/* .free = */ [](struct llama_sampler * smpl) {
delete (llama_sampler_logit_bias *) smpl->ctx;
},
};

struct llama_sampler * llama_sampler_init_logit_bias_impl(
const struct llama_vocab & vocab,
struct llama_sampler * llama_sampler_init_logit_bias(
int32_t n_vocab,
int32_t n_logit_bias,
const llama_logit_bias * logit_bias) {
return new llama_sampler {
/* .iface = */ &llama_sampler_logit_bias_i,
/* .ctx = */ new llama_sampler_logit_bias {
/* .vocab = */ &vocab,
/* .n_vocab = */ n_vocab,
/* .logit_bias = */ std::vector<llama_logit_bias>(logit_bias, logit_bias + n_logit_bias),
/* .to_search = */ {},
},
};
}
21 changes: 0 additions & 21 deletions src/llama-sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,28 +33,7 @@ void llama_sampler_penalties_impl(
float penalty_freq,
float penalty_present);

struct llama_sampler * llama_sampler_init_mirostat_impl(
const struct llama_vocab & vocab,
uint32_t seed,
float tau,
float eta,
int32_t m);

struct llama_sampler * llama_sampler_init_grammar_impl(
const struct llama_vocab & vocab,
const char * grammar_str,
const char * grammar_root);

struct llama_sampler * llama_sampler_init_penalties_impl(
const struct llama_vocab & vocab,
int32_t penalty_last_n,
float penalty_repeat,
float penalty_freq,
float penalty_present,
bool penalize_nl,
bool ignore_eos);

LLAMA_API struct llama_sampler * llama_sampler_init_logit_bias_impl(
const struct llama_vocab & vocab,
int32_t n_logit_bias,
const llama_logit_bias * logit_bias);
25 changes: 0 additions & 25 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20592,36 +20592,11 @@ int32_t llama_chat_apply_template(
// sampling
//

// TODO: remove indirection when vocab becomes accesible in llama-sampling.cpp
struct llama_sampler * llama_sampler_init_mirostat(const struct llama_model * model, uint32_t seed, float tau, float eta, int32_t m) {
return llama_sampler_init_mirostat_impl(model->vocab, seed, tau, eta, m);
}

// TODO: remove indirection when vocab becomes accesible in llama-sampling.cpp
struct llama_sampler * llama_sampler_init_grammar(const struct llama_model * model, const char * grammar_str, const char * grammar_root) {
return llama_sampler_init_grammar_impl(model->vocab, grammar_str, grammar_root);
}

// TODO: remove indirection when vocab becomes accesible in llama-sampling.cpp
struct llama_sampler * llama_sampler_init_penalties(
const struct llama_model * model,
int32_t penalty_last_n,
float penalty_repeat,
float penalty_freq,
float penalty_present,
bool penalize_nl,
bool ignore_eos) {
return llama_sampler_init_penalties_impl(model->vocab, penalty_last_n, penalty_repeat, penalty_freq, penalty_present, penalize_nl, ignore_eos);
}

// TODO: remove indirection when vocab becomes accesible in llama-sampling.cpp
struct llama_sampler * llama_sampler_init_logit_bias(
const struct llama_model * model,
int32_t n_logit_bias,
const llama_logit_bias * logit_bias) {
return llama_sampler_init_logit_bias_impl(model->vocab, n_logit_bias, logit_bias);
}

//
// model split
//
Expand Down

0 comments on commit 0e6d170

Please sign in to comment.