Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add example script for rendering jinja2 templates #7246

Closed
wants to merge 30 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
eac2e83
gguf: Add example script for extracting chat template
teleprint-me May 12, 2024
f572213
Merge branch 'ggerganov:master' into gguf-model-template
teleprint-me May 12, 2024
bf5154f
docs: Fix filename in docstring and remove return type from main
teleprint-me May 13, 2024
4a018e7
feat: Add assistant turn
teleprint-me May 13, 2024
8b9ed88
patch: Handle how templates are rendered if no system prompt is allowed
teleprint-me May 13, 2024
668c7ee
refactor: Use render template instead of format
teleprint-me May 13, 2024
fa0b0b1
feat: Allow toggling verbosity
teleprint-me May 13, 2024
6be3576
feat: Add sane defaults and options for setting special tokens
teleprint-me May 13, 2024
214e9e6
refactor: Add logging debug and clean up logger implementation
teleprint-me May 13, 2024
f8bb223
refactor: Remove rename from display to render and return result inst…
teleprint-me May 13, 2024
da96fdd
patch: Apply patch for advisories/GHSA-56xg-wfcc-g829
teleprint-me May 13, 2024
cfe659d
feat: Add option for adding generation prompt
teleprint-me May 13, 2024
b4b6f1f
fix: End messages with a user role due to jinja2 conditional checks
teleprint-me May 15, 2024
2185e5c
docs: Update and fix CLI help descriptions
teleprint-me May 15, 2024
8b67acc
Merge branch 'ggerganov:master' into gguf-model-template
teleprint-me May 20, 2024
3c23d9f
Merge branch 'master' into gguf-model-template
teleprint-me Jun 3, 2024
174bb3b
Merge branch 'gguf-model-template' of github.com:teleprint-me/llama.c…
teleprint-me Jun 3, 2024
1b18688
Merge branch 'ggerganov:master' into gguf-model-template
teleprint-me Jul 14, 2024
4204cab
chore : Apply snake case as described in #8305
teleprint-me Jul 14, 2024
b7528fd
chore : Add jinja2 as dev dependency in pyproject.toml and explicit d…
teleprint-me Jul 15, 2024
0cb404c
feat : Add shebang and executable bit to enable script execution
teleprint-me Jul 15, 2024
f455e82
Merge branch 'ggerganov:master' into gguf-model-template
teleprint-me Jul 15, 2024
27070de
Merge branch 'master' of github.com:teleprint-me/llama.cpp into gguf-…
teleprint-me Jul 19, 2024
5481cec
Merge branch 'ggerganov:master' into gguf-model-template
teleprint-me Jul 19, 2024
6875ace
Merge branch 'gguf-model-template' of github.com:teleprint-me/llama.c…
teleprint-me Jul 19, 2024
0de43fc
chore : Set sentencepiece to 0.2.0 to match requirements.txt
teleprint-me Jul 19, 2024
fe88330
chore : Add gguf_template to poetry scripts
teleprint-me Jul 19, 2024
a083c6c
chore : Add gguf template entrypoint to scripts sub package and fix t…
teleprint-me Jul 19, 2024
43eef8d
Merge branch 'ggerganov:master' into gguf-model-template
teleprint-me Jul 26, 2024
964ee4b
Merge branch 'ggerganov:master' into gguf-model-template
teleprint-me Jul 28, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions gguf-py/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ pyyaml = ">=5.1"

[tool.poetry.dev-dependencies]
pytest = "^5.2"
jinja2 = ">=3.1.0"

[build-system]
requires = ["poetry-core>=1.0.0"]
Expand All @@ -36,3 +37,4 @@ gguf-convert-endian = "scripts:gguf_convert_endian_entrypoint"
gguf-dump = "scripts:gguf_dump_entrypoint"
gguf-set-metadata = "scripts:gguf_set_metadata_entrypoint"
gguf-new-metadata = "scripts:gguf_new_metadata_entrypoint"
gguf-template = "scripts:gguf_template_entrypoint"
1 change: 1 addition & 0 deletions gguf-py/scripts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
from .gguf_dump import main as gguf_dump_entrypoint
from .gguf_set_metadata import main as gguf_set_metadata_entrypoint
from .gguf_new_metadata import main as gguf_new_metadata_entrypoint
from .gguf_template import main as gguf_template_entrypoint
160 changes: 160 additions & 0 deletions gguf-py/scripts/gguf_template.py
teleprint-me marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
#!/usr/bin/env python3
"""
teleprint-me marked this conversation as resolved.
Show resolved Hide resolved
gguf_template.py - example file to extract the chat template from the models metadata
"""

