From 84e1c33cde9e0a7aafcda2d4f21ba51c300482d7 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 26 Nov 2024 13:36:40 +0200 Subject: [PATCH] server : fix parallel speculative decoding (#10513) ggml-ci --- examples/server/server.cpp | 63 +++++++++++++++++++------------------- 1 file changed, 31 insertions(+), 32 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index c0ea4faf77d42..9c86407c28eba 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2267,50 +2267,49 @@ struct server_context { continue; // continue loop of slots } - llama_token id; + llama_token id = common_sampler_sample(slot.smpl, ctx, slot.i_batch - i); - { - completion_token_output result; - - id = common_sampler_sample(slot.smpl, ctx, slot.i_batch - i); + slot.i_batch = -1; - slot.i_batch = -1; + common_sampler_accept(slot.smpl, id, true); - common_sampler_accept(slot.smpl, id, true); - - slot.n_decoded += 1; - if (slot.n_decoded == 1) { - slot.t_start_generation = ggml_time_us(); - slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3; - metrics.on_prompt_eval(slot); - } + slot.n_decoded += 1; + if (slot.n_decoded == 1) { + slot.t_start_generation = ggml_time_us(); + slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3; + metrics.on_prompt_eval(slot); + } - result.tok = id; + completion_token_output result; + result.tok = id; - const auto * cur_p = common_sampler_get_candidates(slot.smpl); + const auto * cur_p = common_sampler_get_candidates(slot.smpl); - for (size_t i = 0; i < (size_t) slot.params.sampling.n_probs; ++i) { - result.probs.push_back({ - cur_p->data[i].id, - i >= cur_p->size ? 0.0f : cur_p->data[i].p, - }); - } + for (size_t i = 0; i < (size_t) slot.params.sampling.n_probs; ++i) { + result.probs.push_back({ + cur_p->data[i].id, + i >= cur_p->size ? 0.0f : cur_p->data[i].p, + }); + } - if (!process_token(result, slot)) { - // release slot because of stop condition - slot.release(); - slot.print_timings(); - send_final_response(slot); - metrics.on_prediction(slot); - continue; - } + if (!process_token(result, slot)) { + // release slot because of stop condition + slot.release(); + slot.print_timings(); + send_final_response(slot); + metrics.on_prediction(slot); + continue; } + } - // check if the slot supports speculative decoding - if (!slot.can_speculate()) { + // do speculative decoding + for (auto & slot : slots) { + if (!slot.is_processing() || !slot.can_speculate()) { continue; } + llama_token id = slot.sampled; + struct common_speculative_params params_spec; params_spec.n_draft = slot.params.speculative.n_max; params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max;