Skip to content

Commit

Permalink
feat: Integrate functionary v1.4 and v2 models + add custom tokenizer…
Browse files Browse the repository at this point in the history
… support to Llama class (abetlen#1078)

* convert functionary-v1 chat handler to use hf autotokenizer

* add hf_tokenizer + inteegrate functionary-v1.4 prompt template

* integrate functionary v2 prompt template

* update readme

* set up parallel function calling wip

* set up parallel function calling

* Update README.md

* Update README.md

* refactor tokenizers

* include old functionary handler for backward compatibility

* add hf_tokenizer_path in server ModelSettings

* convert functionary-v1 chat handler to use hf autotokenizer

* add hf_tokenizer + inteegrate functionary-v1.4 prompt template

* integrate functionary v2 prompt template

* update readme

* set up parallel function calling wip

* resolve merge conflict

* Update README.md

* Update README.md

* refactor tokenizers

* include old functionary handler for backward compatibility

* add hf_tokenizer_path in server ModelSettings

* Cleanup PR, fix breaking changes

* Use hf_pretrained_model_name_or_path for tokenizer

* fix hf tokenizer in streaming

* update README

* refactor offset mapping

---------

Co-authored-by: Andrei <[email protected]>
  • Loading branch information
jeffrey-fong and abetlen authored Feb 8, 2024
1 parent 34f3104 commit 9018270
Show file tree
Hide file tree
Showing 4 changed files with 525 additions and 34 deletions.
19 changes: 8 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -293,19 +293,16 @@ To constrain the response to a specific JSON Schema, you can use the `schema` pr

The high-level API also provides a simple interface for function calling.

Note that the only model that supports full function calling at this time is "functionary".
The gguf-converted files for this model can be found here: [functionary-7b-v1](https://huggingface.co/abetlen/functionary-7b-v1-GGUF)
The only set of models that supports full function calling at this time is [functionary](https://github.com/MeetKai/functionary). The various gguf-converted files for this set of models can be found [here](https://huggingface.co/meetkai). Functionary is able to intelligently call functions and also analyze any provided function outputs to generate coherent responses. All v2 models of functionary supports **parallel function calling**. You can provide either `functionary-v1` or `functionary-v2` for the `chat_format` when initializing the Llama class.

Note that due to discrepancies between llama.cpp and HuggingFace's tokenizers, it is required to provide HF Tokenizer for functionary. The `LlamaHFTokenizer` class can be initialized and passed into the Llama class. This will override the default llama.cpp tokenizer used in Llama class. The tokenizer files are already included in the respective HF repositories hosting the gguf files.

```python
>>> from llama_cpp import Llama
>>> llm = Llama(model_path="path/to/functionary/llama-model.gguf", chat_format="functionary")
>>> from llama_cpp import Llama, LlamaHFTokenizer
>>> tokenizer = LlamaHFTokenizer.from_pretrained("path/to/functionary/")
>>> llm = Llama(model_path="path/to/functionary/llama-model.gguf", tokenizer=tokenizer, chat_format="functionary-v2")
>>> llm.create_chat_completion(
messages = [
{
"role": "system",
"content": "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. The assistant calls functions with appropriate input when necessary"

},
{
"role": "user",
"content": "Extract Jason is 25 years old"
Expand All @@ -332,12 +329,12 @@ The gguf-converted files for this model can be found here: [functionary-7b-v1](h
}
}
}],
tool_choice=[{
tool_choice={
"type": "function",
"function": {
"name": "UserDetail"
}
}]
},
)
```

Expand Down
101 changes: 79 additions & 22 deletions llama_cpp/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import os
import sys
import abc
import uuid
import time
import multiprocessing
Expand All @@ -14,11 +15,14 @@
Iterator,
Deque,
Callable,
Any,
)
from collections import deque

import ctypes

from llama_cpp.llama_types import List

from .llama_types import *
from .llama_grammar import LlamaGrammar
from .llama_cache import (
Expand Down Expand Up @@ -95,6 +99,8 @@ def __init__(
chat_handler: Optional[llama_chat_format.LlamaChatCompletionHandler] = None,
# Speculative Decoding
draft_model: Optional[LlamaDraftModel] = None,
# Tokenizer Override
tokenizer: Optional[BaseLlamaTokenizer] = None,
# Misc
verbose: bool = True,
# Extra Params
Expand Down Expand Up @@ -159,6 +165,7 @@ def __init__(
chat_format: String specifying the chat format to use when calling create_chat_completion.
chat_handler: Optional chat handler to use when calling create_chat_completion.
draft_model: Optional draft model to use for speculative decoding.
tokenizer: Optional tokenizer to override the default tokenizer from llama.cpp.
verbose: Print verbose output to stderr.
Raises:
Expand Down Expand Up @@ -235,6 +242,7 @@ def __init__(
self.n_threads_batch = n_threads_batch or max(
multiprocessing.cpu_count() // 2, 1
)

# Context Params
self.context_params = llama_cpp.llama_context_default_params()
self.context_params.seed = seed
Expand Down Expand Up @@ -286,6 +294,10 @@ def __init__(
self._model = _LlamaModel(
path_model=self.model_path, params=self.model_params, verbose=self.verbose
)

# Override tokenizer
self.tokenizer_ = tokenizer or LlamaTokenizer(self)

# Set the default value for the context and correct the batch
if n_ctx == 0:
n_ctx = self._model.n_ctx_train()
Expand Down Expand Up @@ -431,18 +443,19 @@ def tokenize(
Returns:
A list of tokens.
"""
return self._model.tokenize(text, add_bos, special)
return self.tokenizer_.tokenize(text, add_bos, special)

def detokenize(self, tokens: List[int]) -> bytes:
def detokenize(self, tokens: List[int], prev_tokens: Optional[List[int]] = None) -> bytes:
"""Detokenize a list of tokens.
Args:
tokens: The list of tokens to detokenize.
prev_tokens: The list of previous tokens. Offset mapping will be performed if provided
Returns:
The detokenized string.
"""
return self._model.detokenize(tokens)
return self.tokenizer_.detokenize(tokens, prev_tokens)

def set_cache(self, cache: Optional[BaseLlamaCache]):
"""Set the cache.
Expand Down Expand Up @@ -935,7 +948,8 @@ def logit_bias_processor(

if stream:
remaining_tokens = completion_tokens[returned_tokens:]
remaining_text = self.detokenize(remaining_tokens)
prev_tokens = completion_tokens[:returned_tokens]
remaining_text = self.detokenize(completion_tokens, prev_tokens)
remaining_length = len(remaining_text)

# We want to avoid yielding any characters from
Expand All @@ -957,13 +971,13 @@ def logit_bias_processor(
for token in remaining_tokens:
if token == self.token_bos():
continue
token_end_position += len(self.detokenize([token]))
token_end_position += len(remaining_text)
# Check if stop sequence is in the token
if token_end_position > (
remaining_length - first_stop_position
):
break
token_str = self.detokenize([token]).decode(
token_str = remaining_text.decode(
"utf-8", errors="ignore"
)
text_offset = len(prompt) + len(
Expand All @@ -988,11 +1002,7 @@ def logit_bias_processor(
}
top_logprob.update({token_str: current_logprobs[int(token)]})
logprobs_or_none = {
"tokens": [
self.detokenize([token]).decode(
"utf-8", errors="ignore"
)
],
"tokens": [token_str],
"text_offset": [text_offset],
"token_logprobs": [current_logprobs[int(token)]],
"top_logprobs": [top_logprob],
Expand All @@ -1005,9 +1015,7 @@ def logit_bias_processor(
"model": model_name,
"choices": [
{
"text": self.detokenize([token]).decode(
"utf-8", errors="ignore"
),
"text": token_str,
"index": 0,
"logprobs": logprobs_or_none,
"finish_reason": None,
Expand All @@ -1019,7 +1027,7 @@ def logit_bias_processor(
decode_success = False
for i in range(1, len(remaining_tokens) + 1):
try:
bs = self.detokenize(remaining_tokens[:i])
bs = remaining_text
ts = bs.decode("utf-8")
decode_success = True
break
Expand Down Expand Up @@ -1055,6 +1063,7 @@ def logit_bias_processor(

if len(completion_tokens) >= max_tokens:
text = self.detokenize(completion_tokens)

finish_reason = "length"
break

Expand Down Expand Up @@ -1693,8 +1702,8 @@ def n_vocab(self) -> int:
"""Return the vocabulary size."""
return self._model.n_vocab()

def tokenizer(self) -> "LlamaTokenizer":
"""Return the tokenizer for this model."""
def tokenizer(self) -> LlamaTokenizer:
"""Return the llama tokenizer for this model."""
return LlamaTokenizer(self)

def token_eos(self) -> int:
Expand Down Expand Up @@ -1738,23 +1747,71 @@ def longest_token_prefix(a: Sequence[int], b: Sequence[int]):
return longest_prefix


class LlamaTokenizer:
class BaseLlamaTokenizer(abc.ABC):
@abc.abstractmethod
def tokenize(self, text: bytes, add_bos: bool = True, special: bool = True) -> List[int]:
raise NotImplementedError

@abc.abstractmethod
def detokenize(self, tokens: List[int], prev_tokens: Optional[List[int]] = None) -> bytes:
raise NotImplementedError


class LlamaTokenizer(BaseLlamaTokenizer):
def __init__(self, llama: Llama):
self.llama = llama
self._model = llama._model # type: ignore

def tokenize(self, text: bytes, add_bos: bool = True, special: bool = True) -> List[int]:
return self._model.tokenize(text, add_bos=add_bos, special=special)

def detokenize(self, tokens: List[int], prev_tokens: Optional[List[int]] = None) -> bytes:
if prev_tokens is not None:
return self._model.detokenize(tokens[len(prev_tokens):])
else:
return self._model.detokenize(tokens)

def encode(self, text: str, add_bos: bool = True) -> List[int]:
return self.llama.tokenize(
text.encode("utf-8", errors="ignore"), add_bos=add_bos, special=True
def encode(self, text: str, add_bos: bool = True, special: bool = True) -> List[int]:
return self.tokenize(
text.encode("utf-8", errors="ignore"), add_bos=add_bos, special=special
)

def decode(self, tokens: List[int]) -> str:
return self.llama.detokenize(tokens).decode("utf-8", errors="ignore")
return self.detokenize(tokens).decode("utf-8", errors="ignore")

@classmethod
def from_ggml_file(cls, path: str) -> "LlamaTokenizer":
return cls(Llama(model_path=path, vocab_only=True))


class LlamaHFTokenizer(BaseLlamaTokenizer):
def __init__(self, hf_tokenizer: Any):
self.hf_tokenizer = hf_tokenizer

def tokenize(self, text: bytes, add_bos: bool = True, special: bool = True) -> List[int]:
return self.hf_tokenizer.encode(text.decode("utf-8", errors="ignore"), add_special_tokens=special)

def detokenize(self, tokens: List[int], prev_tokens: Optional[List[int]] = None) -> bytes:
if prev_tokens is not None:
text = self.hf_tokenizer.decode(tokens).encode("utf-8", errors="ignore")
prev_text = self.hf_tokenizer.decode(prev_tokens).encode("utf-8", errors="ignore")
return text[len(prev_text):]
else:
return self.hf_tokenizer.decode(tokens).encode("utf-8", errors="ignore")

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str) -> "LlamaHFTokenizer":
try:
from transformers import AutoTokenizer
except ImportError:
raise ImportError(
"The `transformers` library is required to use the `HFTokenizer`."
"You can install it with `pip install transformers`."
)
hf_tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=pretrained_model_name_or_path)
return cls(hf_tokenizer)


class LlamaState:
def __init__(
self,
Expand Down
Loading

0 comments on commit 9018270

Please sign in to comment.