Skip to content

Commit

Permalink
cont : store params in llama_sampling implementation
Browse files Browse the repository at this point in the history
ggml-ci
  • Loading branch information
ggerganov committed Aug 12, 2024
1 parent d352b01 commit 6174762
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 50 deletions.
20 changes: 7 additions & 13 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,24 +32,18 @@ struct llama_sampling_context * llama_sampling_init(const struct gpt_sampling_pa
lp.penalize_nl = params.penalize_nl;
lp.ignore_eos = params.ignore_eos;

lp.grammar = params.grammar.c_str();
lp.grammar_root = "root";

lp.cfg_prompt = params.cfg_negative_prompt.c_str();
lp.cfg_scale = params.cfg_scale;

lp.n_logit_bias = params.logit_bias.size();
lp.logit_bias = params.logit_bias.data();

result->smpl = llama_sampling_init(model, lp);

llama_sampling_set_rng_seed (result->smpl, params.seed);
llama_sampling_set_grammar (result->smpl, params.grammar.c_str(), "root");
llama_sampling_set_cfg (result->smpl, params.cfg_negative_prompt.c_str(), params.cfg_scale);
llama_sampling_set_logit_bias(result->smpl, params.logit_bias.size(), params.logit_bias.data());
}

result->prev.resize(params.n_prev);

result->n_valid = 0;

llama_sampling_set_rng_seed(result->smpl, params.seed);

return result;
}

Expand All @@ -60,7 +54,7 @@ void llama_sampling_free(struct llama_sampling_context * ctx) {
}

