Skip to content

Commit

Permalink
server : bring back info of final chunk in stream mode (ggerganov#10722)
Browse files Browse the repository at this point in the history
* server : bring back into to final chunk in stream mode

* clarify a bit

* traling space
  • Loading branch information
ngxson authored Dec 8, 2024
1 parent 06d7014 commit e52522b
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 86 deletions.
174 changes: 88 additions & 86 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ struct server_task_result {
return false;
}
virtual bool is_stop() {
// only used by server_task_result_cmpl_partial
// only used by server_task_result_cmpl_*
return false;
}
virtual int get_index() {
Expand Down Expand Up @@ -478,14 +478,20 @@ struct server_task_result_cmpl_final : server_task_result {
return index;
}

virtual bool is_stop() override {
return true; // in stream mode, final responses are considered stop
}

virtual json to_json() override {
return oaicompat ? to_json_oaicompat_chat() : to_json_non_oaicompat();
return oaicompat
? (stream ? to_json_oaicompat_chat_stream() : to_json_oaicompat_chat())
: to_json_non_oaicompat();
}

json to_json_non_oaicompat() {
json res = json {
{"index", index},
{"content", content},
{"content", stream ? "" : content}, // in stream mode, content is already in last partial chunk
{"id_slot", id_slot},
{"stop", true},
{"model", oaicompat_model},
Expand Down Expand Up @@ -546,18 +552,46 @@ struct server_task_result_cmpl_final : server_task_result {

return res;
}

json to_json_oaicompat_chat_stream() {
std::time_t t = std::time(0);
std::string finish_reason = "length";
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
finish_reason = "stop";
}

json choices = json::array({json{{"finish_reason", finish_reason},
{"index", 0},
{"delta", json::object()}}});

json ret = json {
{"choices", choices},
{"created", t},
{"id", oaicompat_cmpl_id},
{"model", oaicompat_model},
{"object", "chat.completion.chunk"},
{"usage", json {
{"completion_tokens", n_decoded},
{"prompt_tokens", n_prompt_tokens},
{"total_tokens", n_decoded + n_prompt_tokens},
}},
};

if (timings.prompt_n >= 0) {
ret.push_back({"timings", timings.to_json()});
}

return ret;
}
};

struct server_task_result_cmpl_partial : server_task_result {
int index = 0;
std::string content;

bool truncated;
int32_t n_decoded;
int32_t n_prompt_tokens;

stop_type stop = STOP_TYPE_NONE;

std::vector<completion_token_output> probs_output;
result_timings timings;

Expand All @@ -573,20 +607,19 @@ struct server_task_result_cmpl_partial : server_task_result {
}

virtual bool is_stop() override {
return stop != STOP_TYPE_NONE;
return false; // in stream mode, partial responses are not considered stop
}

virtual json to_json() override {
if (oaicompat) {
return to_json_oaicompat();
}
bool is_stop = stop != STOP_TYPE_NONE;
return oaicompat ? to_json_oaicompat() : to_json_non_oaicompat();
}

json to_json_non_oaicompat() {
// non-OAI-compat JSON
json res = json {
{"index", index},
{"content", content},
{"stop_type", stop_type_to_str(stop)},
{"stop", is_stop},
{"stop", false},
{"id_slot", id_slot},
{"tokens_predicted", n_decoded},
{"tokens_evaluated", n_prompt_tokens},
Expand All @@ -598,72 +631,54 @@ struct server_task_result_cmpl_partial : server_task_result {
if (!probs_output.empty()) {
res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output);
}
if (is_stop) {
res.push_back({"truncated", truncated});
}
return res;
}

json to_json_oaicompat() {
bool first = n_decoded == 0;

std::string finish_reason;
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
finish_reason = "stop";
} else if (stop == STOP_TYPE_LIMIT) {
finish_reason = "length";
}

std::time_t t = std::time(0);

json choices;

if (!finish_reason.empty()) {
choices = json::array({json{{"finish_reason", finish_reason},
{"index", 0},
{"delta", json::object()}}});
} else {
if (first) {
if (content.empty()) {
choices = json::array({json{{"finish_reason", nullptr},
{"index", 0},
{"delta", json{{"role", "assistant"}}}}});
} else {
// We have to send this as two updates to conform to openai behavior
json initial_ret = json{{"choices", json::array({json{
{"finish_reason", nullptr},
if (first) {
if (content.empty()) {
choices = json::array({json{{"finish_reason", nullptr},
{"index", 0},
{"delta", json{
{"role", "assistant"}
}}}})},
{"created", t},
{"id", oaicompat_cmpl_id},
{"model", oaicompat_model},
{"object", "chat.completion.chunk"}};

json second_ret = json{
{"choices", json::array({json{{"finish_reason", nullptr},
{"index", 0},
{"delta", json{
{"content", content}}}
}})},
{"created", t},
{"id", oaicompat_cmpl_id},
{"model", oaicompat_model},
{"object", "chat.completion.chunk"}};

return std::vector<json>({initial_ret, second_ret});
}
{"delta", json{{"role", "assistant"}}}}});
} else {
choices = json::array({json{
{"finish_reason", nullptr},
{"index", 0},
{"delta",
json{
{"content", content},
}},
}});
// We have to send this as two updates to conform to openai behavior
json initial_ret = json{{"choices", json::array({json{
{"finish_reason", nullptr},
{"index", 0},
{"delta", json{
{"role", "assistant"}
}}}})},
{"created", t},
{"id", oaicompat_cmpl_id},
{"model", oaicompat_model},
{"object", "chat.completion.chunk"}};

json second_ret = json{
{"choices", json::array({json{{"finish_reason", nullptr},
{"index", 0},
{"delta", json{
{"content", content}}}
}})},
{"created", t},
{"id", oaicompat_cmpl_id},
{"model", oaicompat_model},
{"object", "chat.completion.chunk"}};

return std::vector<json>({initial_ret, second_ret});
}
} else {
choices = json::array({json{
{"finish_reason", nullptr},
{"index", 0},
{"delta",
json{
{"content", content},
}},
}});
}

json ret = json {
Expand All @@ -678,14 +693,6 @@ struct server_task_result_cmpl_partial : server_task_result {
ret.push_back({"timings", timings.to_json()});
}

if (!finish_reason.empty()) {
ret.push_back({"usage", json {
{"completion_tokens", n_decoded},
{"prompt_tokens", n_prompt_tokens},
{"total_tokens", n_decoded + n_prompt_tokens},
}});
}

return std::vector<json>({ret});
}
};
Expand Down Expand Up @@ -1888,12 +1895,9 @@ struct server_context {
res->index = slot.index;
res->content = tkn.text_to_send;

res->truncated = slot.truncated;
res->n_decoded = slot.n_decoded;
res->n_prompt_tokens = slot.n_prompt_tokens;

res->stop = slot.stop;

res->verbose = slot.params.verbose;
res->oaicompat = slot.params.oaicompat;
res->oaicompat_chat = slot.params.oaicompat_chat;
Expand Down Expand Up @@ -1924,12 +1928,6 @@ struct server_context {
}

void send_final_response(server_slot & slot) {
if (slot.params.stream) {
// if in stream mode, send the last partial response
send_partial_response(slot, {0, "", {}});
return;
}

auto res = std::make_unique<server_task_result_cmpl_final>();
res->id = slot.id_task;
res->id_slot = slot.id;
Expand All @@ -1948,6 +1946,7 @@ struct server_context {
res->stop = slot.stop;

res->verbose = slot.params.verbose;
res->stream = slot.params.stream;
res->oaicompat = slot.params.oaicompat;
res->oaicompat_chat = slot.params.oaicompat_chat;
res->oaicompat_model = slot.params.oaicompat_model;
Expand Down Expand Up @@ -2100,7 +2099,10 @@ struct server_context {
return;
}

GGML_ASSERT(dynamic_cast<server_task_result_cmpl_partial*>(result.get()) != nullptr);
GGML_ASSERT(
dynamic_cast<server_task_result_cmpl_partial*>(result.get()) != nullptr
|| dynamic_cast<server_task_result_cmpl_final*>(result.get()) != nullptr
);
if (!result_handler(result)) {
cancel_tasks(id_tasks);
break;
Expand Down
6 changes: 6 additions & 0 deletions examples/server/tests/unit/test_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,16 @@ def test_completion_stream(prompt: str, n_predict: int, re_content: str, n_promp
})
content = ""
for data in res:
assert "stop" in data and type(data["stop"]) == bool
if data["stop"]:
assert data["timings"]["prompt_n"] == n_prompt
assert data["timings"]["predicted_n"] == n_predicted
assert data["truncated"] == truncated
assert data["stop_type"] == "limit"
assert "generation_settings" in data
assert server.n_predict is not None
assert data["generation_settings"]["n_predict"] == min(n_predict, server.n_predict)
assert data["generation_settings"]["seed"] == server.seed
assert match_regex(re_content, content)
else:
content += data["content"]
Expand Down

0 comments on commit e52522b

Please sign in to comment.