Skip to content

Commit

Permalink
feat: Add add_generation_prompt option for jinja2chatformatter.
Browse files Browse the repository at this point in the history
  • Loading branch information
abetlen committed Jan 21, 2024
1 parent ac2e96d commit 7f3209b
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions llama_cpp/llama_chat_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,11 +152,13 @@ def __init__(
template: str,
eos_token: str,
bos_token: str,
add_generation_prompt: bool = True,
):
"""A chat formatter that uses jinja2 templates to format the prompt."""
self.template = template
self.eos_token = eos_token
self.bos_token = bos_token
self.add_generation_prompt = add_generation_prompt

self._environment = jinja2.Environment(
loader=jinja2.BaseLoader(),
Expand All @@ -170,12 +172,13 @@ def __call__(
messages: List[llama_types.ChatCompletionRequestMessage],
**kwargs: Any,
) -> ChatFormatterResponse:
messages = [
*messages,
llama_types.ChatCompletionRequestAssistantMessage(
role="assistant", content=""
),
]
if self.add_generation_prompt:
messages = [
*messages,
llama_types.ChatCompletionRequestAssistantMessage(
role="assistant", content=""
),
]
prompt = self._environment.render(
messages=messages, eos_token=self.eos_token, bos_token=self.bos_token
)
Expand Down

0 comments on commit 7f3209b

Please sign in to comment.