Skip to content

Commit

Permalink
Merge pull request #234 from MeetKai/fix_get_template_from_tokenizer
Browse files Browse the repository at this point in the history
fix function get_prompt_template from tokenizer
  • Loading branch information
musab-mk authored Aug 7, 2024
2 parents 3b5e585 + 0d6725e commit a530177
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 1 deletion.
13 changes: 13 additions & 0 deletions functionary/prompt_template/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from functionary.prompt_template.llava_prompt_template import LlavaLlama
from functionary.prompt_template.prompt_template_v1 import PromptTemplateV1
from functionary.prompt_template.prompt_template_v2 import PromptTemplateV2
import re


def get_available_prompt_template_versions() -> List[PromptTemplate]:
Expand Down Expand Up @@ -61,6 +62,18 @@ def get_prompt_template_from_tokenizer(tokenizer: Any) -> PromptTemplate:
Returns:
_type_: _description_
"""
# find prompt template using jinja chat template first
for version in _TEMPLATE_DIC:
if _TEMPLATE_DIC[version].get_chat_template_jinja() == tokenizer.chat_template:
return _TEMPLATE_DIC[version]

# find prompt template by searching for version information in jinja tempalte comment, e.g: {# version=abc #}
chat_template = tokenizer.chat_template
match = re.search("\{\# version=(?P<version_name>.+) \#\}", chat_template)
if match:
version_name = match.group("version_name").strip()
return _TEMPLATE_DIC[version_name]

p1 = PromptTemplateV1.get_prompt_template()
p2 = _TEMPLATE_DIC[PromptTemplateV2.version]
p3 = _TEMPLATE_DIC[Llama3Template.version]
Expand Down
2 changes: 1 addition & 1 deletion functionary/prompt_template/base_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def get_raw_response_from_assistant_message(

def get_chat_template_jinja(self):
"""Return chat_template in jinja format"""
return ""
return "{# " + f"version={self.version}" + " #}"

def get_generation_prefix_for_tool_choice(self, tool_choice: Any):
if tool_choice == "auto" or tool_choice is None:
Expand Down

0 comments on commit a530177

Please sign in to comment.