Skip to content

Commit

Permalink
server : fill usage info in embeddings and rerank responses (ggergano…
Browse files Browse the repository at this point in the history
…v#10852)

* server : fill usage info in embeddings response

* server : fill usage info in reranking response
  • Loading branch information
krystiancha authored Dec 17, 2024
1 parent 382bc7f commit 05c3a44
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 10 deletions.
16 changes: 12 additions & 4 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -719,14 +719,17 @@ struct server_task_result_embd : server_task_result {
int index = 0;
std::vector<float> embedding;

int32_t n_tokens;

virtual int get_index() override {
return index;
}

virtual json to_json() override {
return json {
{"index", index},
{"embedding", embedding},
{"index", index},
{"embedding", embedding},
{"tokens_evaluated", n_tokens},
};
}
};
Expand All @@ -735,14 +738,17 @@ struct server_task_result_rerank : server_task_result {
int index = 0;
float score = -1e6;

int32_t n_tokens;

virtual int get_index() override {
return index;
}

virtual json to_json() override {
return json {
{"index", index},
{"score", score},
{"index", index},
{"score", score},
{"tokens_evaluated", n_tokens},
};
}
};
Expand Down Expand Up @@ -1995,6 +2001,7 @@ struct server_context {
auto res = std::make_unique<server_task_result_embd>();
res->id = slot.id_task;
res->index = slot.index;
res->n_tokens = slot.n_prompt_tokens;

const int n_embd = llama_n_embd(model);

Expand Down Expand Up @@ -2030,6 +2037,7 @@ struct server_context {
auto res = std::make_unique<server_task_result_rerank>();
res->id = slot.id_task;
res->index = slot.index;
res->n_tokens = slot.n_prompt_tokens;

for (int i = 0; i < batch.n_tokens; ++i) {
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
Expand Down
30 changes: 30 additions & 0 deletions examples/server/tests/unit/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,33 @@ def test_same_prompt_give_same_result():
vi = res.body['data'][i]['embedding']
for x, y in zip(v0, vi):
assert abs(x - y) < EPSILON


@pytest.mark.parametrize(
"content,n_tokens",
[
("I believe the meaning of life is", 7),
("This is a test", 4),
]
)
def test_embedding_usage_single(content, n_tokens):
global server
server.start()
res = server.make_request("POST", "/embeddings", data={"input": content})
assert res.status_code == 200
assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
assert res.body['usage']['prompt_tokens'] == n_tokens


def test_embedding_usage_multiple():
global server
server.start()
res = server.make_request("POST", "/embeddings", data={
"input": [
"I believe the meaning of life is",
"I believe the meaning of life is",
],
})
assert res.status_code == 200
assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
assert res.body['usage']['prompt_tokens'] == 2 * 7
23 changes: 23 additions & 0 deletions examples/server/tests/unit/test_rerank.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,26 @@ def test_invalid_rerank_req(documents):
})
assert res.status_code == 400
assert "error" in res.body


@pytest.mark.parametrize(
"query,doc1,doc2,n_tokens",
[
("Machine learning is", "A machine", "Learning is", 19),
("Which city?", "Machine learning is ", "Paris, capitale de la", 26),
]
)
def test_rerank_usage(query, doc1, doc2, n_tokens):
global server
server.start()

res = server.make_request("POST", "/rerank", data={
"query": query,
"documents": [
doc1,
doc2,
]
})
assert res.status_code == 200
assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
assert res.body['usage']['prompt_tokens'] == n_tokens
18 changes: 12 additions & 6 deletions examples/server/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -560,21 +560,24 @@ static json oaicompat_completion_params_parse(

static json format_embeddings_response_oaicompat(const json & request, const json & embeddings) {
json data = json::array();
int32_t n_tokens = 0;
int i = 0;
for (const auto & elem : embeddings) {
data.push_back(json{
{"embedding", json_value(elem, "embedding", json::array())},
{"index", i++},
{"object", "embedding"}
});

n_tokens += json_value(elem, "tokens_evaluated", 0);
}

json res = json {
{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
{"object", "list"},
{"usage", json { // TODO: fill
{"prompt_tokens", 0},
{"total_tokens", 0}
{"usage", json {
{"prompt_tokens", n_tokens},
{"total_tokens", n_tokens}
}},
{"data", data}
};
Expand All @@ -584,20 +587,23 @@ static json format_embeddings_response_oaicompat(const json & request, const jso

static json format_response_rerank(const json & request, const json & ranks) {
json data = json::array();
int32_t n_tokens = 0;
int i = 0;
for (const auto & rank : ranks) {
data.push_back(json{
{"index", i++},
{"relevance_score", json_value(rank, "score", 0.0)},
});

n_tokens += json_value(rank, "tokens_evaluated", 0);
}

json res = json {
{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
{"object", "list"},
{"usage", json { // TODO: fill
{"prompt_tokens", 0},
{"total_tokens", 0}
{"usage", json {
{"prompt_tokens", n_tokens},
{"total_tokens", n_tokens}
}},
{"results", data}
};
Expand Down

0 comments on commit 05c3a44

Please sign in to comment.