Skip to content

Commit

Permalink
cont : update sampling API
Browse files Browse the repository at this point in the history
ggml-ci
  • Loading branch information
ggerganov committed Aug 15, 2024
1 parent 9def2a6 commit 7a9bf68
Show file tree
Hide file tree
Showing 11 changed files with 170 additions and 263 deletions.
50 changes: 13 additions & 37 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,14 @@ struct llama_sampling_context * llama_sampling_init(const struct gpt_sampling_pa
lp.mirostat = params.mirostat;
lp.mirostat_tau = params.mirostat_tau;
lp.mirostat_eta = params.mirostat_eta;
lp.cfg_scale = params.cfg_scale;
lp.penalize_nl = params.penalize_nl;
lp.ignore_eos = params.ignore_eos;

result->smpl = llama_sampling_init(model, lp);

llama_sampling_set_rng_seed (result->smpl, params.seed);
llama_sampling_set_grammar (result->smpl, params.grammar.c_str(), "root");
llama_sampling_set_cfg (result->smpl, params.cfg_negative_prompt.c_str(), params.cfg_scale);
llama_sampling_set_logit_bias(result->smpl, params.logit_bias.size(), params.logit_bias.data());
}

Expand Down Expand Up @@ -202,38 +202,21 @@ std::vector<llama_sampler_type> llama_sampling_types_from_chars(const std::strin
// no reasons to expose this function in header
static void sampler_queue(
struct llama_sampling_context * ctx_sampling,
llama_token_data_array & cur_p,
size_t min_keep) {
llama_token_data_array & cur_p) {
llama_sampling * smpl = ctx_sampling->smpl;

const gpt_sampling_params & params = ctx_sampling->params;

const float temp = params.temp;
const float dynatemp_range = params.dynatemp_range;
const float dynatemp_exponent = params.dynatemp_exponent;
const int32_t top_k = params.top_k;
const float top_p = params.top_p;
const float min_p = params.min_p;
const float tfs_z = params.tfs_z;
const float typical_p = params.typical_p;
const std::vector<llama_sampler_type> & samplers_sequence = params.samplers_sequence;

for (auto sampler_type : samplers_sequence) {
switch (sampler_type) {
case llama_sampler_type::TOP_K : llama_sampling_top_k (smpl, &cur_p, top_k, min_keep); break;
case llama_sampler_type::TFS_Z : llama_sampling_tail_free(smpl, &cur_p, tfs_z, min_keep); break;
case llama_sampler_type::TYPICAL_P: llama_sampling_typical (smpl, &cur_p, typical_p, min_keep); break;
case llama_sampler_type::TOP_P : llama_sampling_top_p (smpl, &cur_p, top_p, min_keep); break;
case llama_sampler_type::MIN_P : llama_sampling_min_p (smpl, &cur_p, min_p, min_keep); break;
case llama_sampler_type::TEMPERATURE:
if (dynatemp_range > 0) {
float dynatemp_min = std::max(0.0f, temp - dynatemp_range);
float dynatemp_max = std::max(0.0f, temp + dynatemp_range);
llama_sampling_entropy(smpl, &cur_p, dynatemp_min, dynatemp_max, dynatemp_exponent);
} else {
llama_sampling_temp(smpl, &cur_p, temp);
}
break;
case llama_sampler_type::TOP_K: llama_sampling_top_k (smpl, &cur_p); break;
case llama_sampler_type::TFS_Z: llama_sampling_tail_free(smpl, &cur_p); break;
case llama_sampler_type::TYPICAL_P: llama_sampling_typical (smpl, &cur_p); break;
case llama_sampler_type::TOP_P: llama_sampling_top_p (smpl, &cur_p); break;
case llama_sampler_type::MIN_P: llama_sampling_min_p (smpl, &cur_p); break;
case llama_sampler_type::TEMPERATURE: llama_sampling_temp (smpl, &cur_p); break;
default : break;
}
}
Expand Down Expand Up @@ -269,18 +252,11 @@ static llama_token llama_sampling_sample_impl(
// greedy sampling, no probs
id = llama_sampling_sample_greedy(smpl, &cur_p);
} else {
if (mirostat == 1) {
const int mirostat_m = 100;
llama_sampling_temp(smpl, &cur_p, temp);
id = llama_sampling_sample_mirostat(smpl, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &ctx_sampling->mirostat_mu);
} else if (mirostat == 2) {
llama_sampling_temp(smpl, &cur_p, temp);
id = llama_sampling_sample_mirostat_v2(smpl, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling->mirostat_mu);
if (mirostat != 0) {
llama_sampling_temp(smpl, &cur_p);
id = llama_sampling_sample_mirostat(smpl, &cur_p);
} else {
// temperature sampling
size_t min_keep = std::max(1, params.min_keep);

sampler_queue(ctx_sampling, cur_p, min_keep);
sampler_queue(ctx_sampling, cur_p);

id = llama_sampling_sample(smpl, &cur_p);

Expand Down Expand Up @@ -372,7 +348,7 @@ static llama_token_data_array llama_sampling_prepare_impl(

if (ctx_cfg) {
float * logits_guidance = llama_get_logits_ith(ctx_cfg, idx);
llama_sampling_cfg(smpl, logits, logits_guidance, params.cfg_scale);
llama_sampling_cfg(smpl, logits, logits_guidance);
}

cur.resize(n_vocab);
Expand Down
3 changes: 0 additions & 3 deletions common/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,6 @@ struct llama_sampling_context {
// parameters that will be used for sampling
gpt_sampling_params params;

// mirostat sampler state
float mirostat_mu;

llama_sampling * smpl;

ring_buffer<llama_token> prev;
Expand Down
18 changes: 10 additions & 8 deletions examples/batched/batched.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,13 @@ int main(int argc, char ** argv) {
ctx_params.n_batch = std::max(n_predict, n_parallel);

llama_context * ctx = llama_new_context_with_model(model, ctx_params);
llama_sampling * smpl = llama_sampling_init(model, llama_sampling_default_params());

auto sparams = llama_sampling_default_params();
sparams.top_k = 40;
sparams.top_p = 0.9f;
sparams.temp = 0.4f;

llama_sampling * smpl = llama_sampling_init(model, sparams);

if (ctx == NULL) {
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
Expand Down Expand Up @@ -177,13 +183,9 @@ int main(int argc, char ** argv) {

llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };

const int top_k = 40;
const float top_p = 0.9f;
const float temp = 0.4f;

llama_sampling_top_k(smpl, &candidates_p, top_k, 1);
llama_sampling_top_p(smpl, &candidates_p, top_p, 1);
llama_sampling_temp (smpl, &candidates_p, temp);
llama_sampling_top_k(smpl, &candidates_p);
llama_sampling_top_p(smpl, &candidates_p);
llama_sampling_temp (smpl, &candidates_p);

const llama_token new_token_id = llama_sampling_sample(smpl, &candidates_p);

Expand Down
7 changes: 4 additions & 3 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2358,9 +2358,10 @@ struct server_context {
const size_t n_valid = slot.ctx_sampling->n_valid;

// Make sure at least n_probs top tokens are at the front of the vector:
if (slot.sparams.temp == 0.0f && n_probs > n_valid) {
llama_sampling_top_k(slot.ctx_sampling->smpl, &cur_p, n_probs, 0);
}
// TODO: decide to how to handle this after the refactoring
//if (slot.sparams.temp == 0.0f && n_probs > n_valid) {
// llama_sampling_top_k(slot.ctx_sampling->smpl, &cur_p, n_probs, 0);
//}

if (slot.sparams.temp == 0.0f) {
// With greedy sampling the probabilities have possibly not been calculated.
Expand Down
65 changes: 13 additions & 52 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,7 @@ extern "C" {
int32_t mirostat; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
float mirostat_tau; // target entropy
float mirostat_eta; // learning rate
float cfg_scale; // classifier-free guidance scale

// Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value.
bool penalize_nl; // consider newlines as a repeatable token
Expand Down Expand Up @@ -1012,7 +1013,6 @@ extern "C" {
// Sets the current rng seed.
LLAMA_API void llama_sampling_set_rng_seed (struct llama_sampling * smpl, uint32_t seed);
LLAMA_API void llama_sampling_set_grammar (struct llama_sampling * smpl, const char * grammar_str, const char * grammar_root);
LLAMA_API void llama_sampling_set_cfg (struct llama_sampling * smpl, const char * cfg_prompt, float cfg_scale);
LLAMA_API void llama_sampling_set_logit_bias(struct llama_sampling * smpl, int32_t n_logit_bias, const llama_logit_bias * logit_bias);

/// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
Expand All @@ -1023,50 +1023,32 @@ extern "C" {
/// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
LLAMA_API void llama_sampling_top_k(
struct llama_sampling * smpl,
llama_token_data_array * candidates,
int32_t k,
size_t min_keep);
llama_token_data_array * candidates);

/// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
LLAMA_API void llama_sampling_top_p(
struct llama_sampling * smpl,
llama_token_data_array * candidates,
float p,
size_t min_keep);
llama_token_data_array * candidates);

/// @details Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841
LLAMA_API void llama_sampling_min_p(
struct llama_sampling * smpl,
llama_token_data_array * candidates,
float p,
size_t min_keep);
llama_token_data_array * candidates);

/// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
LLAMA_API void llama_sampling_tail_free(
struct llama_sampling * smpl,
llama_token_data_array * candidates,
float z,
size_t min_keep);
llama_token_data_array * candidates);

/// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
LLAMA_API void llama_sampling_typical(
struct llama_sampling * smpl,
llama_token_data_array * candidates,
float p,
size_t min_keep);

/// @details Dynamic temperature implementation described in the paper https://arxiv.org/abs/2309.02772.
LLAMA_API void llama_sampling_entropy(
struct llama_sampling * smpl,
llama_token_data_array * candidates_p,
float min_temp,
float max_temp,
float exponent_val);
llama_token_data_array * candidates);

/// @details Apply temperature and entropy
LLAMA_API void llama_sampling_temp(
struct llama_sampling * smpl,
llama_token_data_array * candidates,
float temp);
llama_token_data_array * candidates);

/// @details Apply constraints from grammar
LLAMA_API void llama_sampling_grammar(
Expand All @@ -1075,6 +1057,7 @@ extern "C" {

/// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
/// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
// TODO: update signature
LLAMA_API void llama_sampling_repetition_penalties(
struct llama_sampling * smpl,
llama_token_data_array * candidates,
Expand All @@ -1091,34 +1074,12 @@ extern "C" {
LLAMA_API void llama_sampling_cfg(
struct llama_sampling * smpl,
float * logits,
float * logits_guidance,
float scale);

/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
/// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.
/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
float * logits_guidance);

/// @details Mirostat algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
LLAMA_API llama_token llama_sampling_sample_mirostat(
struct llama_sampling * smpl,
llama_token_data_array * candidates,
float tau,
float eta,
int32_t m,
float * mu);

/// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
LLAMA_API llama_token llama_sampling_sample_mirostat_v2(
struct llama_sampling * smpl,
llama_token_data_array * candidates,
float tau,
float eta,
float * mu);
llama_token_data_array * candidates);

/// @details Selects the token with the highest probability.
/// Does not compute the token probabilities. Use llama_sampling_softmax() instead.
Expand Down
Loading

0 comments on commit 7a9bf68

Please sign in to comment.