From e52522b8694ae73abf12feb18d29168674aa1c1b Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 8 Dec 2024 20:38:51 +0100 Subject: [PATCH] server : bring back info of final chunk in stream mode (#10722) * server : bring back into to final chunk in stream mode * clarify a bit * traling space --- examples/server/server.cpp | 174 +++++++++--------- examples/server/tests/unit/test_completion.py | 6 + 2 files changed, 94 insertions(+), 86 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 1c21e55aaa011..1d9c0533d4c40 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -392,7 +392,7 @@ struct server_task_result { return false; } virtual bool is_stop() { - // only used by server_task_result_cmpl_partial + // only used by server_task_result_cmpl_* return false; } virtual int get_index() { @@ -478,14 +478,20 @@ struct server_task_result_cmpl_final : server_task_result { return index; } + virtual bool is_stop() override { + return true; // in stream mode, final responses are considered stop + } + virtual json to_json() override { - return oaicompat ? to_json_oaicompat_chat() : to_json_non_oaicompat(); + return oaicompat + ? (stream ? to_json_oaicompat_chat_stream() : to_json_oaicompat_chat()) + : to_json_non_oaicompat(); } json to_json_non_oaicompat() { json res = json { {"index", index}, - {"content", content}, + {"content", stream ? "" : content}, // in stream mode, content is already in last partial chunk {"id_slot", id_slot}, {"stop", true}, {"model", oaicompat_model}, @@ -546,18 +552,46 @@ struct server_task_result_cmpl_final : server_task_result { return res; } + + json to_json_oaicompat_chat_stream() { + std::time_t t = std::time(0); + std::string finish_reason = "length"; + if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { + finish_reason = "stop"; + } + + json choices = json::array({json{{"finish_reason", finish_reason}, + {"index", 0}, + {"delta", json::object()}}}); + + json ret = json { + {"choices", choices}, + {"created", t}, + {"id", oaicompat_cmpl_id}, + {"model", oaicompat_model}, + {"object", "chat.completion.chunk"}, + {"usage", json { + {"completion_tokens", n_decoded}, + {"prompt_tokens", n_prompt_tokens}, + {"total_tokens", n_decoded + n_prompt_tokens}, + }}, + }; + + if (timings.prompt_n >= 0) { + ret.push_back({"timings", timings.to_json()}); + } + + return ret; + } }; struct server_task_result_cmpl_partial : server_task_result { int index = 0; std::string content; - bool truncated; int32_t n_decoded; int32_t n_prompt_tokens; - stop_type stop = STOP_TYPE_NONE; - std::vector probs_output; result_timings timings; @@ -573,20 +607,19 @@ struct server_task_result_cmpl_partial : server_task_result { } virtual bool is_stop() override { - return stop != STOP_TYPE_NONE; + return false; // in stream mode, partial responses are not considered stop } virtual json to_json() override { - if (oaicompat) { - return to_json_oaicompat(); - } - bool is_stop = stop != STOP_TYPE_NONE; + return oaicompat ? to_json_oaicompat() : to_json_non_oaicompat(); + } + + json to_json_non_oaicompat() { // non-OAI-compat JSON json res = json { {"index", index}, {"content", content}, - {"stop_type", stop_type_to_str(stop)}, - {"stop", is_stop}, + {"stop", false}, {"id_slot", id_slot}, {"tokens_predicted", n_decoded}, {"tokens_evaluated", n_prompt_tokens}, @@ -598,72 +631,54 @@ struct server_task_result_cmpl_partial : server_task_result { if (!probs_output.empty()) { res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output); } - if (is_stop) { - res.push_back({"truncated", truncated}); - } return res; } json to_json_oaicompat() { bool first = n_decoded == 0; - - std::string finish_reason; - if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { - finish_reason = "stop"; - } else if (stop == STOP_TYPE_LIMIT) { - finish_reason = "length"; - } - std::time_t t = std::time(0); - json choices; - if (!finish_reason.empty()) { - choices = json::array({json{{"finish_reason", finish_reason}, - {"index", 0}, - {"delta", json::object()}}}); - } else { - if (first) { - if (content.empty()) { - choices = json::array({json{{"finish_reason", nullptr}, - {"index", 0}, - {"delta", json{{"role", "assistant"}}}}}); - } else { - // We have to send this as two updates to conform to openai behavior - json initial_ret = json{{"choices", json::array({json{ - {"finish_reason", nullptr}, + if (first) { + if (content.empty()) { + choices = json::array({json{{"finish_reason", nullptr}, {"index", 0}, - {"delta", json{ - {"role", "assistant"} - }}}})}, - {"created", t}, - {"id", oaicompat_cmpl_id}, - {"model", oaicompat_model}, - {"object", "chat.completion.chunk"}}; - - json second_ret = json{ - {"choices", json::array({json{{"finish_reason", nullptr}, - {"index", 0}, - {"delta", json{ - {"content", content}}} - }})}, - {"created", t}, - {"id", oaicompat_cmpl_id}, - {"model", oaicompat_model}, - {"object", "chat.completion.chunk"}}; - - return std::vector({initial_ret, second_ret}); - } + {"delta", json{{"role", "assistant"}}}}}); } else { - choices = json::array({json{ - {"finish_reason", nullptr}, - {"index", 0}, - {"delta", - json{ - {"content", content}, - }}, - }}); + // We have to send this as two updates to conform to openai behavior + json initial_ret = json{{"choices", json::array({json{ + {"finish_reason", nullptr}, + {"index", 0}, + {"delta", json{ + {"role", "assistant"} + }}}})}, + {"created", t}, + {"id", oaicompat_cmpl_id}, + {"model", oaicompat_model}, + {"object", "chat.completion.chunk"}}; + + json second_ret = json{ + {"choices", json::array({json{{"finish_reason", nullptr}, + {"index", 0}, + {"delta", json{ + {"content", content}}} + }})}, + {"created", t}, + {"id", oaicompat_cmpl_id}, + {"model", oaicompat_model}, + {"object", "chat.completion.chunk"}}; + + return std::vector({initial_ret, second_ret}); } + } else { + choices = json::array({json{ + {"finish_reason", nullptr}, + {"index", 0}, + {"delta", + json{ + {"content", content}, + }}, + }}); } json ret = json { @@ -678,14 +693,6 @@ struct server_task_result_cmpl_partial : server_task_result { ret.push_back({"timings", timings.to_json()}); } - if (!finish_reason.empty()) { - ret.push_back({"usage", json { - {"completion_tokens", n_decoded}, - {"prompt_tokens", n_prompt_tokens}, - {"total_tokens", n_decoded + n_prompt_tokens}, - }}); - } - return std::vector({ret}); } }; @@ -1888,12 +1895,9 @@ struct server_context { res->index = slot.index; res->content = tkn.text_to_send; - res->truncated = slot.truncated; res->n_decoded = slot.n_decoded; res->n_prompt_tokens = slot.n_prompt_tokens; - res->stop = slot.stop; - res->verbose = slot.params.verbose; res->oaicompat = slot.params.oaicompat; res->oaicompat_chat = slot.params.oaicompat_chat; @@ -1924,12 +1928,6 @@ struct server_context { } void send_final_response(server_slot & slot) { - if (slot.params.stream) { - // if in stream mode, send the last partial response - send_partial_response(slot, {0, "", {}}); - return; - } - auto res = std::make_unique(); res->id = slot.id_task; res->id_slot = slot.id; @@ -1948,6 +1946,7 @@ struct server_context { res->stop = slot.stop; res->verbose = slot.params.verbose; + res->stream = slot.params.stream; res->oaicompat = slot.params.oaicompat; res->oaicompat_chat = slot.params.oaicompat_chat; res->oaicompat_model = slot.params.oaicompat_model; @@ -2100,7 +2099,10 @@ struct server_context { return; } - GGML_ASSERT(dynamic_cast(result.get()) != nullptr); + GGML_ASSERT( + dynamic_cast(result.get()) != nullptr + || dynamic_cast(result.get()) != nullptr + ); if (!result_handler(result)) { cancel_tasks(id_tasks); break; diff --git a/examples/server/tests/unit/test_completion.py b/examples/server/tests/unit/test_completion.py index 1c3aa77de5bba..7f4f9cd038be4 100644 --- a/examples/server/tests/unit/test_completion.py +++ b/examples/server/tests/unit/test_completion.py @@ -42,10 +42,16 @@ def test_completion_stream(prompt: str, n_predict: int, re_content: str, n_promp }) content = "" for data in res: + assert "stop" in data and type(data["stop"]) == bool if data["stop"]: assert data["timings"]["prompt_n"] == n_prompt assert data["timings"]["predicted_n"] == n_predicted assert data["truncated"] == truncated + assert data["stop_type"] == "limit" + assert "generation_settings" in data + assert server.n_predict is not None + assert data["generation_settings"]["n_predict"] == min(n_predict, server.n_predict) + assert data["generation_settings"]["seed"] == server.seed assert match_regex(re_content, content) else: content += data["content"]