Skip to content

Commit

Permalink
v0.3.6 - Llama.cpp integration
Browse files Browse the repository at this point in the history
  • Loading branch information
noamgat committed Oct 19, 2023
1 parent e2ea9cf commit b9a14d9
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 50 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ print(result)

## Capabilities / Advantages

- Works with any Python language model and tokenizer. Already supports [transformers](https://github.com/huggingface/transformers), [LangChain](https://docs.langchain.com/docs/) and [vLLM](https://github.com/noamgat/lm-format-enforcer/blob/main/samples/colab_vllm_integration.ipynb). Can be adapted to others.
- Works with any Python language model and tokenizer. Already supports [transformers](https://github.com/huggingface/transformers), [LangChain](https://docs.langchain.com/docs/), [llama.cpp](https://github.com/noamgat/lm-format-enforcer/blob/main/samples/colab_llamacpppython_integration.ipynb) and [vLLM](https://github.com/noamgat/lm-format-enforcer/blob/main/samples/colab_vllm_integration.ipynb). Can be adapted to others.
- Supports batched generation and beam searches - each input / beam can have different tokens filtered at every timestep
- Supports both JSON Schema and Regular Expression formats
- Supports both required and optional fields in JSON schemas
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "lm-format-enforcer"
version = "0.3.5"
version = "0.3.6"
description = "Enforce the output format (JSON Schema, Regex etc) of a language model"
authors = ["Noam Gat <[email protected]>"]
license = "MIT"
Expand Down
94 changes: 53 additions & 41 deletions samples/colab_llamacpppython_integration.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,18 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 7,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The autoreload extension is already loaded. To reload it, use:\n",
" %reload_ext autoreload\n"
]
}
],
"source": [
"!pip install llama-cpp-python lm-format-enforcer huggingface-hub\n",
"\n",
Expand All @@ -53,14 +62,17 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"llama_model_loader: loaded meta data with 19 key-value pairs and 291 tensors from /mnt/e/manual/llama-2-7b-chat.Q5_K_M.gguf (version GGUF V2 (latest))\n",
"/home/noamgat/mambaforge/envs/llamacpppy/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n",
"llama-2-7b-chat.Q5_K_M.gguf: 100%|██████████| 4.78G/4.78G [02:19<00:00, 34.3MB/s]\n",
"llama_model_loader: loaded meta data with 19 key-value pairs and 291 tensors from /home/noamgat/huggingface/hub/models--TheBloke--Llama-2-7b-Chat-GGUF/snapshots/191239b3e26b2882fb562ffccdd1cf0f65402adb/llama-2-7b-chat.Q5_K_M.gguf (version GGUF V2 (latest))\n",
"llama_model_loader: - tensor 0: token_embd.weight q5_K [ 4096, 32000, 1, 1 ]\n",
"llama_model_loader: - tensor 1: blk.0.attn_norm.weight f32 [ 4096, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 2: blk.0.ffn_down.weight q6_K [ 11008, 4096, 1, 1 ]\n",
Expand Down Expand Up @@ -439,7 +451,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 9,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -502,7 +514,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -526,7 +538,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -548,7 +560,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -576,14 +588,14 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/tmp/ipykernel_401088/4169945469.py:13: PydanticDeprecatedSince20: The `schema_json` method is deprecated; use `model_json_schema` and json.dumps instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.4/migration/\n",
"/tmp/ipykernel_409047/2888867492.py:13: PydanticDeprecatedSince20: The `schema_json` method is deprecated; use `model_json_schema` and json.dumps instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.4/migration/\n",
" question_with_schema = f'{question}{AnswerFormat.schema_json()}'\n"
]
},
Expand Down Expand Up @@ -622,7 +634,7 @@
{
"data": {
"text/markdown": [
"**Answer, With json schema enforcing:**"
"**Answer, Without json schema enforcing:**"
],
"text/plain": [
"<IPython.core.display.Markdown object>"
Expand All @@ -635,24 +647,27 @@
"name": "stderr",
"output_type": "stream",
"text": [
"/tmp/ipykernel_401088/4169945469.py:21: PydanticDeprecatedSince20: The `schema` method is deprecated; use `model_json_schema` instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.4/migration/\n",
" result = llamacpp_with_character_level_parser(llm, prompt, JsonSchemaParser(AnswerFormat.schema()))\n",
"\n",
"llama_print_timings: load time = 16791.06 ms\n",
"llama_print_timings: sample time = 17.44 ms / 53 runs ( 0.33 ms per token, 3038.64 tokens per second)\n",
"llama_print_timings: prompt eval time = 16791.00 ms / 294 tokens ( 57.11 ms per token, 17.51 tokens per second)\n",
"llama_print_timings: eval time = 4649.13 ms / 52 runs ( 89.41 ms per token, 11.18 tokens per second)\n",
"llama_print_timings: total time = 21662.33 ms\n"
"llama_print_timings: load time = 16716.39 ms\n",
"llama_print_timings: sample time = 33.72 ms / 93 runs ( 0.36 ms per token, 2757.76 tokens per second)\n",
"llama_print_timings: prompt eval time = 16716.30 ms / 294 tokens ( 56.86 ms per token, 17.59 tokens per second)\n",
"llama_print_timings: eval time = 10525.06 ms / 92 runs ( 114.40 ms per token, 8.74 tokens per second)\n",
"llama_print_timings: total time = 27395.89 ms\n"
]
},
{
"data": {
"text/markdown": [
"```\n",
" { \"first_name\": \"Michael\", \"last_name\": \"Jordan\", \"year_of_birth\": 1963, \"num_seasons_in_nba\": 15 }\n",
"\n",
"\n",
" Of course! I'd be happy to provide information about Michael Jordan using the provided JSON schema.\n",
"{\n",
"\"first_name\": \"Michael\",\n",
"\"last_name\": \"Jordan\",\n",
"\"year_of_birth\": 1963,\n",
"\"num_seasons_in_nba\": 15\n",
"}\n",
"\n",
"I hope this helps! Let me know if you have any other questions.\n",
"```"
],
"text/plain": [
Expand All @@ -665,7 +680,7 @@
{
"data": {
"text/markdown": [
"**Answer, Without json schema enforcing:**"
"**Answer, With json schema enforcing:**"
],
"text/plain": [
"<IPython.core.display.Markdown object>"
Expand All @@ -678,27 +693,24 @@
"name": "stderr",
"output_type": "stream",
"text": [
"/tmp/ipykernel_409047/2888867492.py:24: PydanticDeprecatedSince20: The `schema` method is deprecated; use `model_json_schema` instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.4/migration/\n",
" result = llamacpp_with_character_level_parser(llm, prompt, JsonSchemaParser(AnswerFormat.schema()))\n",
"Llama.generate: prefix-match hit\n",
"\n",
"llama_print_timings: load time = 16791.06 ms\n",
"llama_print_timings: sample time = 34.38 ms / 99 runs ( 0.35 ms per token, 2879.58 tokens per second)\n",
"llama_print_timings: load time = 16716.39 ms\n",
"llama_print_timings: sample time = 17.67 ms / 52 runs ( 0.34 ms per token, 2943.01 tokens per second)\n",
"llama_print_timings: prompt eval time = 0.00 ms / 1 tokens ( 0.00 ms per token, inf tokens per second)\n",
"llama_print_timings: eval time = 8800.92 ms / 99 runs ( 88.90 ms per token, 11.25 tokens per second)\n",
"llama_print_timings: total time = 8956.60 ms\n"
"llama_print_timings: eval time = 5051.36 ms / 52 runs ( 97.14 ms per token, 10.29 tokens per second)\n",
"llama_print_timings: total time = 5253.00 ms\n"
]
},
{
"data": {
"text/markdown": [
"```\n",
" Of course! I'd be happy to provide information about Michael Jordan using the provided JSON schema. Here is the information for you:\n",
"{\n",
"\"first_name\": \"Michael\",\n",
"\"last_name\": \"Jordan\",\n",
"\"year_of_birth\": 1963,\n",
"\"num_seasons_in_nba\": 15\n",
"}\n",
"I hope this helps! Let me know if you have any other questions.\n",
" { \"first_name\": \"Michael\", \"last_name\": \"Jordan\", \"year_of_birth\": 1963, \"num_seasons_in_nba\": 15 }\n",
"\n",
"\n",
"```"
],
"text/plain": [
Expand Down Expand Up @@ -728,22 +740,22 @@
"display_header(\"Prompt:\")\n",
"display_content(prompt)\n",
"\n",
"display_header(\"Answer, With json schema enforcing:\")\n",
"\n",
"result = llamacpp_with_character_level_parser(llm, prompt, JsonSchemaParser(AnswerFormat.schema()))\n",
"display_content(result)\n",
"\n",
"display_header(\"Answer, Without json schema enforcing:\")\n",
"result = llamacpp_with_character_level_parser(llm, prompt, None)\n",
"display_content(result)\n",
"\n"
"\n",
"display_header(\"Answer, With json schema enforcing:\")\n",
"result = llamacpp_with_character_level_parser(llm, prompt, JsonSchemaParser(AnswerFormat.schema()))\n",
"display_content(result)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As you can see, the enforced output matches the required schema, while the unenforced does not. We have successfully integrated with llama.cpp!"
"As you can see, the enforced output matches the required schema, while the unenforced does not. We have successfully integrated with llama.cpp!\n",
"\n",
"Ending note - the last cell probably took quite a long time to run. This is due to this notebook using CPU inference. LM Format Enforcer's runtime footprint is negligible compared to the model's runtime."
]
}
],
Expand Down
9 changes: 2 additions & 7 deletions samples/colab_vllm_integration.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@
"source": [
"## Setting up the prompt for the specific language model\n",
"\n",
"We set up the prompting style according to the demo at https://huggingface.co/spaces/huggingface-projects/llama-2-13b-chat/blob/main/app.py . We simplify the implementation a bit as we don't need chat history for this demo."
"We set up the prompting style according to the [Llama2 demo](https://huggingface.co/spaces/huggingface-projects/llama-2-13b-chat/blob/main/app.py). We simplify the implementation a bit as we don't need chat history for this demo."
]
},
{
Expand All @@ -279,12 +279,7 @@
"\"\"\"\n",
"\n",
"def get_prompt(message: str, system_prompt: str = DEFAULT_SYSTEM_PROMPT) -> str:\n",
" texts = [f'<s>[INST] <<SYS>>\\n{system_prompt}\\n<</SYS>>\\n\\n']\n",
" # The first user input is _not_ stripped\n",
" do_strip = False\n",
" message = message.strip() if do_strip else message\n",
" texts.append(f'{message} [/INST]')\n",
" return ''.join(texts)"
" return f'<s>[INST] <<SYS>>\\n{system_prompt}\\n<</SYS>>\\n\\n{message} [/INST]'"
]
},
{
Expand Down

0 comments on commit b9a14d9

Please sign in to comment.