Skip to content

Commit

Permalink
common : simplify gpt_sampler
Browse files Browse the repository at this point in the history
ggml-ci
  • Loading branch information
ggerganov committed Sep 6, 2024
1 parent b2fc0f8 commit 43eda32
Showing 1 changed file with 19 additions and 30 deletions.
49 changes: 19 additions & 30 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,7 @@ struct ring_buffer {
struct gpt_sampler {
gpt_sampler_params params;

struct llama_sampler * bias;
struct llama_sampler * pnlt;
struct llama_sampler * grmr;

struct llama_sampler * chain;

ring_buffer<llama_token> prev;
Expand Down Expand Up @@ -140,11 +137,11 @@ std::string gpt_sampler_params::print() const {
}

std::string gpt_sampler_print(const struct gpt_sampler * gsmpl) {
std::string result = "\tlogits";
std::string result = "\tlogits ";

for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain); i++) {
const auto * smpl = llama_sampler_chain_get(gsmpl->chain, i);
result += std::string(" -> ") + llama_sampler_name(smpl) + " ";
result += std::string("-> ") + llama_sampler_name(smpl) + " ";
}

return result;
Expand All @@ -157,25 +154,29 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st

auto * result = new gpt_sampler {
/* .params = */ params,
/* .bias = */ llama_sampler_init_logit_bias(
model,
params.logit_bias.size(),
params.logit_bias.data()),
/* .pnlt = */ llama_sampler_init_penalties(
model,
params.penalty_last_n,
params.penalty_repeat,
params.penalty_freq,
params.penalty_present,
params.penalize_nl,
params.ignore_eos),
/* .grmr = */ llama_sampler_init_grammar(model, params.grammar.c_str(), "root"),
/* .chain = */ llama_sampler_chain_init(lparams),
/* .prev = */ ring_buffer<llama_token>(params.n_prev),
/* .cur = */ {},
/* .cur_p = */ {},
};

llama_sampler_chain_add(result->chain,
llama_sampler_init_logit_bias(
model,
params.logit_bias.size(),
params.logit_bias.data()));

llama_sampler_chain_add(result->chain,
llama_sampler_init_penalties(
model,
params.penalty_last_n,
params.penalty_repeat,
params.penalty_freq,
params.penalty_present,
params.penalize_nl,
params.ignore_eos));

if (params.temp > 0.0f) {
if (params.mirostat == 0) {
for (const auto & cnstr : params.samplers) {
Expand Down Expand Up @@ -223,8 +224,6 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st

void gpt_sampler_free(struct gpt_sampler * gsmpl) {
if (gsmpl) {
llama_sampler_free(gsmpl->bias);
llama_sampler_free(gsmpl->pnlt);
llama_sampler_free(gsmpl->grmr);

llama_sampler_free(gsmpl->chain);
Expand All @@ -236,8 +235,6 @@ void gpt_sampler_free(struct gpt_sampler * gsmpl) {
struct gpt_sampler * gpt_sampler_clone(gpt_sampler * gsmpl) {
return new gpt_sampler {
/* .params = */ gsmpl->params,
/* .bias = */ llama_sampler_clone(gsmpl->bias),
/* .pnlt = */ llama_sampler_clone(gsmpl->pnlt),
/* .grmr = */ llama_sampler_clone(gsmpl->grmr),
/* .chain = */ llama_sampler_clone(gsmpl->chain),
/* .prev = */ gsmpl->prev,
Expand Down Expand Up @@ -282,18 +279,13 @@ void gpt_perf_print(const struct llama_context * ctx, const struct gpt_sampler *
}

llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
auto & bias = gsmpl->bias;
auto & pnlt = gsmpl->pnlt;
auto & grmr = gsmpl->grmr;
auto & chain = gsmpl->chain;

gsmpl->set_logits(ctx, idx);

auto & cur_p = gsmpl->cur_p;

llama_sampler_apply(bias, &cur_p);
llama_sampler_apply(pnlt, &cur_p);

if (grammar_first) {
llama_sampler_apply(grmr, &cur_p);
}
Expand Down Expand Up @@ -325,10 +317,7 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context
// if the token is not valid, sample again, first apply the grammar samplers and then sample
gsmpl->set_logits(ctx, idx);

llama_sampler_apply(bias, &cur_p);
llama_sampler_apply(pnlt, &cur_p);
llama_sampler_apply(grmr, &cur_p);

llama_sampler_apply(grmr, &cur_p);
llama_sampler_apply(chain, &cur_p);

GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");
Expand Down

0 comments on commit 43eda32

Please sign in to comment.