Skip to content

Commit

Permalink
cont : move samplers to llama lib
Browse files Browse the repository at this point in the history
ggml-ci
  • Loading branch information
ggerganov committed Aug 29, 2024
1 parent 861ad6f commit 62984db
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 81 deletions.
14 changes: 7 additions & 7 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
}
if (arg == "--typical") {
CHECK_ARG
sparams.typical_p = std::stof(argv[i]);
sparams.typ_p = std::stof(argv[i]);
return true;
}
if (arg == "--repeat-last-n") {
Expand Down Expand Up @@ -1532,12 +1532,12 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
"simplified sequence for samplers that will be used (default: %s)", sampler_type_chars.c_str() });
options.push_back({ "*", " --ignore-eos", "ignore end of stream token and continue generating (implies --logit-bias EOS-inf)" });
options.push_back({ "*", " --penalize-nl", "penalize newline tokens (default: %s)", sparams.penalize_nl ? "true" : "false" });
options.push_back({ "*", " --temp N", "temperature (default: %.1f)", (double)sparams.temp });
options.push_back({ "*", " --temp T", "temperature (default: %.1f)", (double)sparams.temp });
options.push_back({ "*", " --top-k N", "top-k sampling (default: %d, 0 = disabled)", sparams.top_k });
options.push_back({ "*", " --top-p N", "top-p sampling (default: %.1f, 1.0 = disabled)", (double)sparams.top_p });
options.push_back({ "*", " --min-p N", "min-p sampling (default: %.1f, 0.0 = disabled)", (double)sparams.min_p });
options.push_back({ "*", " --tfs N", "tail free sampling, parameter z (default: %.1f, 1.0 = disabled)", (double)sparams.tfs_z });
options.push_back({ "*", " --typical N", "locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)", (double)sparams.typical_p });
options.push_back({ "*", " --top-p P", "top-p sampling (default: %.1f, 1.0 = disabled)", (double)sparams.top_p });
options.push_back({ "*", " --min-p P", "min-p sampling (default: %.1f, 0.0 = disabled)", (double)sparams.min_p });
options.push_back({ "*", " --tfs P", "tail free sampling, parameter z (default: %.1f, 1.0 = disabled)", (double)sparams.tfs_z });
options.push_back({ "*", " --typical P", "locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)", (double)sparams.typ_p });
options.push_back({ "*", " --repeat-last-n N", "last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)", sparams.penalty_last_n });
options.push_back({ "*", " --repeat-penalty N", "penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)", (double)sparams.penalty_repeat });
options.push_back({ "*", " --presence-penalty N", "repeat alpha presence penalty (default: %.1f, 0.0 = disabled)", (double)sparams.penalty_present });
Expand Down Expand Up @@ -3316,7 +3316,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
fprintf(stream, "top_k: %d # default: 40\n", sparams.top_k);
fprintf(stream, "top_p: %f # default: 0.95\n", sparams.top_p);
fprintf(stream, "min_p: %f # default: 0.0\n", sparams.min_p);
fprintf(stream, "typical_p: %f # default: 1.0\n", sparams.typical_p);
fprintf(stream, "typ_p: %f # default: 1.0\n", sparams.typ_p);
fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false");
fprintf(stream, "display_prompt: %s # default: true\n", params.display_prompt ? "true" : "false");
}
75 changes: 7 additions & 68 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_model * m
lparams.top_p = params.top_p;
lparams.min_p = params.min_p;
lparams.tfs_z = params.tfs_z;
lparams.typical_p = params.typical_p;
lparams.typ_p = params.typ_p;
lparams.temp = params.temp;
lparams.dynatemp_range = params.dynatemp_range;
lparams.dynatemp_exponent = params.dynatemp_exponent;
Expand Down Expand Up @@ -94,7 +94,7 @@ std::string llama_sampling_print(const gpt_sampling_params & params) {
"\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, typical_p = %.3f, temp = %.3f\n"
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present,
params.top_k, params.tfs_z, params.top_p, params.min_p, params.typical_p, params.temp,
params.top_k, params.tfs_z, params.top_p, params.min_p, params.typ_p, params.temp,
params.mirostat, params.mirostat_eta, params.mirostat_tau);

