diff --git a/samples/colab_llama2_enforcer.ipynb b/samples/colab_llama2_enforcer.ipynb index 16171fd..0728dbe 100644 --- a/samples/colab_llama2_enforcer.ipynb +++ b/samples/colab_llama2_enforcer.ipynb @@ -32,7 +32,7 @@ "metadata": {}, "outputs": [], "source": [ - "!pip install transformers torch lm-format-enforcer huggingface_hub accelerate bitsandbytes cpm_kernels\n", + "!pip install transformers torch lm-format-enforcer huggingface_hub accelerate bitsandbytes cpm_kernels langchain langchain-experimental \n", "!huggingface-cli login\n", "\n", "# When running from source / developing the library, use this instead\n", @@ -41,7 +41,7 @@ "# import sys\n", "# import os\n", "# sys.path.append(os.path.abspath('..'))\n", - "## os.environ['CUDA_LAUNCH_BLOCKING'] = '1'" + "# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'" ] }, { @@ -55,8 +55,7 @@ "text": [ "/home/noamgat/mambaforge/envs/commentranker/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n", - "Loading checkpoint shards: 100%|██████████| 2/2 [05:29<00:00, 164.58s/it]\n", - "Using pad_token, but it is not set yet.\n" + "Loading checkpoint shards: 100%|██████████| 2/2 [03:52<00:00, 116.06s/it]\n" ] } ], @@ -80,9 +79,9 @@ "else:\n", " raise Exception('GPU not available')\n", "tokenizer = AutoTokenizer.from_pretrained(model_id)\n", - "if tokenizer.pad_token is None:\n", + "if tokenizer.pad_token_id is None:\n", " # Required for batching example\n", - " tokenizer.pad_token = tokenizer.eos_token \n" + " tokenizer.pad_token_id = tokenizer.eos_token_id \n" ] }, { @@ -129,7 +128,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -194,7 +193,7 @@ " sequences = output['sequences']\n", " # skip_prompt=True doesn't work consistenly, so we hack around it.\n", " string_outputs = [tokenizer.decode(sequence, skip_special_tokens=True) for sequence in sequences]\n", - " string_outputs = [string_output.replace(prompt[3:], ' ') for string_output, prompt in zip(string_outputs, prompts)]\n", + " string_outputs = [string_output.replace(prompt[3:], '') for string_output, prompt in zip(string_outputs, prompts)]\n", " if parser and not is_multi_message:\n", " enforced_scores_dict = output.enforced_scores\n", " enforced_scores = pd.DataFrame(enforced_scores_dict)\n", @@ -264,7 +263,7 @@ "data": { "text/markdown": [ "```\n", - " { \"first_name\": \"Michael\", \"last_name\": \"Jordan\", \"year_of_birth\": 1963, \"num_seasons_in_nba\": 15 }\n", + " { \"first_name\": \"Michael\", \"last_name\": \"Jordan\", \"year_of_birth\": 1963, \"num_seasons_in_nba\": 15 }\n", "```" ], "text/plain": [ @@ -290,14 +289,18 @@ "data": { "text/markdown": [ "```\n", - " Of course! I'd be happy to provide information about Michael Jordan using the provided JSON schema.\n", + " Of course! I'd be happy to help you with information about Michael Jordan. Here is the information you requested, formatted according to the JSON schema you provided:\n", "{\n", "\"title\": \"AnswerFormat\",\n", "\"type\": \"object\",\n", "\"properties\": {\n", - "\"first_name\": {\"title\": \"First Name\", \"type\": \"string\"},\n", - "\"last_name\": {\"title\": \"Last Name\", \"type\": \"string\"},\n", - "\"year_of_birth\": {\"title\": \"Year Of Birth\", \"\n", + "\"first_name\": {\n", + "\"title\": \"First Name\",\n", + "\"type\": \"string\",\n", + "\"description\": \"Michael Jordan's first name\"\n", + "},\n", + "\"last_name\": {\n", + "\n", "```" ], "text/plain": [ @@ -357,7 +360,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -946,7 +949,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -991,7 +994,7 @@ "data": { "text/markdown": [ "```\n", - " Thank you for asking! Michael Jordan was born in the year 1963.\n", + " Michael Jordan was born in the year 1963.\n", "```" ], "text/plain": [ @@ -1038,7 +1041,7 @@ "data": { "text/markdown": [ "```\n", - " Michael Jordan was born in 1963.\n", + " Michael Jordan was born in 1963.\n", "```" ], "text/plain": [ @@ -1268,7 +1271,7 @@ "data": { "text/markdown": [ "```\n", - " The answer is 1963\n", + "The answer is 1963\n", "```" ], "text/plain": [ @@ -1469,7 +1472,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -1488,7 +1491,7 @@ "data": { "text/markdown": [ "```\n", - " { \"first_name\": \"Michael\", \"last_name\": \"Jordan\", \"year_of_birth\": 1963, \"num_seasons_in_nba\": 15 }\n", + " { \"first_name\": \"Michael\", \"last_name\": \"Jordan\", \"year_of_birth\": 1963, \"num_seasons_in_nba\": 15 }\n", "```" ], "text/plain": [ @@ -1502,7 +1505,7 @@ "data": { "text/markdown": [ "```\n", - " { \"first_name\": \"Timothy\", \"last_name\": \"Duncan\", \"year_of_birth\": 1976, \"num_seasons_in_nba\": 19 }\n", + " { \"first_name\": \"Timothy\", \"last_name\": \"Duncan\", \"year_of_birth\": 1976, \"num_seasons_in_nba\": 19 }\n", "```" ], "text/plain": [ @@ -1516,7 +1519,7 @@ "data": { "text/markdown": [ "```\n", - " { \"first_name\": \"Kobe\", \"last_name\": \"Bryant\", \"year_of_birth\": 1978, \"num_seasons_in_nba\": 20 }\n", + " { \"first_name\": \"Kobe\", \"last_name\": \"Bryant\", \"year_of_birth\": 1978, \"num_seasons_in_nba\": 20 }\n", "```" ], "text/plain": [ @@ -1530,7 +1533,7 @@ "data": { "text/markdown": [ "```\n", - " { \"first_name\": \"Kareem\", \"last_name\": \"Abdul-Jabbar\", \"year_of_birth\": 1947, \"num_seasons_in_nba\": 20 }\n", + " { \"first_name\": \"Kareem Abdul-Jabbar\", \"last_name\": \"Abdul-Jabbar\", \"year_of_birth\": 1947, \"num_seasons_in_nba\": 20 }\n", "```" ], "text/plain": [ @@ -1551,6 +1554,219 @@ "for result in results:\n", " display_content(result)" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# LangChain Integration\n", + "\n", + "The next cell contains an implementation of a LangChain LLM that combines huggingface transformers with the LM Format enforcer. It was inspired by the [JsonFormer Decoder](https://github.com/langchain-ai/langchain/blob/ce0019b646e4a70ac01d8daa37575bfd364cb3a6/libs/experimental/langchain_experimental/llms/jsonformer_decoder.py) implementation, but also supports batch generation.\n", + "\n", + "It works using the JSON use case (not the Regular Expression use case)\n", + "\n", + "Hopefully it will eventually be part of the LangChain project." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "\"\"\"Experimental implementation of lm-format-enforcer wrapped LLM.\"\"\"\n", + "from __future__ import annotations\n", + "from email.policy import default\n", + "\n", + "import json\n", + "from typing import TYPE_CHECKING, Any, List, Optional, cast\n", + "\n", + "from langchain.callbacks.manager import CallbackManagerForLLMRun\n", + "from langchain.llms.huggingface_pipeline import HuggingFacePipeline\n", + "\n", + "from langchain_experimental.pydantic_v1 import Field, root_validator\n", + "from transformers.pipelines import Text2TextGenerationPipeline\n", + "\n", + "if TYPE_CHECKING:\n", + " import lmformatenforcer\n", + "\n", + "\n", + "def import_lmformatenforcer() -> lmformatenforcer:\n", + " \"\"\"Lazily import lmformatenforcer.\"\"\"\n", + " try:\n", + " import lmformatenforcer\n", + " except ImportError:\n", + " raise ImportError(\n", + " \"Could not import lmformatenforcer python package. \"\n", + " \"Please install it with `pip install lm-format-enforcer`.\"\n", + " )\n", + " return lmformatenforcer\n", + "\n", + "\n", + "class LMFormatEnforcer(HuggingFacePipeline):\n", + " \"\"\"LMFormatEnforcer wrapped LLM using HuggingFace Pipeline API.\n", + "\n", + " This pipeline is experimental and not yet stable.\n", + " \"\"\"\n", + "\n", + " json_schema: dict = Field(..., description=\"The JSON Schema to complete.\")\n", + " prompt_suffix_format: str = Field(description=\"The format string to append to the prompt. It will receive the JSON Schema as a parameter.\", \n", + " default=\"You MUST answer using the following json schema: {}\")\n", + " @root_validator\n", + " def check_lmformatenforcer_installation(cls, values: dict) -> dict:\n", + " import_lmformatenforcer()\n", + " return values\n", + "\n", + " def _generate(\n", + " self,\n", + " prompts: List[str],\n", + " stop: Optional[List[str]] = None,\n", + " run_manager: Optional[CallbackManagerForLLMRun] = None,\n", + " **kwargs: Any,\n", + " ):\n", + " # We integrate lmformatenforcer by adding a prefix_allowed_tokens_fn.\n", + " # It has to be done on each call, because the prefix function is stateful, so it needs to be reinitialized.\n", + " if 'prefix_allowed_tokens_fn' in self.pipeline._forward_params:\n", + " raise RuntimeError(\"prefix_allowed_tokens_fn is already set, unsupported by LMFormatEnforcer.\")\n", + " lmformatenforcer = import_lmformatenforcer()\n", + " prompts = [prompt + \"\\n\" + self.prompt_suffix_format.format(json.dumps(self.json_schema)) \n", + " for prompt in prompts]\n", + " parser = lmformatenforcer.JsonSchemaParser(self.json_schema)\n", + " pipeline = cast(Text2TextGenerationPipeline, self.pipeline)\n", + " prefix_function = lmformatenforcer.build_transformers_prefix_allowed_tokens_fn(pipeline.tokenizer, parser)\n", + " self.pipeline._forward_params['prefix_allowed_tokens_fn'] = prefix_function\n", + "\n", + " result = super()._generate(\n", + " prompts,\n", + " stop=stop,\n", + " run_manager=run_manager,\n", + " **kwargs,\n", + " )\n", + " \n", + " del self.pipeline._forward_params['prefix_allowed_tokens_fn']\n", + " return result\n", + " " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Using the LMFormatEnforcer\n", + "\n", + "Now, we can set up a pipeline, create the enforcer from it, and use it to generate text. Both direct call and batch modes are supported." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/markdown": [ + "**Call mode**" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/markdown": [ + "```\n", + " {\"first_name\": \"Michael\", \"last_name\": \"Jordan\", \"year_of_birth\": 1963, \"num_seasons_in_nba\": 15}\n", + "```" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/markdown": [ + "**Batched mode**" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/markdown": [ + "```\n", + " {\"first_name\": \"Michael\", \"last_name\": \"Jordan\", \"year_of_birth\": 1963, \"num_seasons_in_nba\": 15}\n", + "```" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/markdown": [ + "```\n", + " {\"first_name\": \"Larry\", \"last_name\": \"Bird\", \"year_of_birth\": 1956, \"num_seasons_in_nba\": 13}\n", + "```" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/markdown": [ + "```\n", + " {\"first_name\": \"Tim\", \"last_name\": \"Duncan\", \"year_of_birth\": 1976, \"num_seasons_in_nba\": 19}\n", + "```" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from transformers import pipeline\n", + "\n", + "# We create a transformers pipeline to avoid loading the model twice, but we could also use LMFormatEnforcer.from_model_id()\n", + "pipe = pipeline(\"text-generation\", model=model, tokenizer=tokenizer, max_new_tokens=100)\n", + "langchain_pipeline = LMFormatEnforcer(pipeline=pipe, json_schema=AnswerFormat.schema())\n", + "\n", + "DEFAULT_SYSTEM_PROMPT = \"\"\"\\\n", + "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\\n\\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\\\n", + "\"\"\"\n", + "\n", + "players = ['Michael Jordan', 'Larry Bird', 'Tim Duncan']\n", + "question = 'Please give me information about {}.'\n", + "prompts = [get_prompt(question.format(player), DEFAULT_SYSTEM_PROMPT) for player in players]\n", + "\n", + "display_header('Call mode')\n", + "result = langchain_pipeline(prompts[0])\n", + "display_content(result)\n", + "\n", + "display_header('Batched mode')\n", + "results = langchain_pipeline.generate(prompts)\n", + "for generation in results.generations:\n", + " display_content(generation[0].text)" + ] } ], "metadata": {