diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index 989275a27..5466de3a4 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -878,19 +878,21 @@ def format_chatml( @register_chat_format("mistral-instruct") -def format_mistral( +def format_mistral_instruct( messages: List[llama_types.ChatCompletionRequestMessage], **kwargs: Any, ) -> ChatFormatterResponse: - _roles = dict(user="[INST] ", assistant="[/INST]") - _sep = " " - system_template = """{system_message}""" - system_message = _get_system_message(messages) - system_message = system_template.format(system_message=system_message) - _messages = _map_roles(messages, _roles) - _messages.append((_roles["assistant"], None)) - _prompt = _format_no_colon_single(system_message, _messages, _sep) - return ChatFormatterResponse(prompt=_prompt) + bos = "" + eos = "" + stop = eos + prompt = bos + for message in messages: + if message["role"] == "user" and message["content"] is not None and isinstance(message["content"], str): + prompt += "[INST] " + message["content"] + elif message["role"] == "assistant" and message["content"] is not None and isinstance(message["content"], str): + prompt += " [/INST]" + message["content"] + eos + prompt += " [/INST]" + return ChatFormatterResponse(prompt=prompt, stop=stop) @register_chat_format("chatglm3") diff --git a/tests/test_llama_chat_format.py b/tests/test_llama_chat_format.py index 1ef18d927..c10aee42e 100644 --- a/tests/test_llama_chat_format.py +++ b/tests/test_llama_chat_format.py @@ -1,10 +1,33 @@ import json +import jinja2 + from llama_cpp import ( ChatCompletionRequestUserMessage, ) +import llama_cpp.llama_types as llama_types +import llama_cpp.llama_chat_format as llama_chat_format + from llama_cpp.llama_chat_format import hf_tokenizer_config_to_chat_formatter +def test_mistral_instruct(): + chat_template = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}" + chat_formatter = jinja2.Template(chat_template) + messages = [ + llama_types.ChatCompletionRequestUserMessage(role="user", content="Instruction"), + llama_types.ChatCompletionRequestAssistantMessage(role="assistant", content="Model answer"), + llama_types.ChatCompletionRequestUserMessage(role="user", content="Follow-up instruction"), + ] + response = llama_chat_format.format_mistral_instruct( + messages=messages, + ) + reference = chat_formatter.render( + messages=messages, + bos_token="", + eos_token="", + ) + assert response.prompt == reference + mistral_7b_tokenizer_config = """{ "add_bos_token": true,