return std::string(result);
Expand Down Expand Up @@ -132,7 +132,7 @@ std::string llama_sampling_type_to_str(llama_sampler_type sampler_type) {
switch (sampler_type) {
case LLAMA_SAMPLER_TYPE_TOP_K: return "top_k";
case LLAMA_SAMPLER_TYPE_TFS_Z: return "tfs_z";
case LLAMA_SAMPLER_TYPE_TYPICAL_P: return "typical_p";
case LLAMA_SAMPLER_TYPE_TYPICAL_P: return "typ_p";
case LLAMA_SAMPLER_TYPE_TOP_P: return "top_p";
case LLAMA_SAMPLER_TYPE_MIN_P: return "min_p";
case LLAMA_SAMPLER_TYPE_TEMPERATURE: return "temperature";
Expand All @@ -144,7 +144,7 @@ std::vector<llama_sampler_type> llama_sampling_types_from_names(const std::vecto
std::unordered_map<std::string, llama_sampler_type> sampler_canonical_name_map {
{ "top_k", LLAMA_SAMPLER_TYPE_TOP_K },
{ "top_p", LLAMA_SAMPLER_TYPE_TOP_P },
{ "typical_p", LLAMA_SAMPLER_TYPE_TYPICAL_P },
{ "typ_p", LLAMA_SAMPLER_TYPE_TYPICAL_P },
{ "min_p", LLAMA_SAMPLER_TYPE_MIN_P },
{ "tfs_z", LLAMA_SAMPLER_TYPE_TFS_Z },
{ "temperature", LLAMA_SAMPLER_TYPE_TEMPERATURE },
Expand All @@ -158,6 +158,8 @@ std::vector<llama_sampler_type> llama_sampling_types_from_names(const std::vecto
{ "nucleus", LLAMA_SAMPLER_TYPE_TOP_P },
{ "typical-p", LLAMA_SAMPLER_TYPE_TYPICAL_P },
{ "typical", LLAMA_SAMPLER_TYPE_TYPICAL_P },
{ "typ-p", LLAMA_SAMPLER_TYPE_TYPICAL_P },
{ "typ", LLAMA_SAMPLER_TYPE_TYPICAL_P },
{ "min-p", LLAMA_SAMPLER_TYPE_MIN_P },
{ "tfs-z", LLAMA_SAMPLER_TYPE_TFS_Z },
{ "tfs", LLAMA_SAMPLER_TYPE_TFS_Z },
Expand Down Expand Up @@ -205,29 +207,6 @@ std::vector<llama_sampler_type> llama_sampling_types_from_chars(const std::strin
return sampler_types;
}

// no reasons to expose this function in header
static void sampler_queue(
struct llama_sampling_context * ctx_sampling,
struct llama_token_data_array * cur_p) {
llama_sampling * smpl = ctx_sampling->smpl;

const gpt_sampling_params & params = ctx_sampling->params;

const std::vector<llama_sampler_type> & samplers = params.samplers;

for (const auto & sampler : samplers) {
switch (sampler) {
case LLAMA_SAMPLER_TYPE_TOP_K: llama_sampling_top_k (smpl, cur_p); break;
case LLAMA_SAMPLER_TYPE_TFS_Z: llama_sampling_tail_free(smpl, cur_p); break;
case LLAMA_SAMPLER_TYPE_TYPICAL_P: llama_sampling_typical (smpl, cur_p); break;
case LLAMA_SAMPLER_TYPE_TOP_P: llama_sampling_top_p (smpl, cur_p); break;
case LLAMA_SAMPLER_TYPE_MIN_P: llama_sampling_min_p (smpl, cur_p); break;
case LLAMA_SAMPLER_TYPE_TEMPERATURE: llama_sampling_temp (smpl, cur_p); break;
default : break;
}
}
}

void llama_sampling_prepare(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
Expand All @@ -238,47 +217,7 @@ void llama_sampling_prepare(
static llama_token llama_sampling_sample(
struct llama_sampling_context * ctx_sampling,
struct llama_token_data_array * cur_p) {
llama_sampling * smpl = ctx_sampling->smpl;

const gpt_sampling_params & params = ctx_sampling->params;

const float temp = params.temp;
const int mirostat = params.mirostat;

llama_token id = 0;

if (temp < 0.0f || (temp == 0.0f && params.n_probs > 0)) {
// greedy sampling, with probs
llama_sampling_softmax(smpl, cur_p);
id = cur_p->data[0].id;
} else if (temp == 0.0f) {
// greedy sampling, no probs
id = llama_sampling_sample_greedy(smpl, cur_p);
} else {
if (mirostat != 0) {
llama_sampling_temp(smpl, cur_p);
id = llama_sampling_sample_mirostat(smpl, cur_p);
} else {
sampler_queue(ctx_sampling, cur_p);

id = llama_sampling_sample_dist(smpl, cur_p);

//{
// const int n_top = 10;
// LOG("top %d candidates:\n", n_top);

// for (int i = 0; i < n_top; i++) {
// const llama_token id = cur_p.data[i].id;
// (void)id; // To avoid a warning that id is unused when logging is disabled.
// LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(smpl, id).c_str(), cur_p.data[i].p);
// }
//}

//LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(smpl, id).c_str());
}
}

return id;
return llama_sampling_sample(ctx_sampling->smpl, cur_p);
}

llama_token llama_sampling_sample(
Expand Down
2 changes: 1 addition & 1 deletion common/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ typedef struct gpt_sampling_params {
float top_p = 0.95f; // 1.0 = disabled
float min_p = 0.05f; // 0.0 = disabled
float tfs_z = 1.00f; // 1.0 = disabled
float typical_p = 1.00f; // 1.0 = disabled
float typ_p = 1.00f; // typical_p, 1.0 = disabled
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
float dynatemp_range = 0.00f; // 0.0 = disabled
float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
Expand Down
4 changes: 2 additions & 2 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -914,7 +914,7 @@ struct server_context {
slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p);
slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p);
slot.sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z);
slot.sparams.typical_p = json_value(data, "typical_p", default_sparams.typical_p);
slot.sparams.typ_p = json_value(data, "typical_p", default_sparams.typ_p);
slot.sparams.temp = json_value(data, "temperature", default_sparams.temp);
slot.sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range);
slot.sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent);
Expand Down Expand Up @@ -1283,7 +1283,7 @@ struct server_context {
{"top_p", slot.sparams.top_p},
{"min_p", slot.sparams.min_p},
{"tfs_z", slot.sparams.tfs_z},
{"typical_p", slot.sparams.typical_p},
{"typical_p", slot.sparams.typ_p},
{"repeat_last_n", slot.sparams.penalty_last_n},
{"repeat_penalty", slot.sparams.penalty_repeat},
{"presence_penalty", slot.sparams.penalty_present},
Expand Down
7 changes: 6 additions & 1 deletion include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ extern "C" {
float top_p; // 1.0 = disabled
float min_p; // 0.0 = disabled
float tfs_z; // 1.0 = disabled
float typical_p; // 1.0 = disabled
float typ_p; // typical_p, 1.0 = disabled
float temp; // <= 0.0 to sample greedily, 0.0 to not output probabilities
float dynatemp_range; // 0.0 = disabled
float dynatemp_exponent; // controls how entropy maps to temperature in dynamic temperature sampler
Expand Down Expand Up @@ -1106,6 +1106,11 @@ extern "C" {
struct llama_sampling * smpl,
llama_token_data_array * candidates);

/// @details Sample a token using the configured samplers.
LLAMA_API llama_token llama_sampling_sample(
struct llama_sampling * smpl,
llama_token_data_array * candidates);

/// @details Accepts the sampled token into the sampling context
LLAMA_API void llama_sampling_accept(
struct llama_sampling * smpl,
Expand Down
62 changes: 60 additions & 2 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17418,7 +17418,7 @@ struct llama_sampling_params llama_sampling_default_params() {
/*.top_p =*/ 0.95f,
/*.min_p =*/ 0.05f,
/*.tfs_z =*/ 1.00f,
/*.typical_p =*/ 1.00f,
/*.typ_p =*/ 1.00f,
/*.temp =*/ 0.80f,
/*.dynatemp_range =*/ 0.00f,
/*.dynatemp_exponent =*/ 1.00f,
Expand Down Expand Up @@ -20169,7 +20169,7 @@ void llama_sampling_tail_free(struct llama_sampling * smpl, llama_token_data_arr
void llama_sampling_typical(struct llama_sampling * smpl, llama_token_data_array * candidates) {
time_meas tm(smpl->t_sample_us);

llama_sampling_typical_impl(candidates, smpl->params.typical_p, smpl->params.min_keep);
llama_sampling_typical_impl(candidates, smpl->params.typ_p, smpl->params.min_keep);
}

void llama_sampling_temp(struct llama_sampling * smpl, llama_token_data_array * candidates) {
Expand Down Expand Up @@ -20271,6 +20271,64 @@ llama_token llama_sampling_sample_dist(struct llama_sampling * smpl, llama_token
return res;
}

llama_token llama_sampling_sample(struct llama_sampling * smpl, llama_token_data_array * candidates) {
time_meas tm(smpl->t_sample_us);

const auto & params = smpl->params;

const float temp = params.temp;
const int mirostat = params.mirostat;

auto & cur_p = candidates;

llama_token res = 0;

if (temp < 0.0f || (temp == 0.0f && params.n_probs > 0)) {
// greedy sampling, with probs
llama_sampling_softmax_impl(cur_p);
res = cur_p->data[0].id;
} else if (temp == 0.0f) {
// greedy sampling, no probs
res = llama_sampling_sample_greedy(smpl, cur_p);
} else {
if (mirostat != 0) {
llama_sampling_temp(smpl, cur_p);
res = llama_sampling_sample_mirostat(smpl, cur_p);
} else {
for (const auto & sampler : smpl->samplers) {
switch (sampler) {
case LLAMA_SAMPLER_TYPE_TOP_K: llama_sampling_top_k_impl (cur_p, smpl->params.top_k, smpl->params.min_keep); break;
case LLAMA_SAMPLER_TYPE_TFS_Z: llama_sampling_tail_free_impl(cur_p, smpl->params.tfs_z, smpl->params.min_keep); break;
case LLAMA_SAMPLER_TYPE_TYPICAL_P: llama_sampling_typical_impl (cur_p, smpl->params.typ_p, smpl->params.min_keep); break;
case LLAMA_SAMPLER_TYPE_TOP_P: llama_sampling_top_p_impl (cur_p, smpl->params.top_p, smpl->params.min_keep); break;
case LLAMA_SAMPLER_TYPE_MIN_P: llama_sampling_min_p_impl (cur_p, smpl->params.min_p, smpl->params.min_keep); break;
case LLAMA_SAMPLER_TYPE_TEMPERATURE: llama_sampling_temp_impl (cur_p, temp); break;
default : break;
}
}

res = llama_sampling_sample_dist(smpl, cur_p);

//{
// const int n_top = 10;
// LOG("top %d candidates:\n", n_top);

// for (int i = 0; i < n_top; i++) {
// const llama_token id = cur_p.data[i].id;
// (void)id; // To avoid a warning that id is unused when logging is disabled.
// LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(smpl, id).c_str(), cur_p.data[i].p);
// }
//}

//LOG("sampled token: %5d: '%s'\n", res, llama_token_to_piece(smpl, res).c_str());
}
}

smpl->n_sample++;

return res;
}

void llama_sampling_accept(
struct llama_sampling * smpl,
llama_token token,
Expand Down

0 comments on commit 62984db

Please sign in to comment.