Skip to content

Commit

Permalink
sampling : fix state cloning
Browse files Browse the repository at this point in the history
ggml-ci
  • Loading branch information
ggerganov committed Sep 7, 2024
1 parent 0e6d170 commit 8a82f38
Showing 1 changed file with 63 additions and 25 deletions.
88 changes: 63 additions & 25 deletions src/llama-sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -643,7 +643,16 @@ static struct llama_sampler_i llama_sampler_dist_i = {
/* .reset = */ nullptr,
/* .clone = */ [](const struct llama_sampler * smpl) {
const auto * ctx = (const llama_sampler_dist *) smpl->ctx;
return llama_sampler_init_dist(ctx->seed);
auto * result = llama_sampler_init_dist(ctx->seed);

// copy the state
{
auto * result_ctx = (llama_sampler_dist *) result->ctx;

result_ctx->rng = ctx->rng;
}

return result;
},
/* .free = */ [](struct llama_sampler * smpl) {
delete (llama_sampler_dist *) smpl->ctx;
Expand Down Expand Up @@ -987,7 +996,17 @@ static struct llama_sampler_i llama_sampler_mirostat_i = {
},
/* .clone = */ [](const struct llama_sampler * smpl) {
const auto * ctx = (const llama_sampler_mirostat *) smpl->ctx;
return llama_sampler_init_mirostat(ctx->n_vocab, ctx->seed, ctx->tau, ctx->eta, ctx->m);
auto * result = llama_sampler_init_mirostat(ctx->n_vocab, ctx->seed, ctx->tau, ctx->eta, ctx->m);

// copy the state
{
auto * result_ctx = (llama_sampler_mirostat *) smpl->ctx;

result_ctx->mu = ctx->mu;
result_ctx->rng = ctx->rng;
}

return result;
},
/* .free = */ [](struct llama_sampler * smpl) {
delete (llama_sampler_mirostat *) smpl->ctx;
Expand Down Expand Up @@ -1062,7 +1081,18 @@ static struct llama_sampler_i llama_sampler_mirostat_v2_i = {
},
/* .clone = */ [](const struct llama_sampler * smpl) {
const auto * ctx = (const llama_sampler_mirostat_v2 *) smpl->ctx;
return llama_sampler_init_mirostat_v2(ctx->seed, ctx->tau, ctx->eta);

auto * result = llama_sampler_init_mirostat_v2(ctx->seed, ctx->tau, ctx->eta);

// copy the state
{
auto * result_ctx = (llama_sampler_mirostat_v2 *) result->ctx;

result_ctx->mu = ctx->mu;
result_ctx->rng = ctx->rng;
}

return result;
},
/* .free = */ [](struct llama_sampler * smpl) {
delete (llama_sampler_mirostat_v2 *) smpl->ctx;
Expand Down Expand Up @@ -1120,16 +1150,20 @@ static struct llama_sampler_i llama_sampler_grammar_i = {
ctx->grammar = grammar_new;
},
/* .clone = */ [](const struct llama_sampler * smpl) {
const auto * ctx_src = (const llama_sampler_grammar *) smpl->ctx;
const auto * ctx = (const llama_sampler_grammar *) smpl->ctx;

auto * result = llama_sampler_init_grammar_impl(*ctx->vocab, nullptr, nullptr);

auto * result = llama_sampler_init_grammar_impl(*ctx_src->vocab, nullptr, nullptr);
// copy the state
{
auto * result_ctx = (llama_sampler_grammar *) result->ctx;

auto * ctx_dst = (llama_sampler_grammar *) result->ctx;
if (ctx_src->grammar) {
ctx_dst->grammar_str = ctx_src->grammar_str;
ctx_dst->grammar_root = ctx_src->grammar_root;
if (ctx->grammar) {
result_ctx->grammar_str = ctx->grammar_str;
result_ctx->grammar_root = ctx->grammar_root;

ctx_dst->grammar = llama_grammar_clone_impl(*ctx_src->grammar);
result_ctx->grammar = llama_grammar_clone_impl(*ctx->grammar);
}
}

return result;
Expand Down Expand Up @@ -1262,20 +1296,24 @@ static struct llama_sampler_i llama_sampler_penalties_i = {
ctx->prev.clear();
},
/* .clone = */ [](const struct llama_sampler * smpl) {
const auto * ctx_src = (const llama_sampler_penalties *) smpl->ctx;
const auto * ctx = (const llama_sampler_penalties *) smpl->ctx;
auto * result = llama_sampler_init_penalties(
ctx_src->n_vocab,
ctx_src->special_eos_id,
ctx_src->linefeed_id,
ctx_src->penalty_last_n,
ctx_src->penalty_repeat,
ctx_src->penalty_freq,
ctx_src->penalty_present,
ctx_src->penalize_nl,
ctx_src->ignore_eos);

auto * ctx_dst = (llama_sampler_penalties *) result->ctx;
ctx_dst->prev = ctx_src->prev;
ctx->n_vocab,
ctx->special_eos_id,
ctx->linefeed_id,
ctx->penalty_last_n,
ctx->penalty_repeat,
ctx->penalty_freq,
ctx->penalty_present,
ctx->penalize_nl,
ctx->ignore_eos);

// copy the state
{
auto * result_ctx = (llama_sampler_penalties *) result->ctx;

result_ctx->prev = ctx->prev;
}

return result;
},
Expand Down Expand Up @@ -1358,8 +1396,8 @@ static struct llama_sampler_i llama_sampler_logit_bias_i = {
},
/* .reset = */ nullptr,
/* .clone = */ [](const struct llama_sampler * smpl) {
const auto * ctx_src = (const llama_sampler_logit_bias *) smpl->ctx;
return llama_sampler_init_logit_bias(ctx_src->n_vocab, ctx_src->logit_bias.size(), ctx_src->logit_bias.data());
const auto * ctx = (const llama_sampler_logit_bias *) smpl->ctx;
return llama_sampler_init_logit_bias(ctx->n_vocab, ctx->logit_bias.size(), ctx->logit_bias.data());
},
/* .free = */ [](struct llama_sampler * smpl) {
delete (llama_sampler_logit_bias *) smpl->ctx;
Expand Down

0 comments on commit 8a82f38

Please sign in to comment.