From 642330ac7cea11dc1f9d3df2b8f3dbd10b5e3f3e Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Mon, 2 Dec 2024 22:10:19 +0100 Subject: [PATCH] llama : add enum for built-in chat templates (#10623) * llama : add enum for supported chat templates * use "built-in" instead of "supported" * arg: print list of built-in templates * fix test * update server README --- common/arg.cpp | 20 +- examples/server/README.md | 11 +- include/llama.h | 3 + src/llama.cpp | 357 +++++++++++++++++++++++++---------- tests/test-chat-template.cpp | 20 +- 5 files changed, 307 insertions(+), 104 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index 32d9a964c1716..078c7538490c4 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -348,6 +348,18 @@ bool common_params_parse(int argc, char ** argv, common_params & params, llama_e return true; } +static std::string list_builtin_chat_templates() { + std::vector supported_tmpl; + int32_t res = llama_chat_builtin_templates(nullptr, 0); + supported_tmpl.resize(res); + res = llama_chat_builtin_templates(supported_tmpl.data(), supported_tmpl.size()); + std::ostringstream msg; + for (auto & tmpl : supported_tmpl) { + msg << tmpl << (&tmpl == &supported_tmpl.back() ? "" : ", "); + } + return msg.str(); +} + common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **)) { // load dynamic backends ggml_backend_load_all(); @@ -1814,9 +1826,11 @@ common_params_context common_params_parser_init(common_params & params, llama_ex ).set_examples({LLAMA_EXAMPLE_SERVER})); add_opt(common_arg( {"--chat-template"}, "JINJA_TEMPLATE", - "set custom jinja chat template (default: template taken from model's metadata)\n" - "if suffix/prefix are specified, template will be disabled\n" - "only commonly used templates are accepted:\nhttps://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template", + string_format( + "set custom jinja chat template (default: template taken from model's metadata)\n" + "if suffix/prefix are specified, template will be disabled\n" + "list of built-in templates:\n%s", list_builtin_chat_templates().c_str() + ), [](common_params & params, const std::string & value) { if (!common_chat_verify_template(value)) { throw std::runtime_error(string_format( diff --git a/examples/server/README.md b/examples/server/README.md index aa99d06f95731..3f0d45e5bed1b 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -69,6 +69,8 @@ The project is under active development, and we are [looking for feedback and co | `--mlock` | force system to keep model in RAM rather than swapping or compressing
(env: LLAMA_ARG_MLOCK) | | `--no-mmap` | do not memory-map model (slower load but may reduce pageouts if not using mlock)
(env: LLAMA_ARG_NO_MMAP) | | `--numa TYPE` | attempt optimizations that help on some NUMA systems
- distribute: spread execution evenly over all nodes
- isolate: only spawn threads on CPUs on the node that execution started on
- numactl: use the CPU map provided by numactl
if run without this previously, it is recommended to drop the system page cache before using this
see https://github.com/ggerganov/llama.cpp/issues/1437
(env: LLAMA_ARG_NUMA) | +| `-dev, --device ` | comma-separated list of devices to use for offloading (none = don't offload)
use --list-devices to see a list of available devices
(env: LLAMA_ARG_DEVICE) | +| `--list-devices` | print list of available devices and exit | | `-ngl, --gpu-layers, --n-gpu-layers N` | number of layers to store in VRAM
(env: LLAMA_ARG_N_GPU_LAYERS) | | `-sm, --split-mode {none,layer,row}` | how to split the model across multiple GPUs, one of:
- none: use one GPU only
- layer (default): split layers and KV across GPUs
- row: split rows across GPUs
(env: LLAMA_ARG_SPLIT_MODE) | | `-ts, --tensor-split N0,N1,N2,...` | fraction of the model to offload to each GPU, comma-separated list of proportions, e.g. 3,1
(env: LLAMA_ARG_TENSOR_SPLIT) | @@ -158,9 +160,16 @@ The project is under active development, and we are [looking for feedback and co | `--props` | enable changing global properties via POST /props (default: disabled)
(env: LLAMA_ARG_ENDPOINT_PROPS) | | `--no-slots` | disables slots monitoring endpoint
(env: LLAMA_ARG_NO_ENDPOINT_SLOTS) | | `--slot-save-path PATH` | path to save slot kv cache (default: disabled) | -| `--chat-template JINJA_TEMPLATE` | set custom jinja chat template (default: template taken from model's metadata)
if suffix/prefix are specified, template will be disabled
only commonly used templates are accepted:
https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template
(env: LLAMA_ARG_CHAT_TEMPLATE) | +| `--chat-template JINJA_TEMPLATE` | set custom jinja chat template (default: template taken from model's metadata)
if suffix/prefix are specified, template will be disabled
list of built-in templates:
chatglm3, chatglm4, chatml, command-r, deepseek, deepseek2, exaone3, gemma, granite, llama2, llama2-sys, llama2-sys-bos, llama2-sys-strip, llama3, minicpm, mistral-v1, mistral-v3, mistral-v3-tekken, mistral-v7, monarch, openchat, orion, phi3, rwkv-world, vicuna, vicuna-orca, zephyr
(env: LLAMA_ARG_CHAT_TEMPLATE) | | `-sps, --slot-prompt-similarity SIMILARITY` | how much the prompt of a request must match the prompt of a slot in order to use that slot (default: 0.50, 0.0 = disabled)
| | `--lora-init-without-apply` | load LoRA adapters without applying them (apply later via POST /lora-adapters) (default: disabled) | +| `--draft-max, --draft, --draft-n N` | number of tokens to draft for speculative decoding (default: 16) | +| `--draft-min, --draft-n-min N` | minimum number of draft tokens to use for speculative decoding (default: 5) | +| `--draft-p-min P` | minimum speculative decoding probability (greedy) (default: 0.9) | +| `-cd, --ctx-size-draft N` | size of the prompt context for the draft model (default: 0, 0 = loaded from model) | +| `-devd, --device-draft ` | comma-separated list of devices to use for offloading the draft model (none = don't offload)
use --list-devices to see a list of available devices | +| `-ngld, --gpu-layers-draft, --n-gpu-layers-draft N` | number of layers to store in VRAM for the draft model | +| `-md, --model-draft FNAME` | draft model for speculative decoding (default: unused) | Note: If both command line argument and environment variable are both set for the same param, the argument will take precedence over env var. diff --git a/include/llama.h b/include/llama.h index ab5e376e6c7f2..439e0ff0c7e0c 100644 --- a/include/llama.h +++ b/include/llama.h @@ -990,6 +990,9 @@ extern "C" { char * buf, int32_t length); + // Get list of built-in chat templates + int32_t llama_chat_builtin_templates(const char ** output, size_t len); + // // Sampling API // diff --git a/src/llama.cpp b/src/llama.cpp index 6e9ba97272287..6a6f4c2a5eb7e 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -1549,6 +1549,67 @@ static const std::map> LLM_TENSOR_N }, }; +enum llm_chat_template { + LLM_CHAT_TEMPLATE_CHATML, + LLM_CHAT_TEMPLATE_LLAMA_2, + LLM_CHAT_TEMPLATE_LLAMA_2_SYS, + LLM_CHAT_TEMPLATE_LLAMA_2_SYS_BOS, + LLM_CHAT_TEMPLATE_LLAMA_2_SYS_STRIP, + LLM_CHAT_TEMPLATE_MISTRAL_V1, + LLM_CHAT_TEMPLATE_MISTRAL_V3, + LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN, + LLM_CHAT_TEMPLATE_MISTRAL_V7, + LLM_CHAT_TEMPLATE_PHI_3, + LLM_CHAT_TEMPLATE_ZEPHYR, + LLM_CHAT_TEMPLATE_MONARCH, + LLM_CHAT_TEMPLATE_GEMMA, + LLM_CHAT_TEMPLATE_ORION, + LLM_CHAT_TEMPLATE_OPENCHAT, + LLM_CHAT_TEMPLATE_VICUNA, + LLM_CHAT_TEMPLATE_VICUNA_ORCA, + LLM_CHAT_TEMPLATE_DEEPSEEK, + LLM_CHAT_TEMPLATE_DEEPSEEK_2, + LLM_CHAT_TEMPLATE_COMMAND_R, + LLM_CHAT_TEMPLATE_LLAMA_3, + LLM_CHAT_TEMPLATE_CHATGML_3, + LLM_CHAT_TEMPLATE_CHATGML_4, + LLM_CHAT_TEMPLATE_MINICPM, + LLM_CHAT_TEMPLATE_EXAONE_3, + LLM_CHAT_TEMPLATE_RWKV_WORLD, + LLM_CHAT_TEMPLATE_GRANITE, + LLM_CHAT_TEMPLATE_UNKNOWN, +}; + +static const std::map LLM_CHAT_TEMPLATES = { + { "chatml", LLM_CHAT_TEMPLATE_CHATML }, + { "llama2", LLM_CHAT_TEMPLATE_LLAMA_2 }, + { "llama2-sys", LLM_CHAT_TEMPLATE_LLAMA_2_SYS }, + { "llama2-sys-bos", LLM_CHAT_TEMPLATE_LLAMA_2_SYS_BOS }, + { "llama2-sys-strip", LLM_CHAT_TEMPLATE_LLAMA_2_SYS_STRIP }, + { "mistral-v1", LLM_CHAT_TEMPLATE_MISTRAL_V1 }, + { "mistral-v3", LLM_CHAT_TEMPLATE_MISTRAL_V3 }, + { "mistral-v3-tekken", LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN }, + { "mistral-v7", LLM_CHAT_TEMPLATE_MISTRAL_V7 }, + { "phi3", LLM_CHAT_TEMPLATE_PHI_3 }, + { "zephyr", LLM_CHAT_TEMPLATE_ZEPHYR }, + { "monarch", LLM_CHAT_TEMPLATE_MONARCH }, + { "gemma", LLM_CHAT_TEMPLATE_GEMMA }, + { "orion", LLM_CHAT_TEMPLATE_ORION }, + { "openchat", LLM_CHAT_TEMPLATE_OPENCHAT }, + { "vicuna", LLM_CHAT_TEMPLATE_VICUNA }, + { "vicuna-orca", LLM_CHAT_TEMPLATE_VICUNA_ORCA }, + { "deepseek", LLM_CHAT_TEMPLATE_DEEPSEEK }, + { "deepseek2", LLM_CHAT_TEMPLATE_DEEPSEEK_2 }, + { "command-r", LLM_CHAT_TEMPLATE_COMMAND_R }, + { "llama3", LLM_CHAT_TEMPLATE_LLAMA_3 }, + { "chatglm3", LLM_CHAT_TEMPLATE_CHATGML_3 }, + { "chatglm4", LLM_CHAT_TEMPLATE_CHATGML_4 }, + { "minicpm", LLM_CHAT_TEMPLATE_MINICPM }, + { "exaone3", LLM_CHAT_TEMPLATE_EXAONE_3 }, + { "rwkv-world", LLM_CHAT_TEMPLATE_RWKV_WORLD }, + { "granite", LLM_CHAT_TEMPLATE_GRANITE }, +}; + static llm_arch llm_arch_from_string(const std::string & name) { for (const auto & kv : LLM_ARCH_NAMES) { // NOLINT if (kv.second == name) { @@ -21843,18 +21904,109 @@ int32_t llama_detokenize( // chat templates // +static llm_chat_template llama_chat_detect_template(const std::string & tmpl) { + if (LLM_CHAT_TEMPLATES.find(tmpl) != LLM_CHAT_TEMPLATES.end()) { + return LLM_CHAT_TEMPLATES.at(tmpl); + } + auto tmpl_contains = [&tmpl](const char * haystack) -> bool { + return tmpl.find(haystack) != std::string::npos; + }; + if (tmpl_contains("<|im_start|>")) { + return LLM_CHAT_TEMPLATE_CHATML; + } else if (tmpl.find("mistral") == 0 || tmpl_contains("[INST]")) { + if (tmpl_contains("[SYSTEM_PROMPT]")) { + return LLM_CHAT_TEMPLATE_MISTRAL_V7; + } else if ( + // catches official 'v1' template + tmpl_contains("' [INST] ' + system_message") + // catches official 'v3' and 'v3-tekken' templates + || tmpl_contains("[AVAILABLE_TOOLS]") + ) { + // Official mistral 'v1', 'v3' and 'v3-tekken' templates + // See: https://github.com/mistralai/cookbook/blob/main/concept-deep-dive/tokenization/chat_templates.md + // See: https://github.com/mistralai/cookbook/blob/main/concept-deep-dive/tokenization/templates.md + if (tmpl_contains(" [INST]")) { + return LLM_CHAT_TEMPLATE_MISTRAL_V1; + } else if (tmpl_contains("\"[INST]\"")) { + return LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN; + } + return LLM_CHAT_TEMPLATE_MISTRAL_V3; + } else { + // llama2 template and its variants + // [variant] support system message + // See: https://huggingface.co/blog/llama2#how-to-prompt-llama-2 + bool support_system_message = tmpl_contains("<>"); + bool add_bos_inside_history = tmpl_contains("bos_token + '[INST]"); + bool strip_message = tmpl_contains("content.strip()"); + if (strip_message) { + return LLM_CHAT_TEMPLATE_LLAMA_2_SYS_STRIP; + } else if (add_bos_inside_history) { + return LLM_CHAT_TEMPLATE_LLAMA_2_SYS_BOS; + } else if (support_system_message) { + return LLM_CHAT_TEMPLATE_LLAMA_2_SYS; + } else { + return LLM_CHAT_TEMPLATE_LLAMA_2; + } + } + } else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|end|>")) { + return LLM_CHAT_TEMPLATE_PHI_3; + } else if (tmpl_contains("<|user|>") && tmpl_contains("<|endoftext|>")) { + return LLM_CHAT_TEMPLATE_ZEPHYR; + } else if (tmpl_contains("bos_token + message['role']")) { + return LLM_CHAT_TEMPLATE_MONARCH; + } else if (tmpl_contains("")) { + return LLM_CHAT_TEMPLATE_GEMMA; + } else if (tmpl_contains("'\\n\\nAssistant: ' + eos_token")) { + // OrionStarAI/Orion-14B-Chat + return LLM_CHAT_TEMPLATE_ORION; + } else if (tmpl_contains("GPT4 Correct ")) { + // openchat/openchat-3.5-0106 + return LLM_CHAT_TEMPLATE_OPENCHAT; + } else if (tmpl_contains("USER: ") && tmpl_contains("ASSISTANT: ")) { + // eachadea/vicuna-13b-1.1 (and Orca variant) + if (tmpl_contains("SYSTEM: ")) { + return LLM_CHAT_TEMPLATE_VICUNA_ORCA; + } + return LLM_CHAT_TEMPLATE_VICUNA; + } else if (tmpl_contains("### Instruction:") && tmpl_contains("<|EOT|>")) { + // deepseek-ai/deepseek-coder-33b-instruct + return LLM_CHAT_TEMPLATE_DEEPSEEK; + } else if (tmpl_contains("<|START_OF_TURN_TOKEN|>") && tmpl_contains("<|USER_TOKEN|>")) { + // CohereForAI/c4ai-command-r-plus + return LLM_CHAT_TEMPLATE_COMMAND_R; + } else if (tmpl_contains("<|start_header_id|>") && tmpl_contains("<|end_header_id|>")) { + return LLM_CHAT_TEMPLATE_LLAMA_3; + } else if (tmpl_contains("[gMASK]sop")) { + // chatglm3-6b + return LLM_CHAT_TEMPLATE_CHATGML_3; + } else if (tmpl_contains("[gMASK]")) { + return LLM_CHAT_TEMPLATE_CHATGML_4; + } else if (tmpl_contains(LU8("<用户>"))) { + // MiniCPM-3B-OpenHermes-2.5-v2-GGUF + return LLM_CHAT_TEMPLATE_MINICPM; + } else if (tmpl_contains("'Assistant: ' + message['content'] + eos_token")) { + return LLM_CHAT_TEMPLATE_DEEPSEEK_2; + } else if (tmpl_contains("[|system|]") && tmpl_contains("[|assistant|]") && tmpl_contains("[|endofturn|]")) { + // ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/discussions/8#66bae61b1893d14ee8ed85bb + // EXAONE-3.0-7.8B-Instruct + return LLM_CHAT_TEMPLATE_EXAONE_3; + } else if (tmpl_contains("rwkv-world")) { + return LLM_CHAT_TEMPLATE_RWKV_WORLD; + } else if (tmpl_contains("<|start_of_role|>")) { + return LLM_CHAT_TEMPLATE_GRANITE; + } + return LLM_CHAT_TEMPLATE_UNKNOWN; +} + // Simple version of "llama_apply_chat_template" that only works with strings // This function uses heuristic checks to determine commonly used template. It is not a jinja parser. static int32_t llama_chat_apply_template_internal( - const std::string & tmpl, + const llm_chat_template tmpl, const std::vector & chat, std::string & dest, bool add_ass) { // Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527 std::stringstream ss; - auto tmpl_contains = [&tmpl](std::string haystack) -> bool { - return tmpl.find(haystack) != std::string::npos; - }; - if (tmpl == "chatml" || tmpl_contains("<|im_start|>")) { + if (tmpl == LLM_CHAT_TEMPLATE_CHATML) { // chatml template for (auto message : chat) { ss << "<|im_start|>" << message->role << "\n" << message->content << "<|im_end|>\n"; @@ -21862,86 +22014,84 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "<|im_start|>assistant\n"; } - } else if (tmpl == "llama2" || tmpl.find("mistral") == 0 || tmpl_contains("[INST]")) { - if (tmpl == "mistral-v7" || tmpl_contains("[SYSTEM_PROMPT]")) { - // Official mistral 'v7' template - // See: https://huggingface.co/mistralai/Mistral-Large-Instruct-2411#basic-instruct-template-v7 - for (auto message : chat) { - std::string role(message->role); - std::string content(message->content); - if (role == "system") { - ss << "[SYSTEM_PROMPT] " << content << "[/SYSTEM_PROMPT]"; - } else if (role == "user") { - ss << "[INST] " << content << "[/INST]"; - } - else { - ss << " " << content << ""; - } + } else if (tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7) { + // Official mistral 'v7' template + // See: https://huggingface.co/mistralai/Mistral-Large-Instruct-2411#basic-instruct-template-v7 + for (auto message : chat) { + std::string role(message->role); + std::string content(message->content); + if (role == "system") { + ss << "[SYSTEM_PROMPT] " << content << "[/SYSTEM_PROMPT]"; + } else if (role == "user") { + ss << "[INST] " << content << "[/INST]"; } - } else if (tmpl == "mistral-v1" || tmpl == "mistral-v3" || tmpl == "mistral-v3-tekken" - || tmpl_contains("' [INST] ' + system_message") // catches official 'v1' template - || tmpl_contains("[AVAILABLE_TOOLS]")) { // catches official 'v3' and 'v3-tekken' templates - // Official mistral 'v1', 'v3' and 'v3-tekken' templates - // See: https://github.com/mistralai/cookbook/blob/main/concept-deep-dive/tokenization/chat_templates.md - // See: https://github.com/mistralai/cookbook/blob/main/concept-deep-dive/tokenization/templates.md - std::string leading_space = (tmpl == "mistral-v1" || tmpl_contains(" [INST]") ? " " : ""); - std::string trailing_space = (tmpl == "mistral-v3-tekken" || tmpl_contains("\"[INST]\"") ? "" : " "); - bool trim_assistant_message = tmpl_contains("|trim + eos_token"); - bool is_inside_turn = false; - for (auto message : chat) { - if (!is_inside_turn) { - ss << leading_space << "[INST]" << trailing_space; - is_inside_turn = true; - } - std::string role(message->role); - std::string content(message->content); - if (role == "system") { - ss << content << "\n\n"; - } else if (role == "user") { - ss << content << leading_space << "[/INST]"; - } else { - ss << trailing_space << (trim_assistant_message ? trim(content) : content) << ""; - is_inside_turn = false; - } + else { + ss << " " << content << ""; } - } else { - // llama2 template and its variants - // [variant] support system message - // See: https://huggingface.co/blog/llama2#how-to-prompt-llama-2 - bool support_system_message = tmpl_contains("<>") || tmpl == "llama2"; - // [variant] space before + after response - bool space_around_response = tmpl_contains("' ' + eos_token"); - // [variant] add BOS inside history - bool add_bos_inside_history = tmpl_contains("bos_token + '[INST]"); - // [variant] trim spaces from the input message - bool strip_message = tmpl_contains("content.strip()"); - // construct the prompt - bool is_inside_turn = true; // skip BOS at the beginning - ss << "[INST] "; - for (auto message : chat) { - std::string content = strip_message ? trim(message->content) : message->content; - std::string role(message->role); - if (!is_inside_turn) { - is_inside_turn = true; - ss << (add_bos_inside_history ? "[INST] " : "[INST] "); - } - if (role == "system") { - if (support_system_message) { - ss << "<>\n" << content << "\n<>\n\n"; - } else { - // if the model does not support system message, we still include it in the first message, but without <> - ss << content << "\n"; - } - } else if (role == "user") { - ss << content << " [/INST]"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V1 + || tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V3 + || tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN) { + // See: https://github.com/mistralai/cookbook/blob/main/concept-deep-dive/tokenization/chat_templates.md + // See: https://github.com/mistralai/cookbook/blob/main/concept-deep-dive/tokenization/templates.md + std::string leading_space = tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V1 ? " " : ""; + std::string trailing_space = tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN ? "" : " "; + bool trim_assistant_message = tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V3; + bool is_inside_turn = false; + for (auto message : chat) { + if (!is_inside_turn) { + ss << leading_space << "[INST]" << trailing_space; + is_inside_turn = true; + } + std::string role(message->role); + std::string content(message->content); + if (role == "system") { + ss << content << "\n\n"; + } else if (role == "user") { + ss << content << leading_space << "[/INST]"; + } else { + ss << trailing_space << (trim_assistant_message ? trim(content) : content) << ""; + is_inside_turn = false; + } + } + } else if ( + tmpl == LLM_CHAT_TEMPLATE_LLAMA_2 + || tmpl == LLM_CHAT_TEMPLATE_LLAMA_2_SYS + || tmpl == LLM_CHAT_TEMPLATE_LLAMA_2_SYS_BOS + || tmpl == LLM_CHAT_TEMPLATE_LLAMA_2_SYS_STRIP) { + // llama2 template and its variants + // [variant] support system message + // See: https://huggingface.co/blog/llama2#how-to-prompt-llama-2 + bool support_system_message = tmpl != LLM_CHAT_TEMPLATE_LLAMA_2; + // [variant] add BOS inside history + bool add_bos_inside_history = tmpl == LLM_CHAT_TEMPLATE_LLAMA_2_SYS_BOS; + // [variant] trim spaces from the input message + bool strip_message = tmpl == LLM_CHAT_TEMPLATE_LLAMA_2_SYS_STRIP; + // construct the prompt + bool is_inside_turn = true; // skip BOS at the beginning + ss << "[INST] "; + for (auto message : chat) { + std::string content = strip_message ? trim(message->content) : message->content; + std::string role(message->role); + if (!is_inside_turn) { + is_inside_turn = true; + ss << (add_bos_inside_history ? "[INST] " : "[INST] "); + } + if (role == "system") { + if (support_system_message) { + ss << "<>\n" << content << "\n<>\n\n"; } else { - ss << (space_around_response ? " " : "") << content << (space_around_response ? " " : "") << ""; - is_inside_turn = false; + // if the model does not support system message, we still include it in the first message, but without <> + ss << content << "\n"; } + } else if (role == "user") { + ss << content << " [/INST]"; + } else { + ss << content << ""; + is_inside_turn = false; } - // llama2 templates seem to not care about "add_generation_prompt } - } else if (tmpl == "phi3" || (tmpl_contains("<|assistant|>") && tmpl_contains("<|end|>"))) { + } else if (tmpl == LLM_CHAT_TEMPLATE_PHI_3) { // Phi 3 for (auto message : chat) { std::string role(message->role); @@ -21950,7 +22100,7 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "<|assistant|>\n"; } - } else if (tmpl == "zephyr" || tmpl_contains("<|user|>")) { + } else if (tmpl == LLM_CHAT_TEMPLATE_ZEPHYR) { // zephyr template for (auto message : chat) { ss << "<|" << message->role << "|>" << "\n" << message->content << "<|endoftext|>\n"; @@ -21958,7 +22108,7 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "<|assistant|>\n"; } - } else if (tmpl == "monarch" || tmpl_contains("bos_token + message['role']")) { + } else if (tmpl == LLM_CHAT_TEMPLATE_MONARCH) { // mlabonne/AlphaMonarch-7B template (the is included inside history) for (auto message : chat) { std::string bos = (message == chat.front()) ? "" : ""; // skip BOS for first message @@ -21967,7 +22117,7 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "assistant\n"; } - } else if (tmpl == "gemma" || tmpl == "gemma2" || tmpl_contains("")) { + } else if (tmpl == LLM_CHAT_TEMPLATE_GEMMA) { // google/gemma-7b-it std::string system_prompt = ""; for (auto message : chat) { @@ -21989,7 +22139,7 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "model\n"; } - } else if (tmpl == "orion" || tmpl_contains("'\\n\\nAssistant: ' + eos_token")) { + } else if (tmpl == LLM_CHAT_TEMPLATE_ORION) { // OrionStarAI/Orion-14B-Chat std::string system_prompt = ""; for (auto message : chat) { @@ -22009,7 +22159,7 @@ static int32_t llama_chat_apply_template_internal( ss << message->content << ""; } } - } else if (tmpl == "openchat" || tmpl_contains("GPT4 Correct ")) { + } else if (tmpl == LLM_CHAT_TEMPLATE_OPENCHAT) { // openchat/openchat-3.5-0106, for (auto message : chat) { std::string role(message->role); @@ -22023,13 +22173,13 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "GPT4 Correct Assistant:"; } - } else if (tmpl == "vicuna" || tmpl == "vicuna-orca" || (tmpl_contains("USER: ") && tmpl_contains("ASSISTANT: "))) { + } else if (tmpl == LLM_CHAT_TEMPLATE_VICUNA || tmpl == LLM_CHAT_TEMPLATE_VICUNA_ORCA) { // eachadea/vicuna-13b-1.1 (and Orca variant) for (auto message : chat) { std::string role(message->role); if (role == "system") { // Orca-Vicuna variant uses a system prefix - if (tmpl == "vicuna-orca" || tmpl_contains("SYSTEM: ")) { + if (tmpl == LLM_CHAT_TEMPLATE_VICUNA_ORCA) { ss << "SYSTEM: " << message->content << "\n"; } else { ss << message->content << "\n\n"; @@ -22043,7 +22193,7 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "ASSISTANT:"; } - } else if (tmpl == "deepseek" || (tmpl_contains("### Instruction:") && tmpl_contains("<|EOT|>"))) { + } else if (tmpl == LLM_CHAT_TEMPLATE_DEEPSEEK) { // deepseek-ai/deepseek-coder-33b-instruct for (auto message : chat) { std::string role(message->role); @@ -22058,7 +22208,7 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "### Response:\n"; } - } else if (tmpl == "command-r" || (tmpl_contains("<|START_OF_TURN_TOKEN|>") && tmpl_contains("<|USER_TOKEN|>"))) { + } else if (tmpl == LLM_CHAT_TEMPLATE_COMMAND_R) { // CohereForAI/c4ai-command-r-plus for (auto message : chat) { std::string role(message->role); @@ -22073,7 +22223,7 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"; } - } else if (tmpl == "llama3" || (tmpl_contains("<|start_header_id|>") && tmpl_contains("<|end_header_id|>"))) { + } else if (tmpl == LLM_CHAT_TEMPLATE_LLAMA_3) { // Llama 3 for (auto message : chat) { std::string role(message->role); @@ -22082,7 +22232,7 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "<|start_header_id|>assistant<|end_header_id|>\n\n"; } - } else if (tmpl == "chatglm3" || tmpl_contains("[gMASK]sop")) { + } else if (tmpl == LLM_CHAT_TEMPLATE_CHATGML_3) { // chatglm3-6b ss << "[gMASK]" << "sop"; for (auto message : chat) { @@ -22092,7 +22242,7 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "<|assistant|>"; } - } else if (tmpl == "chatglm4" || tmpl_contains("[gMASK]")) { + } else if (tmpl == LLM_CHAT_TEMPLATE_CHATGML_4) { ss << "[gMASK]" << ""; for (auto message : chat) { std::string role(message->role); @@ -22101,7 +22251,7 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "<|assistant|>"; } - } else if (tmpl == "minicpm" || tmpl_contains(LU8("<用户>"))) { + } else if (tmpl == LLM_CHAT_TEMPLATE_MINICPM) { // MiniCPM-3B-OpenHermes-2.5-v2-GGUF for (auto message : chat) { std::string role(message->role); @@ -22113,7 +22263,7 @@ static int32_t llama_chat_apply_template_internal( ss << trim(message->content); } } - } else if (tmpl == "deepseek2" || tmpl_contains("'Assistant: ' + message['content'] + eos_token")) { + } else if (tmpl == LLM_CHAT_TEMPLATE_DEEPSEEK_2) { // DeepSeek-V2 for (auto message : chat) { std::string role(message->role); @@ -22128,7 +22278,7 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "Assistant:"; } - } else if (tmpl == "exaone3" || (tmpl_contains("[|system|]") && tmpl_contains("[|assistant|]") && tmpl_contains("[|endofturn|]"))) { + } else if (tmpl == LLM_CHAT_TEMPLATE_EXAONE_3) { // ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/discussions/8#66bae61b1893d14ee8ed85bb // EXAONE-3.0-7.8B-Instruct for (auto message : chat) { @@ -22144,7 +22294,7 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "[|assistant|]"; } - } else if (tmpl == "rwkv-world" || tmpl_contains("rwkv-world")) { + } else if (tmpl == LLM_CHAT_TEMPLATE_RWKV_WORLD) { // this template requires the model to have "\n\n" as EOT token for (auto message : chat) { std::string role(message->role); @@ -22154,7 +22304,7 @@ static int32_t llama_chat_apply_template_internal( ss << message->content << "\n\n"; } } - } else if (tmpl == "granite" || tmpl_contains("<|start_of_role|>")) { + } else if (tmpl == LLM_CHAT_TEMPLATE_GRANITE) { // IBM Granite template for (const auto & message : chat) { std::string role(message->role); @@ -22206,7 +22356,11 @@ int32_t llama_chat_apply_template( } std::string formatted_chat; - int32_t res = llama_chat_apply_template_internal(curr_tmpl, chat_vec, formatted_chat, add_ass); + llm_chat_template detected_tmpl = llama_chat_detect_template(curr_tmpl); + if (detected_tmpl == LLM_CHAT_TEMPLATE_UNKNOWN) { + return -1; + } + int32_t res = llama_chat_apply_template_internal(detected_tmpl, chat_vec, formatted_chat, add_ass); if (res < 0) { return res; } @@ -22216,6 +22370,15 @@ int32_t llama_chat_apply_template( return res; } +int32_t llama_chat_builtin_templates(const char ** output, size_t len) { + auto it = LLM_CHAT_TEMPLATES.begin(); + for (size_t i = 0; i < std::min(len, LLM_CHAT_TEMPLATES.size()); i++) { + output[i] = it->first.c_str(); + std::advance(it, 1); + } + return (int32_t) LLM_CHAT_TEMPLATES.size(); +} + // // sampling // diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp index dd8f7d5f096dc..aa140b5696f74 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-chat-template.cpp @@ -82,9 +82,9 @@ int main(void) { // mistralai/Mistral-7B-Instruct-v0.2 (NOTE: Old pre-v1 without a system prompt) "[INST] You are a helpful assistant\nHello [/INST]Hi there[INST] Who are you [/INST] I am an assistant [INST] Another question [/INST]", // TheBloke/FusionNet_34Bx2_MoE-AWQ - "[INST] <>\nYou are a helpful assistant\n<>\n\nHello [/INST] Hi there [INST] Who are you [/INST] I am an assistant [INST] Another question [/INST]", + "[INST] <>\nYou are a helpful assistant\n<>\n\nHello [/INST]Hi there[INST] Who are you [/INST] I am an assistant [INST] Another question [/INST]", // bofenghuang/vigogne-2-70b-chat - "[INST] <>\nYou are a helpful assistant\n<>\n\nHello [/INST] Hi there [INST] Who are you [/INST] I am an assistant [INST] Another question [/INST]", + "[INST] <>\nYou are a helpful assistant\n<>\n\nHello [/INST]Hi there[INST] Who are you [/INST]I am an assistant[INST] Another question [/INST]", // mlabonne/AlphaMonarch-7B "system\nYou are a helpful assistant\nuser\nHello\nassistant\nHi there\nuser\nWho are you\nassistant\n I am an assistant \nuser\nAnother question\nassistant\n", // google/gemma-7b-it @@ -133,6 +133,17 @@ int main(void) { std::vector formatted_chat(1024); int32_t res; + // list all supported templates + std::vector supported_tmpl; + res = llama_chat_builtin_templates(nullptr, 0); + assert(res > 0); + supported_tmpl.resize(res); + res = llama_chat_builtin_templates(supported_tmpl.data(), supported_tmpl.size()); + printf("Built-in chat templates:\n"); + for (auto tmpl : supported_tmpl) { + printf(" %s\n", tmpl); + } + // test invalid chat template res = llama_chat_apply_template(nullptr, "INVALID TEMPLATE", conversation, message_count, true, formatted_chat.data(), formatted_chat.size()); assert(res < 0); @@ -174,7 +185,8 @@ int main(void) { assert(fmt_sys("mistral-v3") == "[INST] You are a helpful assistant\n\n"); assert(fmt_sys("mistral-v3-tekken") == "[INST]You are a helpful assistant\n\n"); assert(fmt_sys("mistral-v7") == "[SYSTEM_PROMPT] You are a helpful assistant[/SYSTEM_PROMPT]"); - assert(fmt_sys("llama2") == "[INST] <>\nYou are a helpful assistant\n<>\n\n"); + assert(fmt_sys("llama2") == "[INST] You are a helpful assistant\n"); + assert(fmt_sys("llama2-sys") == "[INST] <>\nYou are a helpful assistant\n<>\n\n"); assert(fmt_sys("mistral") == "[INST] You are a helpful assistant\n"); // for old pre-v1 templates assert(fmt_sys("gemma") == ""); // for gemma, system message is merged with user message assert(fmt_sys("llama3") == "<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful assistant<|eot_id|>"); @@ -203,5 +215,7 @@ int main(void) { assert(fmt_single("gemma") == "\nuser\nHow are you\nmodel\n"); assert(fmt_single("llama3") == "<|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"); + printf("Test chat templates: OK\n"); + return 0; }