from __future__ import annotations

import argparse
import logging
import os
import sys
from pathlib import Path

import jinja2
teleprint-me marked this conversation as resolved.
Show resolved Hide resolved
import jinja2.sandbox

# Necessary to load the local gguf package
if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / 'gguf-py').exists():
sys.path.insert(0, str(Path(__file__).parent.parent))

from gguf.constants import Keys
from gguf.gguf_reader import GGUFReader # noqa: E402

# Configure logging
logger = logging.getLogger("gguf-chat-template")


def get_chat_template(model_file: str) -> str:
reader = GGUFReader(model_file)

# Available keys
logger.debug("Detected model metadata!")
logger.debug("Outputting available model fields:")
for key in reader.fields.keys():
logger.debug(key)

# Access the 'chat_template' field directly using its key
chat_template_field = reader.fields.get(Keys.Tokenizer.CHAT_TEMPLATE)

if chat_template_field:
# Extract the chat template string from the field
chat_template_memmap = chat_template_field.parts[-1]
chat_template_string = chat_template_memmap.tobytes().decode("utf-8")
return chat_template_string
else:
logger.error("Chat template field not found in model metadata.")
return ""


def render_chat_template(
chat_template: str,
bos_token: str,
eos_token: str,
add_generation_prompt: bool = False,
render_template: bool = False,
) -> str:
"""
Display the chat template to standard output, optionally formatting it using Jinja2.

Args:
chat_template (str): The extracted chat template.
render_template (bool, optional): Whether to format the template using Jinja2. Defaults to False.
"""
logger.debug(f"Render Template: {render_template}")
logger.debug(f"Add Generation Prompt: {add_generation_prompt}")

if render_template:
# Render the formatted template using Jinja2 with a context that includes 'bos_token' and 'eos_token'
env = jinja2.sandbox.ImmutableSandboxedEnvironment(
trim_blocks=True, lstrip_blocks=True
)
template = env.from_string(chat_template)

messages = [
{"role": "system", "content": "I am a helpful assistant."},
{"role": "user", "content": "Hello!"},
{"role": "assistant", "content": "Hello! How may I assist you today?"},
{"role": "user", "content": "Can you tell me what pickled mayonnaise is?"},
{"role": "assistant", "content": "Certainly! What would you like to know about it?"},
{"role": "user", "content": "Is it just regular mayonnaise with vinegar or something else?"},
]

try:
formatted_template = template.render(
messages=messages,
bos_token=bos_token,
eos_token=eos_token,
add_generation_prompt=add_generation_prompt,
)
except jinja2.exceptions.UndefinedError:
# system message is incompatible with set format
formatted_template = template.render(
messages=messages[1:],
bos_token=bos_token,
eos_token=eos_token,
add_generation_prompt=add_generation_prompt,
)

return formatted_template
else:
# Display the raw template
return chat_template


# Example usage:
def main():
parser = argparse.ArgumentParser(
description="Extract chat template from a GGUF model file"
)
parser.add_argument("model_file", type=str, help="Path to the GGUF model file")
parser.add_argument(
"-r",
"--render-template",
action="store_true",
help="Render the chat template using Jinja2. Default is False.",
)
parser.add_argument(
"-b",
"--bos",
default="<s>",
help="Set a bos special token. Default is '<s>'.",
)
parser.add_argument(
"-e",
"--eos",
default="</s>",
help="Set a eos special token. Default is '</s>'.",
)
parser.add_argument(
"-g",
"--agp",
action="store_true",
help="Add generation prompt. Default is False.",
)
parser.add_argument(
"-v",
"--verbose",
action="store_true",
help="Output model keys. Default is False.",
)
args = parser.parse_args()

if args.verbose:
logging.basicConfig(level=logging.DEBUG)
else:
logging.basicConfig(level=logging.INFO)

chat_template = get_chat_template(args.model_file)
rendered_template = render_chat_template(
chat_template,
args.bos,
args.eos,
add_generation_prompt=args.agp,
render_template=args.render_template,
)
print(rendered_template) # noqa: NP100


if __name__ == "__main__":
main()
Loading
Loading