diff --git a/examples/server/README.md b/examples/server/README.md index 39bbab7b0ee5f..45ffb547fcbcc 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -416,7 +416,7 @@ node index.js `samplers`: The order the samplers should be applied in. An array of strings representing sampler type names. If a sampler is not set, it will not be used. If a sampler is specified more than once, it will be applied multiple times. Default: `["dry", "top_k", "typ_p", "top_p", "min_p", "xtc", "temperature"]` - these are all the available values. - `timing_per_token`: Include prompt processing and text generation speed information in each response. Default: `false` + `timings_per_token`: Include prompt processing and text generation speed information in each response. Default: `false` **Response format** diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 839b88ca92ee2..72b04c5c37f8e 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -177,7 +177,7 @@ struct server_slot { bool stopped_word = false; bool stopped_limit = false; - bool timing_per_token = false; + bool timings_per_token = false; bool oaicompat = false; @@ -884,7 +884,7 @@ struct server_context { slot.oaicompat_model = ""; } - slot.timing_per_token = json_value(data, "timing_per_token", false); + slot.timings_per_token = json_value(data, "timings_per_token", false); slot.params.stream = json_value(data, "stream", false); slot.params.cache_prompt = json_value(data, "cache_prompt", true); @@ -1283,7 +1283,7 @@ struct server_context { {"speculative.n_max", slot.params.speculative.n_max}, {"speculative.n_min", slot.params.speculative.n_min}, {"speculative.p_min", slot.params.speculative.p_min}, - {"timing_per_token", slot.timing_per_token}, + {"timings_per_token", slot.timings_per_token}, }; } @@ -1341,7 +1341,7 @@ struct server_context { res.data["model"] = slot.oaicompat_model; } - if (slot.timing_per_token) { + if (slot.timings_per_token) { res.data["timings"] = slot.get_formated_timings(); } diff --git a/examples/server/tests/unit/test_chat_completion.py b/examples/server/tests/unit/test_chat_completion.py index 1048d6fcaf500..8a439f9ef0f29 100644 --- a/examples/server/tests/unit/test_chat_completion.py +++ b/examples/server/tests/unit/test_chat_completion.py @@ -146,3 +146,20 @@ def test_invalid_chat_completion_req(messages): }) assert res.status_code == 400 or res.status_code == 500 assert "error" in res.body + + +def test_chat_completion_with_timings_per_token(): + global server + server.start() + res = server.make_stream_request("POST", "/chat/completions", data={ + "max_tokens": 10, + "messages": [{"role": "user", "content": "test"}], + "stream": True, + "timings_per_token": True, + }) + for data in res: + assert "timings" in data + assert "prompt_per_second" in data["timings"] + assert "predicted_per_second" in data["timings"] + assert "predicted_n" in data["timings"] + assert data["timings"]["predicted_n"] <= 10