-
Notifications
You must be signed in to change notification settings - Fork 10.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Jinja template support #11016
Add Jinja template support #11016
Conversation
Feel free to add the option to llama-run for basic testing also @ochafik |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Impressive work, thanks! Let's wait for @ggerganov to do another pass, then I think it's good to go!
@@ -4,22 +4,26 @@ | |||
|
|||
server = ServerPreset.tinyllama2() | |||
|
|||
|
|||
@pytest.fixture(scope="module", autouse=True) | |||
@pytest.fixture(autouse=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not exceptionally good at pytest
so maybe I'm missing something. Could you explain why scope="module"
is removed here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
scope=module was making the ServerProcess server instance shared between all the test in the module (file). Even though it's stopped in stop_server_after_each_test, it carried previous settings over to subsequent tests, spilling server flags over (became more important w/ jinja & chat_template params)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok thanks for the explanation. Seem like module=scope is not what I wanted. I want the fixture to only affect single file, since the idea is that one test unit uses one model
] | ||
) | ||
def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason): | ||
def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason, jinja, chat_template): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: we can also add a "slow" test that can test tool call with a big model like Hermes or Qwen (see an example in test_infill.py
). I'll have a look in the next few days.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hehe absolutely, this is coming in #9639 or another subpart of it (tool call parsing + conditional grammars)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
btw I wonder if there's any reason to override the LLAMA_CACHE to tmp in server tests? I've been struggling with disk space on my MBP 😅
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's mostly to provide a way to isolate tests if user have multiple clones of llama.cpp source code on the machine. Maybe you can symlink that tmp directory to an external storage ?
…llama_vocab::impl::token_get_attr)
common/minja.hpp
Outdated
static std::string normalize_newlines(const std::string & s) { | ||
#ifdef _WIN32 | ||
static const std::regex nl_regex("\r\n"); | ||
return std::regex_replace(s, nl_regex, "\n"); | ||
#else | ||
return s; | ||
#endif | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure what was the original purpose for this, but I think it can be removed, as well as the definition of ENDL
to \r\n
in win32. It shouldn't make a difference with stringstream
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dropped ENDL + 1 usage of this function (at end of rendering; one is still needed to shield the parser from CRLFs), thanks!
Small thing to note is that some jinja templates are not "linear", meaning each conversation turn is not self-contained, but can modify the content before it. For example, the new deepseek-r1 distilled has The consequence is that it will break A solution is to also track the cached token at token level (not conversation level), which I introduced here #11203 , @ericcurtin feel free to port this to |
Thanks everyone for the insightful reviews! More from #9639 to come soon :-) |
Not sure if this is a special case or the template is broken, but when I load minimax-text-01 (my work-in-progress) with the following template:
with this PR llama.cpp crashes during model loading:
|
Hey @fairydreaming , thanks for testing & reporting! Your template contain an exotic
I could certainly make the error more informative though, feel free to file something on https://github.com/google/minja to that end (and/or any feature request). Looking forward to testing your model, good luck with it! |
@ochafik I did some research and it seems to be a custom keyword introduced in HF transformers: huggingface/transformers#30650 Fortunately among all the models I have currently on disk only MiniMax-Text-01 uses this. |
@fairydreaming thanks for researching that, will track support in google/minja#28 |
* Copy minja from google/minja@58f0ca6 * Add --jinja and --chat-template-file flags * Add missing <optional> include * Avoid print in get_hf_chat_template.py * No designated initializers yet * Try and work around msvc++ non-macro max resolution quirk * Update test_chat_completion.py * Wire LLM_KV_TOKENIZER_CHAT_TEMPLATE_N in llama_model_chat_template * Refactor test-chat-template * Test templates w/ minja * Fix deprecation * Add --jinja to llama-run * Update common_chat_format_example to use minja template wrapper * Test chat_template in e2e test * Update utils.py * Update test_chat_completion.py * Update run.cpp * Update arg.cpp * Refactor common_chat_* functions to accept minja template + use_jinja option * Attempt to fix linkage of LLAMA_CHATML_TEMPLATE * Revert LLAMA_CHATML_TEMPLATE refactor * Normalize newlines in test-chat-templates for windows tests * Forward decl minja::chat_template to avoid eager json dep * Flush stdout in chat template before potential crash * Fix copy elision warning * Rm unused optional include * Add missing optional include to server.cpp * Disable jinja test that has a cryptic windows failure * minja: fix vigogne (google/minja#22) * Apply suggestions from code review Co-authored-by: Xuan Son Nguyen <[email protected]> Co-authored-by: Georgi Gerganov <[email protected]> * Finish suggested renamings * Move chat_templates inside server_context + remove mutex * Update --chat-template-file w/ recent change to --chat-template * Refactor chat template validation * Guard against missing eos/bos tokens (null token otherwise throws in llama_vocab::impl::token_get_attr) * Warn against missing eos / bos tokens when jinja template references them * rename: common_chat_template[s] * reinstate assert on chat_templates.template_default * Update minja to google/minja@b8437df * Update minja to google/minja#25 * Update minja from google/minja#27 * rm unused optional header --------- Co-authored-by: Xuan Son Nguyen <[email protected]> Co-authored-by: Georgi Gerganov <[email protected]>
Subset of #9639 with just the Jinja templating support.
Proper tool support (grammar constraints, lazy grammar triggering, tool call parsing & stop reason) will come in a follow up PR.
--jinja
flag to llama-server, llama-cli, llama-run--chat-template-file
flag to llama-server, llama-cli (related: Added chat template support to llama-run #11215 )tokenizer.chat_template
(ortokenizer.chat_template.tool_use
if defined, only when the request has tools).trim_blocks = true, lstrip_blocks = true
)Example usage:
show output
TODO: