Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

server : add Hermes-3 tool call support (WIP) #9254

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,7 @@ void gpt_params_parse_from_env(gpt_params & params) {
get_env("LLAMA_ARG_CONT_BATCHING", params.cont_batching);
get_env("LLAMA_ARG_HOST", params.hostname);
get_env("LLAMA_ARG_PORT", params.port);
get_env("LLAMA_ARG_TOOL_CALLS", params.enable_tool_calls);
}

bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
Expand Down Expand Up @@ -1046,6 +1047,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
params.lora_init_without_apply = true;
return true;
}
if (arg == "--tool-call" || arg == "--tool-calls") {
params.enable_tool_calls = true;
return true;
}
if (arg == "--control-vector") {
CHECK_ARG
params.control_vectors.push_back({ 1.0f, argv[i], });
Expand Down Expand Up @@ -2036,6 +2041,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
options.push_back({ "server", "-sps, --slot-prompt-similarity SIMILARITY",
"how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f, 0.0 = disabled)\n", params.slot_prompt_similarity });
options.push_back({ "server", " --lora-init-without-apply", "load LoRA adapters without applying them (apply later via POST /lora-adapters) (default: %s)", params.lora_init_without_apply ? "enabled" : "disabled"});
options.push_back({ "server", " --tool-call(s)", "enable OAI tool calls for chat completion endpoint (default: %s)", params.enable_tool_calls ? "enabled" : "disabled"});

#ifndef LOG_DISABLE_LOGS
options.push_back({ "logging" });
Expand Down Expand Up @@ -2253,6 +2259,10 @@ bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_over
return true;
}

bool string_contains(std::string haystack, std::string needle) {
return haystack.find(needle) != std::string::npos;
}

//
// Filesystem utils
//
Expand Down Expand Up @@ -3186,6 +3196,19 @@ std::string llama_chat_format_example(const struct llama_model * model,
return llama_chat_apply_template(model, tmpl, msgs, true);
}

std::string llama_get_chat_template(const struct llama_model * model) {
std::string template_key = "tokenizer.chat_template";
// call with NULL buffer to get the total size of the string
int32_t res = llama_model_meta_val_str(model, template_key.c_str(), NULL, 0);
if (res < 0) {
return "";
} else {
std::vector<char> model_template(res, 0);
llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size());
return std::string(model_template.data(), model_template.size());
}
}

