From 8a82f388cdc32aa677f272054328683819b81d91 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 7 Sep 2024 14:38:00 +0300 Subject: [PATCH] sampling : fix state cloning ggml-ci --- src/llama-sampling.cpp | 88 ++++++++++++++++++++++++++++++------------ 1 file changed, 63 insertions(+), 25 deletions(-) diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index e53b3d3a77edc..02b93b64c6575 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -643,7 +643,16 @@ static struct llama_sampler_i llama_sampler_dist_i = { /* .reset = */ nullptr, /* .clone = */ [](const struct llama_sampler * smpl) { const auto * ctx = (const llama_sampler_dist *) smpl->ctx; - return llama_sampler_init_dist(ctx->seed); + auto * result = llama_sampler_init_dist(ctx->seed); + + // copy the state + { + auto * result_ctx = (llama_sampler_dist *) result->ctx; + + result_ctx->rng = ctx->rng; + } + + return result; }, /* .free = */ [](struct llama_sampler * smpl) { delete (llama_sampler_dist *) smpl->ctx; @@ -987,7 +996,17 @@ static struct llama_sampler_i llama_sampler_mirostat_i = { }, /* .clone = */ [](const struct llama_sampler * smpl) { const auto * ctx = (const llama_sampler_mirostat *) smpl->ctx; - return llama_sampler_init_mirostat(ctx->n_vocab, ctx->seed, ctx->tau, ctx->eta, ctx->m); + auto * result = llama_sampler_init_mirostat(ctx->n_vocab, ctx->seed, ctx->tau, ctx->eta, ctx->m); + + // copy the state + { + auto * result_ctx = (llama_sampler_mirostat *) smpl->ctx; + + result_ctx->mu = ctx->mu; + result_ctx->rng = ctx->rng; + } + + return result; }, /* .free = */ [](struct llama_sampler * smpl) { delete (llama_sampler_mirostat *) smpl->ctx; @@ -1062,7 +1081,18 @@ static struct llama_sampler_i llama_sampler_mirostat_v2_i = { }, /* .clone = */ [](const struct llama_sampler * smpl) { const auto * ctx = (const llama_sampler_mirostat_v2 *) smpl->ctx; - return llama_sampler_init_mirostat_v2(ctx->seed, ctx->tau, ctx->eta); + + auto * result = llama_sampler_init_mirostat_v2(ctx->seed, ctx->tau, ctx->eta); + + // copy the state + { + auto * result_ctx = (llama_sampler_mirostat_v2 *) result->ctx; + + result_ctx->mu = ctx->mu; + result_ctx->rng = ctx->rng; + } + + return result; }, /* .free = */ [](struct llama_sampler * smpl) { delete (llama_sampler_mirostat_v2 *) smpl->ctx; @@ -1120,16 +1150,20 @@ static struct llama_sampler_i llama_sampler_grammar_i = { ctx->grammar = grammar_new; }, /* .clone = */ [](const struct llama_sampler * smpl) { - const auto * ctx_src = (const llama_sampler_grammar *) smpl->ctx; + const auto * ctx = (const llama_sampler_grammar *) smpl->ctx; + + auto * result = llama_sampler_init_grammar_impl(*ctx->vocab, nullptr, nullptr); - auto * result = llama_sampler_init_grammar_impl(*ctx_src->vocab, nullptr, nullptr); + // copy the state + { + auto * result_ctx = (llama_sampler_grammar *) result->ctx; - auto * ctx_dst = (llama_sampler_grammar *) result->ctx; - if (ctx_src->grammar) { - ctx_dst->grammar_str = ctx_src->grammar_str; - ctx_dst->grammar_root = ctx_src->grammar_root; + if (ctx->grammar) { + result_ctx->grammar_str = ctx->grammar_str; + result_ctx->grammar_root = ctx->grammar_root; - ctx_dst->grammar = llama_grammar_clone_impl(*ctx_src->grammar); + result_ctx->grammar = llama_grammar_clone_impl(*ctx->grammar); + } } return result; @@ -1262,20 +1296,24 @@ static struct llama_sampler_i llama_sampler_penalties_i = { ctx->prev.clear(); }, /* .clone = */ [](const struct llama_sampler * smpl) { - const auto * ctx_src = (const llama_sampler_penalties *) smpl->ctx; + const auto * ctx = (const llama_sampler_penalties *) smpl->ctx; auto * result = llama_sampler_init_penalties( - ctx_src->n_vocab, - ctx_src->special_eos_id, - ctx_src->linefeed_id, - ctx_src->penalty_last_n, - ctx_src->penalty_repeat, - ctx_src->penalty_freq, - ctx_src->penalty_present, - ctx_src->penalize_nl, - ctx_src->ignore_eos); - - auto * ctx_dst = (llama_sampler_penalties *) result->ctx; - ctx_dst->prev = ctx_src->prev; + ctx->n_vocab, + ctx->special_eos_id, + ctx->linefeed_id, + ctx->penalty_last_n, + ctx->penalty_repeat, + ctx->penalty_freq, + ctx->penalty_present, + ctx->penalize_nl, + ctx->ignore_eos); + + // copy the state + { + auto * result_ctx = (llama_sampler_penalties *) result->ctx; + + result_ctx->prev = ctx->prev; + } return result; }, @@ -1358,8 +1396,8 @@ static struct llama_sampler_i llama_sampler_logit_bias_i = { }, /* .reset = */ nullptr, /* .clone = */ [](const struct llama_sampler * smpl) { - const auto * ctx_src = (const llama_sampler_logit_bias *) smpl->ctx; - return llama_sampler_init_logit_bias(ctx_src->n_vocab, ctx_src->logit_bias.size(), ctx_src->logit_bias.data()); + const auto * ctx = (const llama_sampler_logit_bias *) smpl->ctx; + return llama_sampler_init_logit_bias(ctx->n_vocab, ctx->logit_bias.size(), ctx->logit_bias.data()); }, /* .free = */ [](struct llama_sampler * smpl) { delete (llama_sampler_logit_bias *) smpl->ctx;