From 8180fd5a30ab1e7b3d71bcc25d5d62e7e0458ab0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Krystian=20Chachu=C5=82a?= Date: Mon, 16 Dec 2024 14:45:06 +0100 Subject: [PATCH] server : fill usage info in reranking response --- examples/server/server.cpp | 8 ++++++-- examples/server/utils.hpp | 9 ++++++--- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 08be3b59e3c006..76965c148163e5 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -738,14 +738,17 @@ struct server_task_result_rerank : server_task_result { int index = 0; float score = -1e6; + int32_t n_prompt_tokens; + virtual int get_index() override { return index; } virtual json to_json() override { return json { - {"index", index}, - {"score", score}, + {"index", index}, + {"score", score}, + {"tokens_evaluated", n_prompt_tokens}, }; } }; @@ -2034,6 +2037,7 @@ struct server_context { auto res = std::make_unique(); res->id = slot.id_task; res->index = slot.index; + res->n_prompt_tokens = slot.n_prompt_tokens; for (int i = 0; i < batch.n_tokens; ++i) { if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index c66368d802d72c..2dceb74806f7c3 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -587,20 +587,23 @@ static json format_embeddings_response_oaicompat(const json & request, const jso static json format_response_rerank(const json & request, const json & ranks) { json data = json::array(); + int32_t n_prompt_tokens = 0; int i = 0; for (const auto & rank : ranks) { data.push_back(json{ {"index", i++}, {"relevance_score", json_value(rank, "score", 0.0)}, }); + + n_prompt_tokens += json_value(rank, "tokens_evaluated", 0); } json res = json { {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, {"object", "list"}, - {"usage", json { // TODO: fill - {"prompt_tokens", 0}, - {"total_tokens", 0} + {"usage", json { + {"prompt_tokens", n_prompt_tokens}, + {"total_tokens", n_prompt_tokens} }}, {"results", data} };