Skip to content

Commit

Permalink
Automatically set chat format from gguf (abetlen#1110)
Browse files Browse the repository at this point in the history
* Use jinja formatter to load chat format from gguf

* Fix off-by-one error in metadata loader

* Implement chat format auto-detection
  • Loading branch information
abetlen authored Jan 29, 2024
1 parent 059f6b3 commit da003d8
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 7 deletions.
4 changes: 2 additions & 2 deletions llama_cpp/_internals.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,13 +216,13 @@ def metadata(self) -> Dict[str, str]:
for i in range(llama_cpp.llama_model_meta_count(self.model)):
nbytes = llama_cpp.llama_model_meta_key_by_index(self.model, i, buffer, buffer_size)
if nbytes > buffer_size:
buffer_size = nbytes
buffer_size = nbytes + 1
buffer = ctypes.create_string_buffer(buffer_size)
nbytes = llama_cpp.llama_model_meta_key_by_index(self.model, i, buffer, buffer_size)
key = buffer.value.decode("utf-8")
nbytes = llama_cpp.llama_model_meta_val_str_by_index(self.model, i, buffer, buffer_size)
if nbytes > buffer_size:
buffer_size = nbytes
buffer_size = nbytes + 1
buffer = ctypes.create_string_buffer(buffer_size)
nbytes = llama_cpp.llama_model_meta_val_str_by_index(self.model, i, buffer, buffer_size)
value = buffer.value.decode("utf-8")
Expand Down
37 changes: 36 additions & 1 deletion llama_cpp/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __init__(
# Backend Params
numa: bool = False,
# Chat Format Params
chat_format: str = "llama-2",
chat_format: Optional[str] = None,
chat_handler: Optional[llama_chat_format.LlamaChatCompletionHandler] = None,
# Misc
verbose: bool = True,
Expand Down Expand Up @@ -343,6 +343,41 @@ def __init__(
if self.verbose:
print(f"Model metadata: {self.metadata}", file=sys.stderr)

if self.chat_format is None and self.chat_handler is None and "tokenizer.chat_template" in self.metadata:
chat_format = llama_chat_format.guess_chat_format_from_gguf_metadata(self.metadata)

if chat_format is not None:
self.chat_format = chat_format
if self.verbose:
print(f"Guessed chat format: {chat_format}", file=sys.stderr)
else:
template = self.metadata["tokenizer.chat_template"]
try:
eos_token_id = int(self.metadata["tokenizer.ggml.eos_token_id"])
except:
eos_token_id = self.token_eos()
try:
bos_token_id = int(self.metadata["tokenizer.ggml.bos_token_id"])
except:
bos_token_id = self.token_bos()

eos_token = self.detokenize([eos_token_id]).decode("utf-8")
bos_token = self.detokenize([bos_token_id]).decode("utf-8")

if self.verbose:
print(f"Using chat template: {template}", file=sys.stderr)
print(f"Using chat eos_token: {eos_token}", file=sys.stderr)
print(f"Using chat bos_token: {bos_token}", file=sys.stderr)

self.chat_handler = llama_chat_format.Jinja2ChatFormatter(
template=template,
eos_token=eos_token,
bos_token=bos_token
).to_chat_handler()

if self.chat_format is None and self.chat_handler is None:
self.chat_format = "llama-2"

@property
def ctx(self) -> llama_cpp.llama_context_p:
assert self._ctx.ctx is not None
Expand Down
30 changes: 28 additions & 2 deletions llama_cpp/llama_chat_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,20 @@

from ._utils import suppress_stdout_stderr, Singleton

### Common Chat Templates and Special Tokens ###

# Source: https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/blob/main/tokenizer_config.json
CHATML_CHAT_TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
CHATML_BOS_TOKEN = "<s>"
CHATML_EOS_TOKEN = "<|im_end|>"

# Source: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/blob/main/tokenizer_config.json
MISTRAL_INSTRUCT_CHAT_TEMPLATE = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"
MISTRAL_INSTRUCT_BOS_TOKEN = "<s>"
MISTRAL_INSTRUCT_EOS_TOKEN = "</s>"


### Chat Completion Handler ###

class LlamaChatCompletionHandler(Protocol):
"""Base Protocol for a llama chat completion handler.
Expand Down Expand Up @@ -118,7 +132,6 @@ def decorator(f: LlamaChatCompletionHandler):

### Chat Formatter ###


@dataclasses.dataclass
class ChatFormatterResponse:
"""Dataclass that stores completion parameters for a given chat format and
Expand Down Expand Up @@ -440,7 +453,20 @@ def hf_tokenizer_config_to_chat_completion_handler(
return chat_formatter_to_chat_completion_handler(chat_formatter)


def guess_chat_format_from_gguf_metadata(metadata: Dict[str, str]) -> Optional[str]:
if "tokenizer.chat_template" not in metadata:
return None

if metadata["tokenizer.chat_template"] == CHATML_CHAT_TEMPLATE:
return "chatml"

if metadata["tokenizer.chat_template"] == MISTRAL_INSTRUCT_CHAT_TEMPLATE:
return "mistral-instruct"

return None

### Utility functions for formatting chat prompts ###
# TODO: Replace these with jinja2 templates


def _get_system_message(
Expand Down Expand Up @@ -929,7 +955,6 @@ def format_openchat(
_prompt = _format_chatml(system_message, _messages, _sep)
return ChatFormatterResponse(prompt=_prompt, stop=_sep)


# Chat format for Saiga models, see more details and available models:
# https://huggingface.co/collections/IlyaGusev/saiga2-saigamistral-6505d4ccc3d1e53166b636cd
@register_chat_format("saiga")
Expand All @@ -951,6 +976,7 @@ def format_saiga(
_prompt += "<s>bot"
return ChatFormatterResponse(prompt=_prompt.strip())

# Tricky chat formats that require custom chat handlers

@register_chat_completion_handler("functionary")
def functionary_chat_handler(
Expand Down
4 changes: 2 additions & 2 deletions llama_cpp/server/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,8 @@ class ModelSettings(BaseSettings):
description="Enable NUMA support.",
)
# Chat Format Params
chat_format: str = Field(
default="llama-2",
chat_format: Optional[str] = Field(
default=None,
description="Chat format to use.",
)
clip_model_path: Optional[str] = Field(
Expand Down

0 comments on commit da003d8

Please sign in to comment.