void llama_sampling_reset(llama_sampling_context * ctx) {
llama_sampling_reset(ctx->smpl, ctx->params.grammar.c_str(), "root");
llama_sampling_reset(ctx->smpl);

std::fill(ctx->prev.begin(), ctx->prev.end(), 0);
ctx->cur.clear();
Expand Down Expand Up @@ -378,7 +372,7 @@ static llama_token_data_array llama_sampling_prepare_impl(

if (ctx_cfg) {
float * logits_guidance = llama_get_logits_ith(ctx_cfg, idx);
llama_sampling_apply_guidance(smpl, logits, logits_guidance, params.cfg_scale);
llama_sampling_cfg(smpl, logits, logits_guidance, params.cfg_scale);
}

cur.resize(n_vocab);
Expand Down
20 changes: 7 additions & 13 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -384,16 +384,6 @@ extern "C" {
float mirostat_tau; // target entropy
float mirostat_eta; // learning rate

// https://github.com/ggerganov/llama.cpp/pull/1773
const char * grammar;
const char * grammar_root;

const char * cfg_prompt; // string to help guidance in negative direction
float cfg_scale; // how strong is guidance

int32_t n_logit_bias;
const llama_logit_bias * logit_bias;

// Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value.
bool penalize_nl; // consider newlines as a repeatable token
bool ignore_eos; // ignore the end-of-sequence token
Expand Down Expand Up @@ -1020,10 +1010,14 @@ extern "C" {

LLAMA_API struct llama_sampling * llama_sampling_cp(const struct llama_sampling * smpl);

LLAMA_API void llama_sampling_reset(struct llama_sampling * smpl, const char * grammar_str, const char * grammar_root);
//LLAMA_API void llama_sampling_reset(struct llama_sampling * smpl, const char * grammar_str, const char * grammar_root);
LLAMA_API void llama_sampling_reset(struct llama_sampling * smpl);

// Sets the current rng seed.
LLAMA_API void llama_sampling_set_rng_seed(struct llama_sampling * smpl, uint32_t seed);
LLAMA_API void llama_sampling_set_rng_seed (struct llama_sampling * smpl, uint32_t seed);
LLAMA_API void llama_sampling_set_grammar (struct llama_sampling * smpl, const char * grammar_str, const char * grammar_root);
LLAMA_API void llama_sampling_set_cfg (struct llama_sampling * smpl, const char * cfg_prompt, float cfg_scale);
LLAMA_API void llama_sampling_set_logit_bias(struct llama_sampling * smpl, int32_t n_logit_bias, const llama_logit_bias * logit_bias);

/// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
LLAMA_API void llama_sampling_softmax(
Expand Down Expand Up @@ -1098,7 +1092,7 @@ extern "C" {
/// @param logits Logits extracted from the original generation context.
/// @param logits_guidance Logits extracted from a separate context from the same model. Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context.
/// @param scale Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance.
LLAMA_API void llama_sampling_apply_guidance(
LLAMA_API void llama_sampling_cfg(
struct llama_sampling * smpl,
float * logits,
float * logits_guidance,
Expand Down
63 changes: 51 additions & 12 deletions src/llama-sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,9 @@ llama_sampling::~llama_sampling() {
struct llama_sampling * llama_sampling_init_impl(const struct llama_vocab & vocab, struct llama_sampling_params params) {
auto * result = new llama_sampling(vocab);

// TODO: store params
result->params = params;

if (params.grammar != nullptr && params.grammar[0] != '\0') {
result->grammar = llama_grammar_init_impl(result->vocab, params.grammar, params.grammar_root);
}
llama_sampling_set_rng_seed_impl(*result, params.seed);

return result;
}
Expand All @@ -52,26 +50,31 @@ void llama_sampling_free_impl(struct llama_sampling * sampling) {
struct llama_sampling * llama_sampling_cp_impl(const struct llama_sampling & smpl) {
auto * result = new llama_sampling(smpl.vocab);

result->params = smpl.params;

result->grammar_str = smpl.grammar_str;
result->grammar_root = smpl.grammar_root;

result->cfg_prompt = smpl.cfg_prompt;
result->cfg_scale = smpl.cfg_scale;

result->logit_bias = smpl.logit_bias;

if (smpl.grammar) {
result->grammar = llama_grammar_copy_impl(*smpl.grammar);
}

return result;
}

void llama_sampling_reset_impl(struct llama_sampling & smpl, const char * grammar_str, const char * grammar_root) {
// TODO: this is dumb, need to fix
const struct llama_vocab * vocab = nullptr;

void llama_sampling_reset_impl(struct llama_sampling & smpl) {
if (smpl.grammar) {
vocab = &smpl.grammar->vocab;

llama_grammar_free_impl(smpl.grammar);
smpl.grammar = nullptr;
}

if (grammar_str != nullptr && grammar_str[0] != '\0') {
smpl.grammar = llama_grammar_init_impl(*vocab, grammar_str, grammar_root);
if (!smpl.grammar_str.empty()) {
smpl.grammar = llama_grammar_init_impl(smpl.vocab, smpl.grammar_str.data(), smpl.grammar_root.data());
}
}

Expand All @@ -83,6 +86,42 @@ void llama_sampling_set_rng_seed_impl(struct llama_sampling & smpl, uint32_t see
smpl.rng.seed(seed);
}

void llama_sampling_set_grammar_impl(struct llama_sampling & smpl, const char * grammar_str, const char * grammar_root) {
if (smpl.grammar) {
llama_grammar_free_impl(smpl.grammar);
smpl.grammar = nullptr;
}

if (grammar_str != nullptr && grammar_str[0] != '\0') {
smpl.grammar_str = grammar_str;
smpl.grammar_root = grammar_root;

smpl.grammar = llama_grammar_init_impl(smpl.vocab, grammar_str, grammar_root);
} else {
smpl.grammar_str.clear();
smpl.grammar_root.clear();
}
}

void llama_sampling_set_cfg_impl(struct llama_sampling & smpl, const char * cfg_prompt, float cfg_scale) {
if (cfg_prompt != nullptr && cfg_prompt[0] != '\0') {
smpl.cfg_prompt = cfg_prompt;
} else {
smpl.cfg_prompt.clear();
}

smpl.cfg_scale = cfg_scale;
}

void llama_sampling_set_logit_bias_impl(struct llama_sampling & smpl, int32_t n_logit_bias, const llama_logit_bias * logit_bias) {
smpl.logit_bias.clear();
smpl.logit_bias.reserve(n_logit_bias);

for (int32_t i = 0; i < n_logit_bias; ++i) {
smpl.logit_bias.push_back(logit_bias[i]);
}
}

void llama_sampling_softmax_impl(struct llama_sampling & /*smpl*/, llama_token_data_array * candidates) {
GGML_ASSERT(candidates->size > 0);

Expand Down
21 changes: 18 additions & 3 deletions src/llama-sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,22 @@ struct llama_sampling {
llama_sampling(const struct llama_vocab & vocab);
~llama_sampling();

const llama_vocab & vocab;
llama_sampling_params params;

std::string grammar_str;
std::string grammar_root;

std::string cfg_prompt;
float cfg_scale = 1.0f;

std::vector<llama_logit_bias> logit_bias; // logit biases to apply

// state

std::mt19937 rng;

const struct llama_vocab & vocab;

struct llama_grammar * grammar = nullptr;

mutable int64_t t_total_us = 0;
Expand All @@ -30,10 +42,13 @@ void llama_sampling_free_impl(struct llama_sampling * sampling);

struct llama_sampling * llama_sampling_cp_impl(const struct llama_sampling & smpl);

void llama_sampling_reset_impl(struct llama_sampling & smpl, const char * grammar_str, const char * grammar_root);
void llama_sampling_reset_impl(struct llama_sampling & smpl);

// TODO: move the API below as member functions of llama_sampling
void llama_sampling_set_rng_seed_impl(struct llama_sampling & smpl, uint32_t seed);
void llama_sampling_set_rng_seed_impl (struct llama_sampling & smpl, uint32_t seed);
void llama_sampling_set_grammar_impl (struct llama_sampling & smpl, const char * grammar_str, const char * grammar_root);
void llama_sampling_set_cfg_impl (struct llama_sampling & smpl, const char * cfg_prompt, float cfg_scale);
void llama_sampling_set_logit_bias_impl(struct llama_sampling & smpl, int32_t n_logit_bias, const llama_logit_bias * logit_bias);

void llama_sampling_softmax_impl (struct llama_sampling & smpl, llama_token_data_array * candidates);
void llama_sampling_top_k_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep);
Expand Down
24 changes: 15 additions & 9 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16511,12 +16511,6 @@ struct llama_sampling_params llama_sampling_default_params() {
/*.mirostat =*/ 0,
/*.mirostat_tau =*/ 5.00f,
/*.mirostat_eta =*/ 0.10f,
/*.grammar =*/ nullptr,
/*.grammar_root =*/ nullptr,
/*.cfg_prompt =*/ nullptr,
/*.cfg_scale =*/ 1.00f,
/*.n_logit_bias =*/ 0,
/*.logit_bias =*/ nullptr,
/*.penalize_nl =*/ false,
/*.ignore_eos =*/ false,
};
Expand Down Expand Up @@ -19109,14 +19103,26 @@ struct llama_sampling * llama_sampling_cp(const struct llama_sampling * smpl) {
return llama_sampling_cp_impl(*smpl);
}

void llama_sampling_reset(struct llama_sampling * smpl, const char * grammar_str, const char * grammar_root) {
llama_sampling_reset_impl(*smpl, grammar_str, grammar_root);
void llama_sampling_reset(struct llama_sampling * smpl) {
llama_sampling_reset_impl(*smpl);
}

void llama_sampling_set_rng_seed(struct llama_sampling * smpl, uint32_t seed) {
llama_sampling_set_rng_seed_impl(*smpl, seed);
}

void llama_sampling_set_grammar(struct llama_sampling * smpl, const char * grammar_str, const char * grammar_root) {
llama_sampling_set_grammar_impl(*smpl, grammar_str, grammar_root);
}

void llama_sampling_set_cfg(struct llama_sampling * smpl, const char * cfg_prompt, float cfg_scale) {
llama_sampling_set_cfg_impl(*smpl, cfg_prompt, cfg_scale);
}

void llama_sampling_set_logit_bias(struct llama_sampling * smpl, int32_t n_logit_bias, const llama_logit_bias * logit_bias) {
llama_sampling_set_logit_bias_impl(*smpl, n_logit_bias, logit_bias);
}

void llama_sampling_softmax(struct llama_sampling * smpl, llama_token_data_array * candidates) {
time_meas tm(smpl->t_total_us);

Expand Down Expand Up @@ -19186,7 +19192,7 @@ void llama_sampling_repetition_penalties(
llama_sampling_repetition_penalties_impl(*smpl, candidates, last_tokens, penalty_last_n, penalty_repeat, penalty_freq, penalty_present);
}

void llama_sampling_apply_guidance(
void llama_sampling_cfg(
struct llama_sampling * smpl,
float * logits,
float * logits_guidance,
Expand Down

0 comments on commit 6174762

Please sign in to comment.