From 7e017cfbc84ecde61bd9cfd18aadb41892be303e Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Fri, 30 Aug 2024 18:02:28 +0200 Subject: [PATCH 1/3] server : add Hermes-3 tool call support --- common/common.cpp | 17 +++++ common/common.h | 6 ++ examples/server/server.cpp | 53 +++++++++++----- examples/server/tool-call.hpp | 114 ++++++++++++++++++++++++++++++++++ examples/server/utils.hpp | 75 +++++++++++++++++----- src/llama.cpp | 12 ++-- 6 files changed, 240 insertions(+), 37 deletions(-) create mode 100644 examples/server/tool-call.hpp diff --git a/common/common.cpp b/common/common.cpp index 9fa18472512ab..28ec4f5fc2a10 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -2253,6 +2253,10 @@ bool string_parse_kv_override(const char * data, std::vector model_template(res, 0); + llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size()); + return std::string(model_template.data(), model_template.size()); + } +} + // // KV cache utils // diff --git a/common/common.h b/common/common.h index cb5e7f6df10c5..db0800432aef4 100644 --- a/common/common.h +++ b/common/common.h @@ -320,6 +320,8 @@ static std::vector string_split(const std::string & str, char delim) { bool string_parse_kv_override(const char * data, std::vector & overrides); void string_process_escapes(std::string & input); +bool string_contains(std::string haystack, std::string needle); + // // Filesystem utils // @@ -428,6 +430,10 @@ std::string llama_chat_format_single(const struct llama_model * model, std::string llama_chat_format_example(const struct llama_model * model, const std::string & tmpl); +// Returns the chat template stored inside the model +// (empty string if model does not have built-in chat template) +std::string llama_get_chat_template(const struct llama_model * model); + // // KV cache utils // diff --git a/examples/server/server.cpp b/examples/server/server.cpp index cc938e80d6a6d..73d0088396b29 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -4,6 +4,7 @@ #include "json-schema-to-grammar.h" #include "llama.h" #include "grammar-parser.h" +#include "tool-call.hpp" #ifndef NDEBUG // crash the server in debug mode, otherwise send an http 500 error @@ -157,6 +158,7 @@ struct server_slot { std::string generated_text; std::vector cache_tokens; std::vector generated_token_probs; + enum llama_response_state response_state = LLAMA_RESPONSE_STATE_UNKNOWN; bool infill = false; bool embedding = false; @@ -207,6 +209,7 @@ struct server_slot { infill = false; ga_i = 0; n_past_se = 0; + response_state = LLAMA_RESPONSE_STATE_UNKNOWN; generated_token_probs.clear(); } @@ -625,6 +628,7 @@ struct server_context { llama_model * model = nullptr; llama_context * ctx = nullptr; std::vector lora_adapters; + llama_tool_format tool_format = LLAMA_TOOL_FORMAT_NOT_SUPPORTED; gpt_params params; @@ -1217,7 +1221,13 @@ struct server_context { break; } - if (!incomplete) { + if (slot.response_state == LLAMA_RESPONSE_STATE_UNKNOWN) { + slot.response_state = check_response_state(tool_format, slot.generated_text); + } + + // if response is tool call, we cannot stream it + // instead, we wait for the full response, then extract JSON + if (!incomplete && slot.response_state == LLAMA_RESPONSE_STATE_TEXT) { size_t pos = std::min(slot.n_sent_text, slot.generated_text.size()); const std::string str_test = slot.generated_text.substr(pos); @@ -1247,9 +1257,7 @@ struct server_context { if (slot.params.stream) { send_partial_response(slot, result); } - } - - if (incomplete) { + } else { slot.has_next_token = true; } @@ -1396,6 +1404,10 @@ struct server_context { {"multimodal", false} }; + if (slot.response_state == LLAMA_RESPONSE_STATE_TOOL_CALL) { + res.data["tool_calls"] = parse_tool_response(tool_format, tkn.text_to_send); + } + if (slot.sparams.n_probs > 0) { const std::vector to_send_toks = llama_tokenize(ctx, tkn.text_to_send, false); const size_t probs_pos = std::min(slot.n_sent_token_probs, slot.generated_token_probs.size()); @@ -1444,6 +1456,10 @@ struct server_context { {"timings", slot.get_formated_timings()} }; + if (slot.response_state == LLAMA_RESPONSE_STATE_TOOL_CALL) { + res.data["tool_calls"] = parse_tool_response(tool_format, slot.generated_text); + } + if (slot.sparams.n_probs > 0) { std::vector probs; if (!slot.params.stream && slot.stopped_word) { @@ -2937,19 +2953,14 @@ int main(int argc, char ** argv) { }; const auto handle_props = [&ctx_server](const httplib::Request &, httplib::Response & res) { - std::string template_key = "tokenizer.chat_template", curr_tmpl; - int32_t tlen = llama_model_meta_val_str(ctx_server.model, template_key.c_str(), nullptr, 0); - if (tlen > 0) { - std::vector curr_tmpl_buf(tlen + 1, 0); - if (llama_model_meta_val_str(ctx_server.model, template_key.c_str(), curr_tmpl_buf.data(), curr_tmpl_buf.size()) == tlen) { - curr_tmpl = std::string(curr_tmpl_buf.data(), tlen); - } - } + std::string chat_tmpl = ctx_server.params.chat_template.empty() + ? llama_get_chat_template(ctx_server.model) + : ctx_server.params.chat_template; json data = { { "system_prompt", ctx_server.system_prompt.c_str() }, { "default_generation_settings", ctx_server.default_generation_settings_for_props }, { "total_slots", ctx_server.params.n_parallel }, - { "chat_template", curr_tmpl.c_str() } + { "chat_template", chat_tmpl }, }; res.set_content(data.dump(), MIMETYPE_JSON); @@ -3056,7 +3067,13 @@ int main(int argc, char ** argv) { res_error(res, format_error_response("This server does not support chat completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); return; } - json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template); + json body = json::parse(req.body); + + if (body.contains("tools") && ctx_server.tool_format != LLAMA_TOOL_FORMAT_NOT_SUPPORTED) { + body["prompt"] = format_chat_with_tool(ctx_server.tool_format, body.at("messages"), body.at("tools")); + } + + json data = oaicompat_completion_params_parse(ctx_server.model, body, params.chat_template); const int id_task = ctx_server.queue_tasks.get_new_id(); @@ -3423,11 +3440,15 @@ int main(int argc, char ** argv) { } } + // decide if we can enable tool calls + ctx_server.tool_format = get_tool_format(ctx_server.ctx); + // print sample chat example to make it clear which template is used { LOG_INFO("chat template", { - {"chat_example", llama_chat_format_example(ctx_server.model, params.chat_template)}, - {"built_in", params.chat_template.empty()}, + {"chat_example", llama_chat_format_example(ctx_server.model, params.chat_template)}, + {"built_in", params.chat_template.empty()}, + {"tool_call_enabled", ctx_server.tool_format != LLAMA_TOOL_FORMAT_NOT_SUPPORTED }, }); } diff --git a/examples/server/tool-call.hpp b/examples/server/tool-call.hpp new file mode 100644 index 0000000000000..0326e8a93e8cd --- /dev/null +++ b/examples/server/tool-call.hpp @@ -0,0 +1,114 @@ +#pragma once + +#include "llama.h" +#include "common.h" +#include "utils.hpp" + +// Change JSON_ASSERT from assert() to GGML_ASSERT: +#define JSON_ASSERT GGML_ASSERT +#include "json.hpp" + +#include +#include +#include + +using json = nlohmann::ordered_json; + +enum llama_tool_format { + LLAMA_TOOL_FORMAT_NOT_SUPPORTED, + LLAMA_TOOL_FORMAT_HERMES_3, +}; + +enum llama_response_state { + LLAMA_RESPONSE_STATE_UNKNOWN, + LLAMA_RESPONSE_STATE_TEXT, + LLAMA_RESPONSE_STATE_TOOL_CALL, +}; + +// get the tool call format for the loaded model +// this function does linear search, so do not call it repeatedly +inline enum llama_tool_format get_tool_format(const struct llama_context * ctx) { + auto model = llama_get_model(ctx); + auto has_token = [&](std::string piece) { + for (int i = 0; i < llama_n_vocab(model); i++) { + const std::string token_str = llama_token_to_piece(ctx, i, true); + if (token_str == piece) { + return true; + } + } + return false; + }; + if (has_token("<|im_start|>") && has_token("")) { + return LLAMA_TOOL_FORMAT_HERMES_3; + } + return LLAMA_TOOL_FORMAT_NOT_SUPPORTED; +} + +inline std::string format_chat_with_tool(enum llama_tool_format format, const std::vector & messages, json tools) { + if (!tools.is_array()) { + throw std::runtime_error("tools must be an array"); + } + std::stringstream ss; + auto chat = parse_chat_messages(messages); + if (format == LLAMA_TOOL_FORMAT_HERMES_3) { + ss << "<|im_start|>system\n\n"; + ss << "You are a function calling AI model. You are provided with function signatures within XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: \n\n"; + for (auto tool : tools) { + ss << tool.dump(1, '\t') << "\n\n"; + } + ss << " Use the following pydantic model json schema for each tool call you will make: {\"properties\": {\"arguments\": {\"title\": \"Arguments\", \"type\": \"object\"}, \"name\": {\"title\": \"Name\", \"type\": \"string\"}}, \"required\": [\"arguments\", \"name\"], \"title\": \"FunctionCall\", \"type\": \"object\"} For each function call return a json object with function name and arguments within XML tags as follows:\n"; + ss << "\n"; + ss << "{\"arguments\": , \"name\": }\n"; + ss << "<|im_end|>\n"; + for (auto & message : chat) { + std::string role(message.role); + if (role == "system") { + continue; // for optimal performance, we skip user-defined system message + } + ss << "<|im_start|>" << role << "\n\n"; + if (role == "tool") { + ss << "\n" << string_strip(message.content) << "\n\n"; + } else { + ss << string_strip(message.content) << "<|im_end|>\n"; + } + } + ss << "<|im_start|>assistant\n\n"; + } else { + throw std::runtime_error("tool_call is not supported by this model"); + } + LOG_VERBOSE("format_chat_with_tool", {{"text", ss.str()}}); + return ss.str(); +} + +// check if the response is text or tool_call +// if it is tool_call, we may have to disable streaming, because we must parse the whole JSON response +inline enum llama_response_state check_response_state(enum llama_tool_format format, const std::string & generated_text) { + if (format == LLAMA_TOOL_FORMAT_NOT_SUPPORTED) { + return LLAMA_RESPONSE_STATE_TEXT; + } else if (format == LLAMA_TOOL_FORMAT_HERMES_3 && generated_text.rfind("", 0) == 0) { + return LLAMA_RESPONSE_STATE_TOOL_CALL; + } + return LLAMA_RESPONSE_STATE_TEXT; +} + +// convert model's response to OAI format +inline json parse_tool_response(enum llama_tool_format format, const std::string & generated_text) { + if (format == LLAMA_TOOL_FORMAT_NOT_SUPPORTED) { + return json{}; + } else if (format == LLAMA_TOOL_FORMAT_HERMES_3) { + std::string tmp(generated_text); + string_replace_all(tmp, "", ""); + string_replace_all(tmp, "", ""); + json tool = json::parse(tmp); + std::vector tool_calls = {json{ + {"id", tool.at("name")}, + {"type", "function"}, + {"function", { + {"name", tool.at("name")}, + {"arguments", tool.at("arguments").dump()}, // OAI requires this to be JSON-stringified + }}, + }}; + return tool_calls; + } + return generated_text; +} diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index e6a1f069723ec..e46f7b0327c54 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -116,10 +116,9 @@ static inline void server_log(const char * level, const char * function, int lin // chat template utils // -// Format given chat. If tmpl is empty, we take the template from model metadata -inline std::string format_chat(const struct llama_model * model, const std::string & tmpl, const std::vector & messages) { +// convert input chat messages from JSON to llama_chat_msg +inline std::vector parse_chat_messages(const std::vector & messages) { std::vector chat; - for (size_t i = 0; i < messages.size(); ++i) { const auto & curr_msg = messages[i]; @@ -144,7 +143,12 @@ inline std::string format_chat(const struct llama_model * model, const std::stri chat.push_back({role, content}); } + return chat; +} +// Format given chat. If tmpl is empty, we take the template from model metadata +inline std::string format_chat(const struct llama_model * model, const std::string & tmpl, const std::vector & messages) { + auto chat = parse_chat_messages(messages); auto formatted_chat = llama_chat_apply_template(model, tmpl, chat, true); LOG_VERBOSE("formatted_chat", {{"text", formatted_chat.c_str()}}); return formatted_chat; @@ -356,7 +360,9 @@ static json oaicompat_completion_params_parse( llama_params["__oaicompat"] = true; // Apply chat template to the list of messages - llama_params["prompt"] = format_chat(model, chat_template, body.at("messages")); + if (!body.contains("prompt")) { + llama_params["prompt"] = format_chat(model, chat_template, body.at("messages")); + } // Handle "stop" field if (body.contains("stop") && body.at("stop").is_string()) { @@ -391,7 +397,7 @@ static json oaicompat_completion_params_parse( } // Params supported by OAI but unsupported by llama.cpp - static const std::vector unsupported_params { "tools", "tool_choice" }; + static const std::vector unsupported_params { "tool_choice" }; for (auto & param : unsupported_params) { if (body.contains(param)) { throw std::runtime_error("Unsupported param: " + param); @@ -417,20 +423,31 @@ static json format_final_response_oaicompat(const json & request, json result, c int num_tokens_predicted = json_value(result, "tokens_predicted", 0); int num_prompt_tokens = json_value(result, "tokens_evaluated", 0); std::string content = json_value(result, "content", std::string("")); + bool has_tool_calls = result.contains("tool_calls"); std::string finish_reason = "length"; if (stopped_word || stopped_eos) { - finish_reason = "stop"; + finish_reason = has_tool_calls ? "tool_calls" : "stop"; } + json message = has_tool_calls + ? json{ + {"content", nullptr}, + {"role", "assistant"}, + {"tool_calls", result.at("tool_calls")}, + } + : json{ + {"content", content}, + {"role", "assistant"}, + }; + json choices = streaming ? json::array({json{{"finish_reason", finish_reason}, {"index", 0}, {"delta", json::object()}}}) : json::array({json{{"finish_reason", finish_reason}, {"index", 0}, - {"message", json{{"content", content}, - {"role", "assistant"}}}}}); + {"message", message}}}); std::time_t t = std::time(0); @@ -472,10 +489,11 @@ static std::vector format_partial_response_oaicompat(json result, const st bool stopped_eos = json_value(result, "stopped_eos", false); bool stopped_limit = json_value(result, "stopped_limit", false); std::string content = json_value(result, "content", std::string("")); + bool has_tool_calls = result.contains("tool_calls"); std::string finish_reason; if (stopped_word || stopped_eos) { - finish_reason = "stop"; + finish_reason = has_tool_calls ? "tool_calls" : "stop"; } if (stopped_limit) { finish_reason = "length"; @@ -484,11 +502,41 @@ static std::vector format_partial_response_oaicompat(json result, const st std::time_t t = std::time(0); json choices; + json delta = has_tool_calls + ? json{ + {"content", nullptr}, + {"role", "assistant"}, + {"tool_calls", result.at("tool_calls")}, + } + : json{ + {"content", content}, + {"role", "assistant"}, + }; if (!finish_reason.empty()) { choices = json::array({json{{"finish_reason", finish_reason}, {"index", 0}, {"delta", json::object()}}}); + if (has_tool_calls) { + // tool call must be send as two updates + json initial_ret = json{{"choices", json::array({json{ + {"finish_reason", nullptr}, + {"index", 0}, + {"delta", delta}}})}, + {"created", t}, + {"id", completion_id}, + {"model", modelname}, + {"object", "chat.completion.chunk"}}; + + json second_ret = json{ + {"choices", choices}, + {"created", t}, + {"id", completion_id}, + {"model", modelname}, + {"object", "chat.completion.chunk"}}; + + return std::vector({initial_ret, second_ret}); + } } else { if (first) { if (content.empty()) { @@ -511,9 +559,7 @@ static std::vector format_partial_response_oaicompat(json result, const st json second_ret = json{ {"choices", json::array({json{{"finish_reason", nullptr}, {"index", 0}, - {"delta", json{ - {"content", content}}} - }})}, + {"delta", delta}}})}, {"created", t}, {"id", completion_id}, {"model", modelname}, @@ -531,10 +577,7 @@ static std::vector format_partial_response_oaicompat(json result, const st choices = json::array({json{ {"finish_reason", nullptr}, {"index", 0}, - {"delta", - json{ - {"content", content}, - }}, + {"delta", delta}, }}); } } diff --git a/src/llama.cpp b/src/llama.cpp index 2274296b45406..00dcc59f74834 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -18058,7 +18058,7 @@ int32_t llama_model_meta_val_str(const struct llama_model * model, const char * } return -1; } - return snprintf(buf, buf_size, "%s", it->second.c_str()); + return buf != NULL ? snprintf(buf, buf_size, "%s", it->second.c_str()) : it->second.size(); } int32_t llama_model_meta_count(const struct llama_model * model) { @@ -19757,8 +19757,8 @@ static int32_t llama_chat_apply_template_internal( std::string & dest, bool add_ass) { // Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527 std::stringstream ss; - auto tmpl_contains = [&tmpl](std::string haystack) -> bool { - return tmpl.find(haystack) != std::string::npos; + auto tmpl_contains = [&tmpl](std::string part) -> bool { + return tmpl.find(part) != std::string::npos; }; if (tmpl == "chatml" || tmpl_contains("<|im_start|>")) { // chatml template @@ -20026,13 +20026,15 @@ int32_t llama_chat_apply_template( if (tmpl == nullptr) { GGML_ASSERT(model != nullptr); // load template from model - std::vector model_template(2048, 0); // longest known template is about 1200 bytes std::string template_key = "tokenizer.chat_template"; - int32_t res = llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size()); + // call with NULL buffer to get the total size of the string + int32_t res = llama_model_meta_val_str(model, template_key.c_str(), NULL, 0); if (res < 0) { // worst case: there is no information about template, we will use chatml by default curr_tmpl = "chatml"; // see llama_chat_apply_template_internal } else { + std::vector model_template(res, 0); + llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size()); curr_tmpl = std::string(model_template.data(), model_template.size()); } } From 5f06d37baf371626fd19bb79627c15b934ed0509 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Fri, 30 Aug 2024 21:40:49 +0200 Subject: [PATCH 2/3] add --tool-call argument --- common/common.cpp | 6 ++++++ common/common.h | 1 + examples/server/server.cpp | 17 +++++++++++++++-- examples/server/utils.hpp | 2 +- 4 files changed, 23 insertions(+), 3 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 28ec4f5fc2a10..478787edf43b0 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -428,6 +428,7 @@ void gpt_params_parse_from_env(gpt_params & params) { get_env("LLAMA_ARG_CONT_BATCHING", params.cont_batching); get_env("LLAMA_ARG_HOST", params.hostname); get_env("LLAMA_ARG_PORT", params.port); + get_env("LLAMA_ARG_TOOL_CALLS", params.enable_tool_calls); } bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { @@ -1046,6 +1047,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.lora_init_without_apply = true; return true; } + if (arg == "--tool-call" || arg == "--tool-calls") { + params.enable_tool_calls = true; + return true; + } if (arg == "--control-vector") { CHECK_ARG params.control_vectors.push_back({ 1.0f, argv[i], }); @@ -2036,6 +2041,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "server", "-sps, --slot-prompt-similarity SIMILARITY", "how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f, 0.0 = disabled)\n", params.slot_prompt_similarity }); options.push_back({ "server", " --lora-init-without-apply", "load LoRA adapters without applying them (apply later via POST /lora-adapters) (default: %s)", params.lora_init_without_apply ? "enabled" : "disabled"}); + options.push_back({ "server", " --tool-call(s)", "enable OAI tool calls for chat completion endpoint (default: %s)", params.enable_tool_calls ? "enabled" : "disabled"}); #ifndef LOG_DISABLE_LOGS options.push_back({ "logging" }); diff --git a/common/common.h b/common/common.h index db0800432aef4..fe5f485e587e7 100644 --- a/common/common.h +++ b/common/common.h @@ -221,6 +221,7 @@ struct gpt_params { std::string chat_template = ""; std::string system_prompt = ""; bool enable_chat_template = true; + bool enable_tool_calls = false; std::vector api_keys; diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 73d0088396b29..ccc45fb6d2ce0 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -3071,6 +3071,7 @@ int main(int argc, char ** argv) { if (body.contains("tools") && ctx_server.tool_format != LLAMA_TOOL_FORMAT_NOT_SUPPORTED) { body["prompt"] = format_chat_with_tool(ctx_server.tool_format, body.at("messages"), body.at("tools")); + body.erase(body.find("tools")); } json data = oaicompat_completion_params_parse(ctx_server.model, body, params.chat_template); @@ -3441,14 +3442,26 @@ int main(int argc, char ** argv) { } // decide if we can enable tool calls - ctx_server.tool_format = get_tool_format(ctx_server.ctx); + bool tool_call_support = false; + if (ctx_server.params.enable_tool_calls) { + ctx_server.tool_format = get_tool_format(ctx_server.ctx); + tool_call_support = ctx_server.tool_format != LLAMA_TOOL_FORMAT_NOT_SUPPORTED; + if (tool_call_support) { + LOG_WARNING("Tool call is EXPERIMENTAL and maybe unstable. Use with your own risk", {}); + } else { + LOG_ERROR("Tool call is not supported for this model. Please remove --tool-call or use with a supported model", {}); + clean_up(); + t.join(); + return 1; + } + } // print sample chat example to make it clear which template is used { LOG_INFO("chat template", { {"chat_example", llama_chat_format_example(ctx_server.model, params.chat_template)}, {"built_in", params.chat_template.empty()}, - {"tool_call_enabled", ctx_server.tool_format != LLAMA_TOOL_FORMAT_NOT_SUPPORTED }, + {"tool_call_support", tool_call_support}, }); } diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index e46f7b0327c54..253f8d42f082c 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -397,7 +397,7 @@ static json oaicompat_completion_params_parse( } // Params supported by OAI but unsupported by llama.cpp - static const std::vector unsupported_params { "tool_choice" }; + static const std::vector unsupported_params { "tools", "tool_choice" }; for (auto & param : unsupported_params) { if (body.contains(param)) { throw std::runtime_error("Unsupported param: " + param); From d25cd7f9e4e56d10c56b55439b8a84a5bd847aab Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Fri, 30 Aug 2024 21:51:12 +0200 Subject: [PATCH 3/3] refactor --- examples/server/server.cpp | 11 ++++-- examples/server/utils.hpp | 78 +++++++++++++++++--------------------- 2 files changed, 42 insertions(+), 47 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index ccc45fb6d2ce0..ca15fe7244055 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -3069,9 +3069,14 @@ int main(int argc, char ** argv) { } json body = json::parse(req.body); - if (body.contains("tools") && ctx_server.tool_format != LLAMA_TOOL_FORMAT_NOT_SUPPORTED) { - body["prompt"] = format_chat_with_tool(ctx_server.tool_format, body.at("messages"), body.at("tools")); - body.erase(body.find("tools")); + if (body.contains("tools")) { + if (ctx_server.tool_format != LLAMA_TOOL_FORMAT_NOT_SUPPORTED) { + body["prompt"] = format_chat_with_tool(ctx_server.tool_format, body.at("messages"), body.at("tools")); + body.erase(body.find("tools")); + } else { + res_error(res, format_error_response("This server does not support tool calls. Start it with `--tool-calls`", ERROR_TYPE_NOT_SUPPORTED)); + return; + } } json data = oaicompat_completion_params_parse(ctx_server.model, body, params.chat_template); diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 253f8d42f082c..fe6828c74793e 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -499,7 +499,15 @@ static std::vector format_partial_response_oaicompat(json result, const st finish_reason = "length"; } - std::time_t t = std::time(0); + auto wrap_choices = [&completion_id, &modelname](json choices) -> json { + return json{ + {"choices", choices}, + {"created", std::time(0)}, + {"id", completion_id}, + {"model", modelname}, + {"object", "chat.completion.chunk"} + }; + }; json choices; json delta = has_tool_calls @@ -519,22 +527,14 @@ static std::vector format_partial_response_oaicompat(json result, const st {"delta", json::object()}}}); if (has_tool_calls) { // tool call must be send as two updates - json initial_ret = json{{"choices", json::array({json{ - {"finish_reason", nullptr}, - {"index", 0}, - {"delta", delta}}})}, - {"created", t}, - {"id", completion_id}, - {"model", modelname}, - {"object", "chat.completion.chunk"}}; - - json second_ret = json{ - {"choices", choices}, - {"created", t}, - {"id", completion_id}, - {"model", modelname}, - {"object", "chat.completion.chunk"}}; - + json initial_ret = wrap_choices(json::array({ + json{ + {"finish_reason", nullptr}, + {"index", 0}, + {"delta", delta}, + } + })); + json second_ret = wrap_choices(choices); return std::vector({initial_ret, second_ret}); } } else { @@ -545,26 +545,22 @@ static std::vector format_partial_response_oaicompat(json result, const st {"delta", json{{"role", "assistant"}}}}}); } else { // We have to send this as two updates to conform to openai behavior - json initial_ret = json{{"choices", json::array({json{ - {"finish_reason", nullptr}, - {"index", 0}, - {"delta", json{ - {"role", "assistant"} - }}}})}, - {"created", t}, - {"id", completion_id}, - {"model", modelname}, - {"object", "chat.completion.chunk"}}; - - json second_ret = json{ - {"choices", json::array({json{{"finish_reason", nullptr}, - {"index", 0}, - {"delta", delta}}})}, - {"created", t}, - {"id", completion_id}, - {"model", modelname}, - {"object", "chat.completion.chunk"}}; - + json initial_ret = wrap_choices(json::array({ + json{ + {"finish_reason", nullptr}, + {"index", 0}, + {"delta", json{ + {"role", "assistant"}, + }}, + } + })); + json second_ret = wrap_choices(json::array({ + json{ + {"finish_reason", nullptr}, + {"index", 0}, + {"delta", delta}, + } + })); return std::vector({initial_ret, second_ret}); } } else { @@ -582,13 +578,7 @@ static std::vector format_partial_response_oaicompat(json result, const st } } - json ret = json { - {"choices", choices}, - {"created", t}, - {"id", completion_id}, - {"model", modelname}, - {"object", "chat.completion.chunk"} - }; + json ret = wrap_choices(choices); if (!finish_reason.empty()) { int num_tokens_predicted = json_value(result, "tokens_predicted", 0); int num_prompt_tokens = json_value(result, "tokens_evaluated", 0);