Skip to content

Commit

Permalink
Added LangChain integration example to sample notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
noamgat committed Oct 11, 2023
1 parent e1fa09d commit 3b886d1
Showing 1 changed file with 239 additions and 23 deletions.
262 changes: 239 additions & 23 deletions samples/colab_llama2_enforcer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
"metadata": {},
"outputs": [],
"source": [
"!pip install transformers torch lm-format-enforcer huggingface_hub accelerate bitsandbytes cpm_kernels\n",
"!pip install transformers torch lm-format-enforcer huggingface_hub accelerate bitsandbytes cpm_kernels langchain langchain-experimental \n",
"!huggingface-cli login\n",
"\n",
"# When running from source / developing the library, use this instead\n",
Expand All @@ -41,7 +41,7 @@
"# import sys\n",
"# import os\n",
"# sys.path.append(os.path.abspath('..'))\n",
"## os.environ['CUDA_LAUNCH_BLOCKING'] = '1'"
"# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'"
]
},
{
Expand All @@ -55,8 +55,7 @@
"text": [
"/home/noamgat/mambaforge/envs/commentranker/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n",
"Loading checkpoint shards: 100%|██████████| 2/2 [05:29<00:00, 164.58s/it]\n",
"Using pad_token, but it is not set yet.\n"
"Loading checkpoint shards: 100%|██████████| 2/2 [03:52<00:00, 116.06s/it]\n"
]
}
],
Expand All @@ -80,9 +79,9 @@
"else:\n",
" raise Exception('GPU not available')\n",
"tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
"if tokenizer.pad_token is None:\n",
"if tokenizer.pad_token_id is None:\n",
" # Required for batching example\n",
" tokenizer.pad_token = tokenizer.eos_token \n"
" tokenizer.pad_token_id = tokenizer.eos_token_id \n"
]
},
{
Expand Down Expand Up @@ -129,7 +128,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 4,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -194,7 +193,7 @@
" sequences = output['sequences']\n",
" # skip_prompt=True doesn't work consistenly, so we hack around it.\n",
" string_outputs = [tokenizer.decode(sequence, skip_special_tokens=True) for sequence in sequences]\n",
" string_outputs = [string_output.replace(prompt[3:], ' ') for string_output, prompt in zip(string_outputs, prompts)]\n",
" string_outputs = [string_output.replace(prompt[3:], '') for string_output, prompt in zip(string_outputs, prompts)]\n",
" if parser and not is_multi_message:\n",
" enforced_scores_dict = output.enforced_scores\n",
" enforced_scores = pd.DataFrame(enforced_scores_dict)\n",
Expand Down Expand Up @@ -264,7 +263,7 @@
"data": {
"text/markdown": [
"```\n",
" { \"first_name\": \"Michael\", \"last_name\": \"Jordan\", \"year_of_birth\": 1963, \"num_seasons_in_nba\": 15 }\n",
" { \"first_name\": \"Michael\", \"last_name\": \"Jordan\", \"year_of_birth\": 1963, \"num_seasons_in_nba\": 15 }\n",
"```"
],
"text/plain": [
Expand All @@ -290,14 +289,18 @@
"data": {
"text/markdown": [
"```\n",
" Of course! I'd be happy to provide information about Michael Jordan using the provided JSON schema.\n",
" Of course! I'd be happy to help you with information about Michael Jordan. Here is the information you requested, formatted according to the JSON schema you provided:\n",
"{\n",
"\"title\": \"AnswerFormat\",\n",
"\"type\": \"object\",\n",
"\"properties\": {\n",
"\"first_name\": {\"title\": \"First Name\", \"type\": \"string\"},\n",
"\"last_name\": {\"title\": \"Last Name\", \"type\": \"string\"},\n",
"\"year_of_birth\": {\"title\": \"Year Of Birth\", \"\n",
"\"first_name\": {\n",
"\"title\": \"First Name\",\n",
"\"type\": \"string\",\n",
"\"description\": \"Michael Jordan's first name\"\n",
"},\n",
"\"last_name\": {\n",
"\n",
"```"
],
"text/plain": [
Expand Down Expand Up @@ -357,7 +360,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 6,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -946,7 +949,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 7,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -991,7 +994,7 @@
"data": {
"text/markdown": [
"```\n",
" Thank you for asking! Michael Jordan was born in the year 1963.\n",
" Michael Jordan was born in the year 1963.\n",
"```"
],
"text/plain": [
Expand Down Expand Up @@ -1038,7 +1041,7 @@
"data": {
"text/markdown": [
"```\n",
" Michael Jordan was born in 1963.\n",
" Michael Jordan was born in 1963.\n",
"```"
],
"text/plain": [
Expand Down Expand Up @@ -1268,7 +1271,7 @@
"data": {
"text/markdown": [
"```\n",
" The answer is 1963\n",
"The answer is 1963\n",
"```"
],
"text/plain": [
Expand Down Expand Up @@ -1469,7 +1472,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 8,
"metadata": {},
"outputs": [
{
Expand All @@ -1488,7 +1491,7 @@
"data": {
"text/markdown": [
"```\n",
" { \"first_name\": \"Michael\", \"last_name\": \"Jordan\", \"year_of_birth\": 1963, \"num_seasons_in_nba\": 15 }\n",
" { \"first_name\": \"Michael\", \"last_name\": \"Jordan\", \"year_of_birth\": 1963, \"num_seasons_in_nba\": 15 }\n",
"```"
],
"text/plain": [
Expand All @@ -1502,7 +1505,7 @@
"data": {
"text/markdown": [
"```\n",
" { \"first_name\": \"Timothy\", \"last_name\": \"Duncan\", \"year_of_birth\": 1976, \"num_seasons_in_nba\": 19 }\n",
" { \"first_name\": \"Timothy\", \"last_name\": \"Duncan\", \"year_of_birth\": 1976, \"num_seasons_in_nba\": 19 }\n",
"```"
],
"text/plain": [
Expand All @@ -1516,7 +1519,7 @@
"data": {
"text/markdown": [
"```\n",
" { \"first_name\": \"Kobe\", \"last_name\": \"Bryant\", \"year_of_birth\": 1978, \"num_seasons_in_nba\": 20 }\n",
" { \"first_name\": \"Kobe\", \"last_name\": \"Bryant\", \"year_of_birth\": 1978, \"num_seasons_in_nba\": 20 }\n",
"```"
],
"text/plain": [
Expand All @@ -1530,7 +1533,7 @@
"data": {
"text/markdown": [
"```\n",
" { \"first_name\": \"Kareem\", \"last_name\": \"Abdul-Jabbar\", \"year_of_birth\": 1947, \"num_seasons_in_nba\": 20 }\n",
" { \"first_name\": \"Kareem Abdul-Jabbar\", \"last_name\": \"Abdul-Jabbar\", \"year_of_birth\": 1947, \"num_seasons_in_nba\": 20 }\n",
"```"
],
"text/plain": [
Expand All @@ -1551,6 +1554,219 @@
"for result in results:\n",
" display_content(result)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# LangChain Integration\n",
"\n",
"The next cell contains an implementation of a LangChain LLM that combines huggingface transformers with the LM Format enforcer. It was inspired by the [JsonFormer Decoder](https://github.com/langchain-ai/langchain/blob/ce0019b646e4a70ac01d8daa37575bfd364cb3a6/libs/experimental/langchain_experimental/llms/jsonformer_decoder.py) implementation, but also supports batch generation.\n",
"\n",
"It works using the JSON use case (not the Regular Expression use case)\n",
"\n",
"Hopefully it will eventually be part of the LangChain project."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"\"\"\"Experimental implementation of lm-format-enforcer wrapped LLM.\"\"\"\n",
"from __future__ import annotations\n",
"from email.policy import default\n",
"\n",
"import json\n",
"from typing import TYPE_CHECKING, Any, List, Optional, cast\n",
"\n",
"from langchain.callbacks.manager import CallbackManagerForLLMRun\n",
"from langchain.llms.huggingface_pipeline import HuggingFacePipeline\n",
"\n",
"from langchain_experimental.pydantic_v1 import Field, root_validator\n",
"from transformers.pipelines import Text2TextGenerationPipeline\n",
"\n",
"if TYPE_CHECKING:\n",
" import lmformatenforcer\n",
"\n",
"\n",
"def import_lmformatenforcer() -> lmformatenforcer:\n",
" \"\"\"Lazily import lmformatenforcer.\"\"\"\n",
" try:\n",
" import lmformatenforcer\n",
" except ImportError:\n",
" raise ImportError(\n",
" \"Could not import lmformatenforcer python package. \"\n",
" \"Please install it with `pip install lm-format-enforcer`.\"\n",
" )\n",
" return lmformatenforcer\n",
"\n",
"\n",
"class LMFormatEnforcer(HuggingFacePipeline):\n",
" \"\"\"LMFormatEnforcer wrapped LLM using HuggingFace Pipeline API.\n",
"\n",
" This pipeline is experimental and not yet stable.\n",
" \"\"\"\n",
"\n",
" json_schema: dict = Field(..., description=\"The JSON Schema to complete.\")\n",
" prompt_suffix_format: str = Field(description=\"The format string to append to the prompt. It will receive the JSON Schema as a parameter.\", \n",
" default=\"You MUST answer using the following json schema: {}\")\n",
" @root_validator\n",
" def check_lmformatenforcer_installation(cls, values: dict) -> dict:\n",
" import_lmformatenforcer()\n",
" return values\n",
"\n",
" def _generate(\n",
" self,\n",
" prompts: List[str],\n",
" stop: Optional[List[str]] = None,\n",
" run_manager: Optional[CallbackManagerForLLMRun] = None,\n",
" **kwargs: Any,\n",
" ):\n",
" # We integrate lmformatenforcer by adding a prefix_allowed_tokens_fn.\n",
" # It has to be done on each call, because the prefix function is stateful, so it needs to be reinitialized.\n",
" if 'prefix_allowed_tokens_fn' in self.pipeline._forward_params:\n",
" raise RuntimeError(\"prefix_allowed_tokens_fn is already set, unsupported by LMFormatEnforcer.\")\n",
" lmformatenforcer = import_lmformatenforcer()\n",
" prompts = [prompt + \"\\n\" + self.prompt_suffix_format.format(json.dumps(self.json_schema)) \n",
" for prompt in prompts]\n",
" parser = lmformatenforcer.JsonSchemaParser(self.json_schema)\n",
" pipeline = cast(Text2TextGenerationPipeline, self.pipeline)\n",
" prefix_function = lmformatenforcer.build_transformers_prefix_allowed_tokens_fn(pipeline.tokenizer, parser)\n",
" self.pipeline._forward_params['prefix_allowed_tokens_fn'] = prefix_function\n",
"\n",
" result = super()._generate(\n",
" prompts,\n",
" stop=stop,\n",
" run_manager=run_manager,\n",
" **kwargs,\n",
" )\n",
" \n",
" del self.pipeline._forward_params['prefix_allowed_tokens_fn']\n",
" return result\n",
" "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Using the LMFormatEnforcer\n",
"\n",
"Now, we can set up a pipeline, create the enforcer from it, and use it to generate text. Both direct call and batch modes are supported."
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/markdown": [
"**Call mode**"
],
"text/plain": [
"<IPython.core.display.Markdown object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/markdown": [
"```\n",
" {\"first_name\": \"Michael\", \"last_name\": \"Jordan\", \"year_of_birth\": 1963, \"num_seasons_in_nba\": 15}\n",
"```"
],
"text/plain": [
"<IPython.core.display.Markdown object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/markdown": [
"**Batched mode**"
],
"text/plain": [
"<IPython.core.display.Markdown object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/markdown": [
"```\n",
" {\"first_name\": \"Michael\", \"last_name\": \"Jordan\", \"year_of_birth\": 1963, \"num_seasons_in_nba\": 15}\n",
"```"
],
"text/plain": [
"<IPython.core.display.Markdown object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/markdown": [
"```\n",
" {\"first_name\": \"Larry\", \"last_name\": \"Bird\", \"year_of_birth\": 1956, \"num_seasons_in_nba\": 13}\n",
"```"
],
"text/plain": [
"<IPython.core.display.Markdown object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/markdown": [
"```\n",
" {\"first_name\": \"Tim\", \"last_name\": \"Duncan\", \"year_of_birth\": 1976, \"num_seasons_in_nba\": 19}\n",
"```"
],
"text/plain": [
"<IPython.core.display.Markdown object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from transformers import pipeline\n",
"\n",
"# We create a transformers pipeline to avoid loading the model twice, but we could also use LMFormatEnforcer.from_model_id()\n",
"pipe = pipeline(\"text-generation\", model=model, tokenizer=tokenizer, max_new_tokens=100)\n",
"langchain_pipeline = LMFormatEnforcer(pipeline=pipe, json_schema=AnswerFormat.schema())\n",
"\n",
"DEFAULT_SYSTEM_PROMPT = \"\"\"\\\n",
"You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\\n\\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\\\n",
"\"\"\"\n",
"\n",
"players = ['Michael Jordan', 'Larry Bird', 'Tim Duncan']\n",
"question = 'Please give me information about {}.'\n",
"prompts = [get_prompt(question.format(player), DEFAULT_SYSTEM_PROMPT) for player in players]\n",
"\n",
"display_header('Call mode')\n",
"result = langchain_pipeline(prompts[0])\n",
"display_content(result)\n",
"\n",
"display_header('Batched mode')\n",
"results = langchain_pipeline.generate(prompts)\n",
"for generation in results.generations:\n",
" display_content(generation[0].text)"
]
}
],
"metadata": {
Expand Down

0 comments on commit 3b886d1

Please sign in to comment.