-
Notifications
You must be signed in to change notification settings - Fork 991
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft: Refactor Chat Formatter for Enhanced Flexibility and Extensibility #809
Conversation
- Introduce `BASE_TEMPLATE` for common chat formatting structure. - Implement a protocol-based `ChatFormatterTemplate` for custom formatters. - Add `Llama2Formatter` to handle specific Llama-2 formatting. - Create `ChatFormatter` class for registering and retrieving formatters. - Remove redundant functions like `_format_llama2`. Refactored the chat message formatting to use a more structured and extensible approach. Now supports multiple templates and ensures a cleaner codebase.
- Introduce `test_llama_chat_formatters.py` for testing chat formatters. - Implement `test_llama2_formatter` to validate Llama2 message formatting. Added unit tests to ensure the correctness of the newly refactored Llama2Formatter. This ensures that message formatting adheres to the expected template.
- Introduce pytest fixture `sequence_of_messages` in `test_llama_chat_formatters.py`. - Refactor `test_llama2_formatter` to use the new fixture. Utilizing pytest fixtures enhances the modularity of our test suite, allowing for cleaner test cases and potential reusability across multiple tests.
- Introduce `BASE_TEMPLATE` for common chat formatting structure. - Implement a protocol-based `ChatFormatterTemplate` for custom formatters. - Add `Llama2Formatter` to handle specific Llama-2 formatting. - Create `ChatFormatter` class for registering and retrieving formatters. - Remove redundant functions like `_format_llama2`. Refactored the chat message formatting to use a more structured and extensible approach. Now supports multiple templates and ensures a cleaner codebase.
- Introduce `test_llama_chat_formatters.py` for testing chat formatters. - Implement `test_llama2_formatter` to validate Llama2 message formatting. Added unit tests to ensure the correctness of the newly refactored Llama2Formatter. This ensures that message formatting adheres to the expected template.
- Introduce pytest fixture `sequence_of_messages` in `test_llama_chat_formatters.py`. - Refactor `test_llama2_formatter` to use the new fixture. Utilizing pytest fixtures enhances the modularity of our test suite, allowing for cleaner test cases and potential reusability across multiple tests.
- Introduced `TokenizerCache` to efficiently reuse tokenizers. - Merged specific formatter classes into a generic `ChatFormatterTemplate` leveraging HuggingFace's `AutoTokenizer` and Jinja2 template capabilities. - Simplified the `ChatFormatter` class to manage chat format registrations and perform formatting and parsing operations. - Reduced overall source lines of code while enhancing code clarity and maintainability. Note: This refactor aims to provide a more flexible and extensible approach to chat formatting, making it easier to add and manage different model templates in the future.
For clarity and simplicity, here are the proposed changes. import dataclasses
from typing import Dict, List, Optional
from transformers import AutoTokenizer
from . import llama_types
# NOTE: Custom Templates use Jinja2.
# If no template is given, then should default to hf's tokenizer template.
# We can define the model and template on a model-to-model basis,
# however, this should be allowed to be overridden for flexibility and extensibility.
# We only need 2 keys, the model name and the jinja2 template.
#
# template = {"model": "meta-llama/Llama-2-7b-chat-hf", "template": None}
#
# or
#
# chat_template = {
# "model": "meta-llama/Llama-2-7b-chat-hf",
# "jinja": "{% for message in messages %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + message['content'].strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<<SYS>>\\n' + message['content'].strip() + '\\n<</SYS>>\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ '[ASST] ' + message ['content'] + ' [/ASST]' + eos_token }}{% endif %}{% endfor %}",
# }
#
# We can probably employ some kind of method for reading a template it in from a file in necessary.
#
# We leave template empty here because HuggingFace defined it already.
#
# Source: https://huggingface.co/docs/transformers/main/chat_templating
#
# Special Thanks and Credit goes to bioshazard for the idea and preliminary implementation.
# Source: https://github.com/abetlen/llama-cpp-python/pull/790
# NOTE: We can still use this for reverse compatibility with the currently employed API.
# This can be modified, if needed, in the future.
@dataclasses.dataclass
class ChatFormatterResponse:
prompt: str
stop: Optional[List[str]] = None
class TokenizerCache:
_cache: Dict[str, AutoTokenizer] = {}
@classmethod
def get_tokenizer(cls, model_name: str) -> AutoTokenizer:
if model_name not in cls._cache:
cls._cache[model_name] = AutoTokenizer.from_pretrained(model_name)
return cls._cache[model_name]
class ChatFormatterTemplate:
def __init__(self, template: Optional[Dict[str, str]] = None):
if template:
self.template = template
else:
self.template = {
"model": "meta-llama/Llama-2-7b-chat-hf",
"jinja": None,
"tokenize": False,
}
self.tokenizer = TokenizerCache.get_tokenizer(self.template["model"])
def _format_messages(self, messages: List[Dict[str, str]]) -> str:
# If a custom template is provided, override the tokenizer's default template
if self.template.get("jinja"):
self.tokenizer.chat_template = self.template["jinja"]
return self.tokenizer.apply_chat_template(
messages, tokenize=self.template["tokenize"]
)
def parse_response(self, messages: List[Dict[str, str]]) -> ChatFormatterResponse:
formatted_content = self._format_messages(messages)
return ChatFormatterResponse(
prompt=formatted_content, stop=[self.tokenizer.eos_token]
)
class ChatFormatter:
_chat_formatters: Dict[str, ChatFormatterTemplate] = {}
def register_chat_format(
self, model_name: str, template: Optional[Dict[str, str]] = None
):
self._chat_formatters[model_name] = ChatFormatterTemplate(template)
def get_chat_format(self, model_name: str) -> ChatFormatterTemplate:
if model_name not in self._chat_formatters:
raise ValueError(f"Model {model_name} is not registered.")
return self._chat_formatters[model_name]
def format(self, model_name: str, messages: List[Dict[str, str]]) -> str:
formatter = self.get_chat_format(model_name)
return formatter._format_messages(messages)
def parse(
self, model_name: str, messages: List[Dict[str, str]]
) -> ChatFormatterResponse:
formatter = self.get_chat_format(model_name)
return formatter.parse_response(messages)
# NOTE: Template registration is currently a WIP (work in progress) With this refactoring:
You can easily extend the functionality in the future if needed. However, as of now, this design offers a good balance of simplicity, flexibility, and efficiency. @bioshazard @abetlen @delock @earonesty Let me know what you guys think. |
I appreciate the thanks and credit! Looking forward to AutoTokenizer support either way. |
It should be noted that this seriously has it's own set of drawbacks. The previous implementation is way more flexible. I'm thinking we should have a middle ground for several reasons.
I feel these points should be considered with a serious level of thought while considering the gravity of the implications of such changes. The middle ground I'm thinking of might be to scrap the current design, including the one's I've come up with, and go for a Jinja2-like template system. Something like what HuggingFace has with their Chat Template system. It'll be easier to test, implement, and streamline as a result. It'll also be interoperable with the HuggingFace API as well by simply hot-swapping templates if a user wants to. We should also make HuggingFace an optional dependency as well due the previously mentioned concerns I have with this approach. |
|
Enabling the ability to create custom datasets and fine-tuned models is a valuable aspect for me. Flexibility to deviate from existing templates should be a core principle in the design we settle on. Ultimately, the direction we take largely depends on @abetlen since it's his repository. I believe it's important for us to collaborate and reach a consensus while striving for maximum flexibility with minimal lines of code (SLoC) and abstraction. Excessive SLoC and abstraction can be a warning sign that we might have done something erroneous in our design. Regarding the use of Jinja-2, I have reservations mainly because I've been using it for a long time, and its original design wasn't intended for this specific use-case. While it might work as a clever workaround, it also raises concerns about our ability to troubleshoot any potential issues that might arise. To ensure the integrity of the expected output, I strongly recommend implementing thorough testing for this approach and treating it as experimental, given its unconventional usage. My motivation for advocating certain tools and approaches is rooted in a desire to maintain a proactive, inclusive, and open approach to accommodating other developers preferences. While I do lean towards pragmatism and minimalism, I'm open to considering alternative approaches because my primary goal is to enhance this library and its utility. I appreciate your support for Jinja-2, even though I express caution in its use. Please let me know if you have any other thoughts or considerations regarding any of the points I've made. I'm open to feedback and healthy criticism. |
Another possible approach would be the following: class ChatFormatterInterface(Protocol):
def __init__(self, template: Optional[Dict[str, Any]] = None):
raise NotImplementedError
def format_messages(
self, messages: List[llama_types.ChatCompletionRequestMessage]
) -> str:
raise NotImplementedError
def parse_response(
self,
messages: List[llama_types.ChatCompletionRequestMessage],
**kwargs,
) -> ChatFormatterResponse:
raise NotImplementedError
class AutoTokenizerFormatter(ChatFormatterInterface):
def __init__(self, template: Optional[Dict[str, str]] = None):
if template:
self.template = template
else:
self.template = {
"model": "meta-llama/Llama-2-7b-chat-hf",
"jinja": None,
"tokenize": False,
}
self.tokenizer = TokenizerCache.get_tokenizer(self.template["model"])
def format_messages(
self, messages: List[llama_types.ChatCompletionRequestMessage]
) -> str:
# If a custom template is provided, override the tokenizer's default template
if self.template.get("jinja"):
self.tokenizer.chat_template = self.template["jinja"]
return self.tokenizer.apply_chat_template(
messages, tokenize=self.template["tokenize"]
)
def parse_response(
self,
messages: List[llama_types.ChatCompletionRequestMessage],
**kwargs,
) -> ChatFormatterResponse:
formatted_content = self.format_messages(messages)
return ChatFormatterResponse(
prompt=formatted_content, stop=[self.tokenizer.eos_token]
)
class ChatFormatter:
_chat_formatters: Dict[str, ChatFormatterInterface] = {}
@staticmethod
def register_chat_format(name: str):
def decorator(cls: Type[ChatFormatterInterface]):
ChatFormatter._chat_formatters[name] = cls()
return cls
return decorator
def get_chat_format(self, name: str) -> ChatFormatterInterface:
try:
return self._chat_formatters[name]
except KeyError:
raise ValueError(
f"Invalid chat format: {name}. Valid formats: {list(self._chat_formatters.keys())}"
)
def format(self, name: str, messages: List[Dict[str, str]]) -> str:
formatter = self.get_chat_format(name)
return formatter.format_messages(messages)
def parse(
self,
name: str,
messages: List[llama_types.ChatCompletionRequestMessage],
**kwargs,
) -> ChatFormatterResponse:
formatter = self.get_chat_format(name)
return formatter.parse_response(messages) This would create a form of backwards compatability. # NOTE: Template registration is currently a WIP (work in progress)
# External developers can now use the `@ChatFormatter.register_chat_format`
# decorator to register their own custom formatters.
@ChatFormatter.register_chat_format("llama-2")
class Llama2Formatter(AutoTokenizerFormatter):
def __init__(self, template: Optional[Dict[str, Any]] = None):
super().__init__(template) It also allows us to override the implementation if necessary. # NOTE: Template registration is currently a WIP (work in progress)
# External developers can now use the `@ChatFormatter.register_chat_format`
# decorator to register their own custom formatters.
@ChatFormatter.register_chat_format("llama-2")
class Llama2Formatter(ChatFormatterInterface):
_system_template = "<<SYS>>{system_message}<</SYS>>\n"
_roles = dict(user="[INST]", assistant="[/INST]")
_sep = "\n"
def _get_system_message(self, messages: List[llama_types.ChatCompletionRequestMessage]) -> str:
try:
if messages[0]["role"] == "system":
return self._system_template.format(system_message=messages[0]["content"])
return ""
except (IndexError, KeyError):
return ""
def _map_roles(self, messages: List[llama_types.ChatCompletionRequestMessage]) -> List[Tuple[str, str]]:
mapped_messages = []
for message in messages:
if message["role"] in self._roles:
mapped_messages.append((self._roles[message["role"]], message["content"]))
mapped_messages.append((self._roles["assistant"], None))
return mapped_messages
def format_messages(self, messages: List[llama_types.ChatCompletionRequestMessage]) -> str:
system_message = self._get_system_message(messages)
mapped_messages = self._map_roles(messages)
return system_message + self._sep.join([msg for role, msg in mapped_messages if msg])
def parse(
self,
messages: List[llama_types.ChatCompletionRequestMessage],
**kwargs,
) -> ChatFormatterResponse:
formatted_content = self.format_messages(messages)
return ChatFormatterResponse(prompt=formatted_content) This still requires a more well thought out design though. Personally, I prefer the previous approach I took in an earlier commit; Commit 361e25. |
- Added introductory comment explaining module purpose - Defined default templates for HuggingFace and common roles - Introduced Llama2Formatter and AlpacaFormatter classes - Registered predefined chat format models - Implemented ChatFormatterFactory for managing formatters These changes enhance the flexibility and customization of chat formatting, allowing for the registration of custom formatters and providing default templates for different chat models.
- Consolidated and isolated the HuggingFace login process for improved security. - Used the module name `huggingface_hub` instead of directly importing the `login` function for clarity. - Corrected the formatter name to "llama-2" for consistency. These changes enhance security by isolating the login process and improve code clarity by using the module name for HuggingFace operations.
…hat_format.py - Updated the important notes section with clear and concise information about special tokens and Python version compatibility. - Anonymized the example templates by replacing names with "Llama" and "User" for clarity. - Made formatting changes to improve code readability and organization in llama_chat_format.py. - Added the Vicuna model template. This commit enhances the clarity of important notes, anonymizes example templates, and improves code formatting in llama_chat_format.py.
- Added basic type definitions for better code clarity. - Removed repetitive comments in the code. - Added a note about the Vicuna template being version 1.5 and differing from v0. - Applied new type definitions to chat templates. - Introduced a new VicunaFormatter class to replace the older one, improving code readability. This commit enhances code clarity, maintains consistency, and improves the structure of the codebase.
…template in llama_chat_format.py - Replaced the Open Assistant Hybrid chat template with the classical template. - Added the original Open Assistant chat template for non-hybrid models. This commit streamlines the chat templates by using the classical template for Open Assistant and adds the original template for non-hybrid models, reducing duplication.
I notice that in HF, i can use |
It would be wrapped. We would ideally just pass the template in during instantiation. I'm still working out the details in the current draft, but I'm considering replacing the current templates with jinja2 templates at some point in the future. The caveat between the 2 is that,
I think it would be preferable to manage a single approach while accommodating each interface appropriately. You can see it here: https://github.com/teleprint-me/llama-cpp-python/blob/refactored.templates/llama_cpp/llama_chat_format.py#L280 |
That was the motivation behind my PR ( #790 ), tho this PR here is a more comprehensive approach. |
can we just stick with Jinja? or at least that should be an option If somebody specifies a jinja template as a chat format we could just use it Right now our formats are very flexible |
Yeah, I included it in the original draft. I'm still waiting on feedback. Also, I have a lot going on and my mental bandwidth is low at the moment. Jinja2 would already be included as a dependency with transformers, so I don't really see that as an issue. |
i was thinking of a lot less structure. literally just copying the jinja templates from huggingface and executing them in the right context. for example, this is the mistral template:
|
I think it's crucial to consider what's happening here and what the goal of this draft is. The ChatML structure is passed in as input, which is a sequence of dictionaries representing roles and their associated content. That sequence is passed in as input and then reformatted into a structured string that llama.cpp expects as input. Appropriately formatting that input is the goal here. The problem is the variety and variability of templates that require management. This is why @abetlen used the following structure, @dataclasses.dataclass
class ChatFormatterResponse:
prompt: str # This is the output to llama.cpp
stop: Optional[Union[str, List[str]]] = None And returned it as a result. @register_chat_format("llama-2")
def format_llama2(
# `messages` is the input we need to transform and format.
messages: List[llama_types.ChatCompletionRequestMessage],
**kwargs: Any,
) -> ChatFormatterResponse:
# ...
# return the transformed and formatted string for model input.
return ChatFormatterResponse(prompt=_prompt) So, if we're to consider your proposal, that would mean having a template string to be parsed for each model baked in, which is admittedly a bit of an improvement, but fails to solve the overarching problem at hand. Consider your example: mistral = {
"bos_token": "<s>",
"chat_template": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}",
"eos_token": "</s>",
} We technically don't need the special tokens at all because everything necessary to format the string should be included already. mistral_template = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}" This is all we would technically need. llama.cpp doesn't expect much and there are experimental mechanisms for defining the prefix and postfix, but none of these solutions really solve the root problem of the inherit variability involved in a chat template. For example, we would need to rely on the user being able to define a template for FIM (Fill-in-the-Middle) models. This is why I considered @bioshazard original proposal, because I mean, if we really want to reduce the amount of SLoC, we could consider scrapping the joint interfaces and simply go for integrating HuggingFace. class AutoTokenizerFormatter(ChatFormatterInterface):
def __init__(self, template: Optional[Dict[str, str]] = None):
self.template = template or huggingface_template
self.huggingface_login()
self.tokenizer = TokenizerCache.get_tokenizer(self.template["model"])
def __call__(
self,
messages: List[llama_types.ChatCompletionRequestMessage],
**kwargs,
) -> ChatFormatterResponse:
formatted_content = self.format_messages(messages)
return ChatFormatterResponse(
prompt=formatted_content, stop=[self.tokenizer.eos_token]
)
def huggingface_login(self) -> None:
# NOTE: Keep in mind there are other ways to do this and this is only a draft.
token = os.getenv("HF_TOKEN")
if token is None:
raise AttributeError(
"Failed to login to huggingface. "
"Did you forget to set the `HF_TOKEN` environment variable with your huggingface token?"
)
huggingface_hub.login(token)
def format_messages(
self, messages: List[llama_types.ChatCompletionRequestMessage]
) -> str:
# If a custom template is provided, override the tokenizer's default template
if self.template.get("jinja"):
self.tokenizer.chat_template = self.template["jinja"]
return self.tokenizer.apply_chat_template(
messages, tokenize=self.template.get("tokenize", False)
) But as I previously stated, this seriously has its own set of drawbacks. My rationale behind all of this is that we would need to be prepared to parse and build the template, then format it to a string, and then pass it back, even if we chose to use only jinja2. The upside to this approach is it would apparently make it compatible with HuggingFace This, unfortunately, is not really a simple problem. It's one of those deceptively simple problems up until you start trying to solve it. |
Seems like most of the drawbacks were the inappropriate excessive dependencies. If auto tokenizers chat templating function could be broken into its own library with minimal and no conflicting dependencies, would you support its candidacy for universal chat template transparency? Especially if you could refer to a local file too rather than a relative model path. |
The problem with HuggingFace is introducing unnecessary dependencies that will literally only be used to gain access to the templates. It's hard to justify this argument in my own mind. |
I deleted my last comment after rereading yours. Valid points about the dependencies. Curious of your thoughts on my replacement reply. |
I don't know. I can't really find any valid arguments against @earonesty proposal. I'm only attempting to elucidate the complexity involved in his proposal. Not shut it down. I want him to genuinely think about it. I find this to be a hard problem. I'm hoping we can come up with an elegant solution together. |
Not suggesting his proposal is inappropriate. Maybe I'm just thinking out loud about attempting to write a minimal library to break out the model template functionality as I offered. |
I feel like the base idea of just supporting raw jinja , and having it be
compatible with the templates that people typically use in HF is fairly
simple to implement
I'm still okay with leaving the existing Python function based templates as
an option
those are more flexible and maybe there's some things you can do with them
that you can't do with jina
…On Sat, Nov 4, 2023, 11:45 AM Joe ***@***.***> wrote:
Not suggesting his proposal is inappropriate. Maybe I'm just thinking out
loud about attempting to write a minimal library to break out the model
template functionality as I offered.
—
Reply to this email directly, view it on GitHub
<#809 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAAMMUI4K7XOMEXE6CV3Q7TYCZPJBAVCNFSM6AAAAAA5ZWMTR6VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTOOJTGQ3TSOJWGY>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
I hear you and I'm inclined to feel the same. There's pros and cons to every approach here. It's important to realize that we'll always be making a compromise. I'm hoping to go for one of those solutions that's only obvious in hindsight where we're all happy with it and it's easy to read, understand, and implement. If you have something more concrete, rather than an abstract notion, I'm all ears. A concrete snippet exemplifying what you're envisioning would be helpful. Sharing is caring 😉 |
I made a new draft in a separate branch as this branch now conflicts with recent changes. teleprint-me:llama-cpp-python:jinja2-templates It's the same general idea and reduces complexity, sloc, and uses the idea of jinja2 templating. class AutoChatFormatter(ChatFormatterInterface):
def __init__(
self,
template: Optional[str] = None,
template_class: Optional[Template] = None,
):
if template is not None:
self._template = template
else:
self._template = llama2_template # default template
self._renderer = jinja2.Environment(
loader=jinja2.BaseLoader(),
trim_blocks=True,
lstrip_blocks=True,
).from_string(
self._template,
template_class=template_class,
)
def __call__(
self,
messages: List[Dict[str, str]],
**kwargs: Any,
) -> ChatFormatterResponse:
formatted_sequence = self._renderer.render(messages=messages, **kwargs)
return ChatFormatterResponse(prompt=formatted_sequence)
@property
def template(self) -> str:
return self._template See PR #875 for more information. |
Draft:
Refactor Chat Formatter for Enhanced Flexibility and Extensibility
Description:
Summary:
This PR aims to refactor the chat formatter to provide a more flexible and extensible framework. It introduces a foundational template and protocol that allows for easy addition and management of various chat formatting styles.
Details:
Refactored Chat Formatter:
BASE_TEMPLATE
) to serve as the foundation for various chat formatting styles.ChatFormatterTemplate
protocol, which provides a standardized interface for all formatter classes._get_system_message
and_map_roles
to this protocol, promoting code reuse and reducing redundancy.Llama2 Formatter:
Llama2Formatter
, from theChatFormatterTemplate
.Testing and Validation:
Commit Strategy:
Future Goals:
Feedback Requested:
I'm looking for early feedback on the current approach, especially concerning the overall structure and the introduction of the new
ChatFormatterTemplate
protocol. Any suggestions or improvements are welcome!