Skip to content

Commit

Permalink
sampling : convert mirostat samplers to constraints
Browse files Browse the repository at this point in the history
ggml-ci
  • Loading branch information
ggerganov committed Sep 4, 2024
1 parent 697a20f commit 23f0802
Show file tree
Hide file tree
Showing 5 changed files with 305 additions and 215 deletions.
88 changes: 45 additions & 43 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,6 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st

lparams.seed = params.seed;
lparams.n_prev = params.n_prev;
lparams.mirostat = params.mirostat;
lparams.mirostat_tau = params.mirostat_tau;
lparams.mirostat_eta = params.mirostat_eta;

auto * result = new gpt_sampler {
/* .params = */ params,
Expand All @@ -69,29 +66,39 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
/* .smpl = */ llama_sampler_init(model, lparams)
};

for (const auto & cnstr : params.constraints) {
switch (cnstr) {
case GPT_CONSTRAINT_TYPE_TOP_K:
llama_sampler_add_constraint(result->smpl, llama_constraint_init_top_k (params.top_k, params.min_keep));
break;
case GPT_CONSTRAINT_TYPE_TOP_P:
llama_sampler_add_constraint(result->smpl, llama_constraint_init_top_p (params.top_p, params.min_keep));
break;
case GPT_CONSTRAINT_TYPE_MIN_P:
llama_sampler_add_constraint(result->smpl, llama_constraint_init_min_p (params.min_p, params.min_keep));
break;
case GPT_CONSTRAINT_TYPE_TFS_Z:
llama_sampler_add_constraint(result->smpl, llama_constraint_init_tail_free(params.tfs_z, params.min_keep));
break;
case GPT_CONSTRAINT_TYPE_TYPICAL_P:
llama_sampler_add_constraint(result->smpl, llama_constraint_init_typical (params.typ_p, params.min_keep));
break;
case GPT_CONSTRAINT_TYPE_TEMPERATURE:
llama_sampler_add_constraint(result->smpl, llama_constraint_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
break;
default:
GGML_ASSERT(false && "unknown constraint type");
if (params.mirostat == 0) {
for (const auto & cnstr : params.constraints) {
switch (cnstr) {
case GPT_CONSTRAINT_TYPE_TOP_K:
llama_sampler_add_constraint(result->smpl, llama_constraint_init_top_k (params.top_k, params.min_keep));
break;
case GPT_CONSTRAINT_TYPE_TOP_P:
llama_sampler_add_constraint(result->smpl, llama_constraint_init_top_p (params.top_p, params.min_keep));
break;
case GPT_CONSTRAINT_TYPE_MIN_P:
llama_sampler_add_constraint(result->smpl, llama_constraint_init_min_p (params.min_p, params.min_keep));
break;
case GPT_CONSTRAINT_TYPE_TFS_Z:
llama_sampler_add_constraint(result->smpl, llama_constraint_init_tail_free(params.tfs_z, params.min_keep));
break;
case GPT_CONSTRAINT_TYPE_TYPICAL_P:
llama_sampler_add_constraint(result->smpl, llama_constraint_init_typical (params.typ_p, params.min_keep));
break;
case GPT_CONSTRAINT_TYPE_TEMPERATURE:
llama_sampler_add_constraint(result->smpl, llama_constraint_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
break;
default:
GGML_ASSERT(false && "unknown constraint type");
}
}
} else if (params.mirostat == 1) {
llama_sampler_add_constraint(result->smpl, llama_constraint_init_temp(params.temp));
llama_sampler_add_constraint(result->smpl, llama_constraint_init_mirostat(model, params.mirostat_tau, params.mirostat_eta));
} else if (params.mirostat == 2) {
llama_sampler_add_constraint(result->smpl, llama_constraint_init_temp(params.temp));
llama_sampler_add_constraint(result->smpl, llama_constraint_init_mirostat_v2(params.mirostat_tau, params.mirostat_eta));
} else {
GGML_ASSERT(false && "unknown mirostat version");
}

return result;
Expand Down Expand Up @@ -153,7 +160,6 @@ static llama_token gpt_sampler_sample(
struct llama_sampler * smpl,
struct llama_token_data_array * cur_p,
float temp,
int mirostat,
int n_probs) {
llama_token res = 0;

Expand All @@ -167,24 +173,20 @@ static llama_token gpt_sampler_sample(
// apply all sampling constraints and then sample
llama_sampler_apply(smpl, cur_p);

if (mirostat != 0) {
res = llama_sampler_sample_mirostat(smpl, cur_p);
} else {
res = llama_sampler_sample_dist(smpl, cur_p);
res = llama_sampler_sample_dist(smpl, cur_p);

//{
// const int n_top = 10;
// LOG("top %d candidates:\n", n_top);
//{
// const int n_top = 10;
// LOG("top %d candidates:\n", n_top);

// for (int i = 0; i < n_top; i++) {
// const llama_token id = cur_p.data[i].id;
// (void)id; // To avoid a warning that id is unused when logging is disabled.
// LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(smpl, id).c_str(), cur_p.data[i].p);
// }
//}
// for (int i = 0; i < n_top; i++) {
// const llama_token id = cur_p.data[i].id;
// (void)id; // To avoid a warning that id is unused when logging is disabled.
// LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(smpl, id).c_str(), cur_p.data[i].p);
// }
//}

//LOG("sampled token: %5d: '%s'\n", res, llama_token_to_piece(smpl, res).c_str());
}
//LOG("sampled token: %5d: '%s'\n", res, llama_token_to_piece(smpl, res).c_str());
}

return res;
Expand All @@ -208,7 +210,7 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context
llama_constraint_apply(pnlt, cur_p);

// first, sample the token without any grammar constraints
const llama_token id = gpt_sampler_sample(smpl, nullptr, params.temp, params.mirostat, params.n_probs);
const llama_token id = gpt_sampler_sample(smpl, nullptr, params.temp, params.n_probs);

// check if it the sampled token fits the grammar
{
Expand All @@ -231,7 +233,7 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context
llama_constraint_apply(pnlt, cur_p);
llama_constraint_apply(grmr, cur_p);

return gpt_sampler_sample(smpl, cur_p, params.temp, params.mirostat, params.n_probs);
return gpt_sampler_sample(smpl, cur_p, params.temp, params.n_probs);
}

void gpt_sampler_apply_grammar(struct gpt_sampler * gsmpl, llama_token_data_array * cur_p) {
Expand Down
47 changes: 29 additions & 18 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -367,16 +367,18 @@ extern "C" {
float bias;
} llama_logit_bias;

enum llama_sampler_type {
LLAMA_SAMPLER_TYPE_GREEDY = 0,
LLAMA_SAMPLER_TYPE_DIST = 1,
};

typedef struct llama_sampler_params {
uint32_t seed; // the seed used to initialize the rng of the sampler

int32_t n_prev; // size of ring buffer to keep previous accepted tokens (needed for llama_sampler_prev_ API)

int32_t mirostat; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
float mirostat_tau; // target entropy
float mirostat_eta; // learning rate

// TODO: add type of sampler: greedy, dist, mirostat, etc.
// TODO: will be used by the llama_decode_with_sampler() API in the future
enum llama_sampler_type type;
} llama_sampler_params;

// performance timing information
Expand Down Expand Up @@ -1003,17 +1005,18 @@ extern "C" {
//
// - Samplers
// The llama_sampler samples a token based on the candidate token probabilities. Before the actual sampling, the
// sampler can apply a sequence of constraints to the candidate tokens.
// sampler can apply a sequence of constraints in order to modify the probabilities of the candidates.
//
// The llama_sampler object contains the entire sampling information:
//
// - RNG state (seed and generator)
// - Custom set of constraints (see llama_sampler_add_constraint)
// - Sampling method (greedy, dist, mirostat)
// - Sampling method (greedy, dist)
// - Previous tokens
//
// In the future, it will be utilized offload the sampling to the backends (e.g. GPU).
//
// TODO: in the future, the entire API should be changed to accept llama_vocab, instead of llama_model

// constraints

Expand All @@ -1039,14 +1042,23 @@ extern "C" {
llama_constraint_context_t ctx;
};

LLAMA_API struct llama_constraint * llama_constraint_init_softmax (void);
LLAMA_API struct llama_constraint * llama_constraint_init_top_k (int32_t k, int32_t min_keep);
LLAMA_API struct llama_constraint * llama_constraint_init_top_p (float p, int32_t min_keep);
LLAMA_API struct llama_constraint * llama_constraint_init_min_p (float p, int32_t min_keep);
LLAMA_API struct llama_constraint * llama_constraint_init_tail_free (float z, int32_t min_keep);
LLAMA_API struct llama_constraint * llama_constraint_init_typical (float p, int32_t min_keep);
LLAMA_API struct llama_constraint * llama_constraint_init_temp (float t);
LLAMA_API struct llama_constraint * llama_constraint_init_temp_ext (float t, float delta, float exponent);
LLAMA_API struct llama_constraint * llama_constraint_init_softmax (void);
LLAMA_API struct llama_constraint * llama_constraint_init_top_k (int32_t k, int32_t min_keep);
LLAMA_API struct llama_constraint * llama_constraint_init_top_p (float p, int32_t min_keep);
LLAMA_API struct llama_constraint * llama_constraint_init_min_p (float p, int32_t min_keep);
LLAMA_API struct llama_constraint * llama_constraint_init_tail_free (float z, int32_t min_keep);
LLAMA_API struct llama_constraint * llama_constraint_init_typical (float p, int32_t min_keep);
LLAMA_API struct llama_constraint * llama_constraint_init_temp (float t);
LLAMA_API struct llama_constraint * llama_constraint_init_temp_ext (float t, float delta, float exponent);

LLAMA_API struct llama_constraint * llama_constraint_init_mirostat(
const struct llama_model * model,
float tau,
float eta);

LLAMA_API struct llama_constraint * llama_constraint_init_mirostat_v2(
float tau,
float eta);

LLAMA_API struct llama_constraint * llama_constraint_init_grammar(
const struct llama_model * model,
Expand Down Expand Up @@ -1093,9 +1105,8 @@ extern "C" {
LLAMA_API void llama_sampler_accept(struct llama_sampler * smpl, llama_token token);
LLAMA_API void llama_sampler_apply (struct llama_sampler * smpl, llama_token_data_array * cur_p);

LLAMA_API llama_token llama_sampler_sample_dist (struct llama_sampler * smpl, llama_token_data_array * cur_p);
LLAMA_API llama_token llama_sampler_sample_greedy (struct llama_sampler * smpl, llama_token_data_array * cur_p, bool probs);
LLAMA_API llama_token llama_sampler_sample_mirostat(struct llama_sampler * smpl, llama_token_data_array * cur_p);
LLAMA_API llama_token llama_sampler_sample_dist (struct llama_sampler * smpl, llama_token_data_array * cur_p);
LLAMA_API llama_token llama_sampler_sample_greedy(struct llama_sampler * smpl, llama_token_data_array * cur_p, bool probs);

/// @details Get the number of accepted tokens so far (max of n_prev)
LLAMA_API int llama_sampler_n_prev(const struct llama_sampler * smpl);
Expand Down
Loading

0 comments on commit 23f0802

Please sign in to comment.