//
// KV cache utils
//
Expand Down
7 changes: 7 additions & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ struct gpt_params {
std::string chat_template = "";
std::string system_prompt = "";
bool enable_chat_template = true;
bool enable_tool_calls = false;

std::vector<std::string> api_keys;

Expand Down Expand Up @@ -320,6 +321,8 @@ static std::vector<T> string_split(const std::string & str, char delim) {
bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides);
void string_process_escapes(std::string & input);

bool string_contains(std::string haystack, std::string needle);

//
// Filesystem utils
//
Expand Down Expand Up @@ -428,6 +431,10 @@ std::string llama_chat_format_single(const struct llama_model * model,
std::string llama_chat_format_example(const struct llama_model * model,
const std::string & tmpl);

// Returns the chat template stored inside the model
// (empty string if model does not have built-in chat template)
std::string llama_get_chat_template(const struct llama_model * model);

//
// KV cache utils
//
Expand Down
71 changes: 55 additions & 16 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "json-schema-to-grammar.h"
#include "llama.h"
#include "grammar-parser.h"
#include "tool-call.hpp"

#ifndef NDEBUG
// crash the server in debug mode, otherwise send an http 500 error
Expand Down Expand Up @@ -157,6 +158,7 @@ struct server_slot {
std::string generated_text;
std::vector<llama_token> cache_tokens;
std::vector<completion_token_output> generated_token_probs;
enum llama_response_state response_state = LLAMA_RESPONSE_STATE_UNKNOWN;

bool infill = false;
bool embedding = false;
Expand Down Expand Up @@ -207,6 +209,7 @@ struct server_slot {
infill = false;
ga_i = 0;
n_past_se = 0;
response_state = LLAMA_RESPONSE_STATE_UNKNOWN;

generated_token_probs.clear();
}
Expand Down Expand Up @@ -625,6 +628,7 @@ struct server_context {
llama_model * model = nullptr;
llama_context * ctx = nullptr;
std::vector<llama_lora_adapter_container> lora_adapters;
llama_tool_format tool_format = LLAMA_TOOL_FORMAT_NOT_SUPPORTED;

gpt_params params;

Expand Down Expand Up @@ -1217,7 +1221,13 @@ struct server_context {
break;
}

if (!incomplete) {
if (slot.response_state == LLAMA_RESPONSE_STATE_UNKNOWN) {
slot.response_state = check_response_state(tool_format, slot.generated_text);
}

// if response is tool call, we cannot stream it
// instead, we wait for the full response, then extract JSON
if (!incomplete && slot.response_state == LLAMA_RESPONSE_STATE_TEXT) {
size_t pos = std::min(slot.n_sent_text, slot.generated_text.size());

const std::string str_test = slot.generated_text.substr(pos);
Expand Down Expand Up @@ -1247,9 +1257,7 @@ struct server_context {
if (slot.params.stream) {
send_partial_response(slot, result);
}
}

if (incomplete) {
} else {
slot.has_next_token = true;
}

Expand Down Expand Up @@ -1396,6 +1404,10 @@ struct server_context {
{"multimodal", false}
};

if (slot.response_state == LLAMA_RESPONSE_STATE_TOOL_CALL) {
res.data["tool_calls"] = parse_tool_response(tool_format, tkn.text_to_send);
}

if (slot.sparams.n_probs > 0) {
const std::vector<llama_token> to_send_toks = llama_tokenize(ctx, tkn.text_to_send, false);
const size_t probs_pos = std::min(slot.n_sent_token_probs, slot.generated_token_probs.size());
Expand Down Expand Up @@ -1444,6 +1456,10 @@ struct server_context {
{"timings", slot.get_formated_timings()}
};

if (slot.response_state == LLAMA_RESPONSE_STATE_TOOL_CALL) {
res.data["tool_calls"] = parse_tool_response(tool_format, slot.generated_text);
}

if (slot.sparams.n_probs > 0) {
std::vector<completion_token_output> probs;
if (!slot.params.stream && slot.stopped_word) {
Expand Down Expand Up @@ -2937,19 +2953,14 @@ int main(int argc, char ** argv) {
};

const auto handle_props = [&ctx_server](const httplib::Request &, httplib::Response & res) {
std::string template_key = "tokenizer.chat_template", curr_tmpl;
int32_t tlen = llama_model_meta_val_str(ctx_server.model, template_key.c_str(), nullptr, 0);
if (tlen > 0) {
std::vector<char> curr_tmpl_buf(tlen + 1, 0);
if (llama_model_meta_val_str(ctx_server.model, template_key.c_str(), curr_tmpl_buf.data(), curr_tmpl_buf.size()) == tlen) {
curr_tmpl = std::string(curr_tmpl_buf.data(), tlen);
}
}
std::string chat_tmpl = ctx_server.params.chat_template.empty()
? llama_get_chat_template(ctx_server.model)
: ctx_server.params.chat_template;
json data = {
{ "system_prompt", ctx_server.system_prompt.c_str() },
{ "default_generation_settings", ctx_server.default_generation_settings_for_props },
{ "total_slots", ctx_server.params.n_parallel },
{ "chat_template", curr_tmpl.c_str() }
{ "chat_template", chat_tmpl },
};

res.set_content(data.dump(), MIMETYPE_JSON);
Expand Down Expand Up @@ -3056,7 +3067,19 @@ int main(int argc, char ** argv) {
res_error(res, format_error_response("This server does not support chat completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
return;
}
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
json body = json::parse(req.body);

if (body.contains("tools")) {
if (ctx_server.tool_format != LLAMA_TOOL_FORMAT_NOT_SUPPORTED) {
body["prompt"] = format_chat_with_tool(ctx_server.tool_format, body.at("messages"), body.at("tools"));
body.erase(body.find("tools"));
} else {
res_error(res, format_error_response("This server does not support tool calls. Start it with `--tool-calls`", ERROR_TYPE_NOT_SUPPORTED));
return;
}
}

json data = oaicompat_completion_params_parse(ctx_server.model, body, params.chat_template);

const int id_task = ctx_server.queue_tasks.get_new_id();

Expand Down Expand Up @@ -3423,11 +3446,27 @@ int main(int argc, char ** argv) {
}
}

// decide if we can enable tool calls
bool tool_call_support = false;
if (ctx_server.params.enable_tool_calls) {
ctx_server.tool_format = get_tool_format(ctx_server.ctx);
tool_call_support = ctx_server.tool_format != LLAMA_TOOL_FORMAT_NOT_SUPPORTED;
if (tool_call_support) {
LOG_WARNING("Tool call is EXPERIMENTAL and maybe unstable. Use with your own risk", {});
} else {
LOG_ERROR("Tool call is not supported for this model. Please remove --tool-call or use with a supported model", {});
clean_up();
t.join();
return 1;
}
}

// print sample chat example to make it clear which template is used
{
LOG_INFO("chat template", {
{"chat_example", llama_chat_format_example(ctx_server.model, params.chat_template)},
{"built_in", params.chat_template.empty()},
{"chat_example", llama_chat_format_example(ctx_server.model, params.chat_template)},
{"built_in", params.chat_template.empty()},
{"tool_call_support", tool_call_support},
});
}

Expand Down
114 changes: 114 additions & 0 deletions examples/server/tool-call.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
#pragma once

#include "llama.h"
#include "common.h"
#include "utils.hpp"

// Change JSON_ASSERT from assert() to GGML_ASSERT:
#define JSON_ASSERT GGML_ASSERT
#include "json.hpp"

#include <string>
#include <vector>
#include <sstream>

using json = nlohmann::ordered_json;

enum llama_tool_format {
LLAMA_TOOL_FORMAT_NOT_SUPPORTED,
LLAMA_TOOL_FORMAT_HERMES_3,
};

enum llama_response_state {
LLAMA_RESPONSE_STATE_UNKNOWN,
LLAMA_RESPONSE_STATE_TEXT,
LLAMA_RESPONSE_STATE_TOOL_CALL,
};

// get the tool call format for the loaded model
// this function does linear search, so do not call it repeatedly
inline enum llama_tool_format get_tool_format(const struct llama_context * ctx) {
auto model = llama_get_model(ctx);
auto has_token = [&](std::string piece) {
for (int i = 0; i < llama_n_vocab(model); i++) {
const std::string token_str = llama_token_to_piece(ctx, i, true);
if (token_str == piece) {
return true;
}
}
return false;
};
if (has_token("<|im_start|>") && has_token("<tool_call>")) {
return LLAMA_TOOL_FORMAT_HERMES_3;
}
return LLAMA_TOOL_FORMAT_NOT_SUPPORTED;
}

inline std::string format_chat_with_tool(enum llama_tool_format format, const std::vector<json> & messages, json tools) {
if (!tools.is_array()) {
throw std::runtime_error("tools must be an array");
}
std::stringstream ss;
auto chat = parse_chat_messages(messages);
if (format == LLAMA_TOOL_FORMAT_HERMES_3) {
ss << "<|im_start|>system\n\n";
ss << "You are a function calling AI model. You are provided with function signatures within <tools></tools> XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: <tools>\n\n";
for (auto tool : tools) {
ss << tool.dump(1, '\t') << "\n\n";

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the tabulations? They are increasing the number of tokens, but I think they do not provide useful information.

Suggested change
ss << tool.dump(1, '\t') << "\n\n";
ss << tool.dump() << "\n\n";

}
ss << "</tools> Use the following pydantic model json schema for each tool call you will make: {\"properties\": {\"arguments\": {\"title\": \"Arguments\", \"type\": \"object\"}, \"name\": {\"title\": \"Name\", \"type\": \"string\"}}, \"required\": [\"arguments\", \"name\"], \"title\": \"FunctionCall\", \"type\": \"object\"} For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:\n";
ss << "<tool_call>\n";
ss << "{\"arguments\": <args-dict>, \"name\": <function-name>}\n";
ss << "</tool_call><|im_end|>\n";
for (auto & message : chat) {
std::string role(message.role);
if (role == "system") {
continue; // for optimal performance, we skip user-defined system message
}
ss << "<|im_start|>" << role << "\n\n";
if (role == "tool") {
ss << "<tool_response>\n" << string_strip(message.content) << "\n</tool_response>\n";
} else {
ss << string_strip(message.content) << "<|im_end|>\n";
}
}
ss << "<|im_start|>assistant\n\n";
} else {
throw std::runtime_error("tool_call is not supported by this model");
}
LOG_VERBOSE("format_chat_with_tool", {{"text", ss.str()}});
return ss.str();
}

// check if the response is text or tool_call
// if it is tool_call, we may have to disable streaming, because we must parse the whole JSON response
inline enum llama_response_state check_response_state(enum llama_tool_format format, const std::string & generated_text) {
if (format == LLAMA_TOOL_FORMAT_NOT_SUPPORTED) {
return LLAMA_RESPONSE_STATE_TEXT;
} else if (format == LLAMA_TOOL_FORMAT_HERMES_3 && generated_text.rfind("<tool_call>", 0) == 0) {
return LLAMA_RESPONSE_STATE_TOOL_CALL;
}
return LLAMA_RESPONSE_STATE_TEXT;
}

// convert model's response to OAI format
inline json parse_tool_response(enum llama_tool_format format, const std::string & generated_text) {
if (format == LLAMA_TOOL_FORMAT_NOT_SUPPORTED) {
return json{};
} else if (format == LLAMA_TOOL_FORMAT_HERMES_3) {
std::string tmp(generated_text);
string_replace_all(tmp, "<tool_call>", "");
string_replace_all(tmp, "</tool_call>", "");
json tool = json::parse(tmp);
std::vector<json> tool_calls = {json{
{"id", tool.at("name")},
{"type", "function"},
{"function", {
{"name", tool.at("name")},
{"arguments", tool.at("arguments").dump()}, // OAI requires this to be JSON-stringified
}},
}};
return tool_calls;
}
return generated_text;
}
Loading
Loading