Skip to content

Commit

Permalink
Adding visual LLM data extraction example
Browse files Browse the repository at this point in the history
  • Loading branch information
noamgat committed Oct 16, 2024
1 parent ad0175d commit 343b790
Show file tree
Hide file tree
Showing 2 changed files with 259 additions and 0 deletions.
259 changes: 259 additions & 0 deletions samples/colab_llama32_vision_enforcer.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,259 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Visual Data Extraction using LM Format Enforcer\n",
"\n",
"<a target=\"_blank\" href=\"https://colab.research.google.com/github/noamgat/lm-format-enforcer/blob/main/samples/colab_llama32_vision_enforcer.ipynb\">\n",
" <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
"</a>\n",
"\n",
"This notebook shows how you can integrate [LM Format Enforcer](https://github.com/noamgat/lm-format-enforcer) with the [Llama 3.2 Vision](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision) model. It shows how you can use LMFE to extract structured output by querying images.\n",
"\n",
"\n",
"## Setting up the COLAB runtime (user action required)\n",
"\n",
"This colab-friendly notebook is targeted at demoing the enforcer on visual LLMs. We will be using LLAMA3.2. It can run on a free GPU on Google Colab with fp4 quantization.\n",
"Make sure that your runtime is set to GPU:\n",
"\n",
"Menu Bar -> Runtime -> Change runtime type -> T4 GPU (at the time of writing this notebook). [Guide here](https://www.codesansar.com/deep-learning/using-free-gpu-tpu-google-colab.htm)."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"!pip install transformers torch lm-format-enforcer huggingface_hub accelerate bitsandbytes cpm_kernels pillow\n",
"\n",
"# When running from source / developing the library, use this instead\n",
"# %load_ext autoreload\n",
"# %autoreload 2\n",
"# import sys\n",
"# import os\n",
"# sys.path.append(os.path.abspath('..'))\n",
"# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Gathering huggingface credentials (user action required)\n",
"\n",
"We begin by installing the dependencies. This demo uses llama3.2, so you will have to create a free huggingface account, request access to the llama2 model, create an access token, and insert it when executing the next cell will request it.\n",
"\n",
"Links:\n",
"\n",
"- [Request access to llama model](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision). See the \"Access Llama 3.2 on Hugging Face\" section.\n",
"- [Create huggingface access token](https://huggingface.co/settings/tokens)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from huggingface_hub import notebook_login\n",
"notebook_login()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we can create the model. This may take a few minutes. In this demo we assume you are running on free Colab, so we quaniize to 4 bit. If you are running with 24GB VRAM or more, you can set `run_in_4bit = False` to get better results."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/noamgat/mambaforge/envs/commentranker/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n",
"Loading checkpoint shards: 100%|██████████| 5/5 [00:04<00:00, 1.18it/s]\n"
]
}
],
"source": [
"import torch\n",
"from transformers import AutoConfig, AutoTokenizer, BitsAndBytesConfig\n",
"from transformers import MllamaForConditionalGeneration, AutoProcessor\n",
"\n",
"model_id = \"meta-llama/Llama-3.2-11B-Vision-Instruct\"\n",
"\n",
"run_in_4bit = True # Can set this to false if running with 24GB VRAM or more\n",
"bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type=\"nf4\", bnb_4bit_compute_dtype=torch.bfloat16) if run_in_4bit else None\n",
"\n",
"device = 'cuda'\n",
"\n",
"if torch.cuda.is_available():\n",
" config = AutoConfig.from_pretrained(model_id)\n",
" config.pretraining_tp = 1\n",
" model = MllamaForConditionalGeneration.from_pretrained(\n",
" model_id,\n",
" torch_dtype=torch.bfloat16,\n",
" device_map=\"cuda:0\",\n",
" quantization_config=bnb_config,\n",
" )\n",
" processor = AutoProcessor.from_pretrained(model_id)\n",
"else:\n",
" raise Exception('GPU not available')\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
"if tokenizer.pad_token_id is None:\n",
" # Required for batching example\n",
" tokenizer.pad_token_id = tokenizer.eos_token_id \n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If the previous cell executed successfully, you have propertly set up your Colab runtime and huggingface account!"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Definining the desired structured output\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"from pydantic import BaseModel\n",
"from lmformatenforcer import JsonSchemaParser\n",
"from typing import List\n",
"from lmformatenforcer.integrations.transformers import build_transformers_prefix_allowed_tokens_fn, build_token_enforcer_tokenizer_data\n",
"\n",
"class Brand(BaseModel):\n",
" brands: List[str]\n",
"\n",
"schema = Brand.model_json_schema()\n",
"parser = JsonSchemaParser(schema)\n",
"\n",
"tokenizer_data = build_token_enforcer_tokenizer_data(processor.tokenizer, model.vocab_size)\n",
"prefix_func = build_transformers_prefix_allowed_tokens_fn(tokenizer_data, parser)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Loading the image and generating structured output"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Generation took 2.1 seconds\n"
]
}
],
"source": [
"import json\n",
"from PIL import Image\n",
"from time import time\n",
"\n",
"user = '''Tell me what brands you can see on the provided screenshot, format it in json with the following format: '''\n",
"image_path = 'colab_llama32_vision_input.png'\n",
"image = Image.open(image_path)\n",
"\n",
"messages = [\n",
" {\"role\": \"user\", \"content\": [\n",
" {\"type\": \"image\"},\n",
" {\"type\": \"text\", \"text\": user+json.dumps(schema)}\n",
" ]}\n",
"]\n",
"input_text = processor.apply_chat_template(messages, add_generation_prompt=True)\n",
"\n",
"inputs = processor(image, input_text, return_tensors=\"pt\").to('cuda:0')\n",
"start_generation = inputs['input_ids'].shape[1]\n",
"\n",
"start_time = time()\n",
"output = model.generate(**inputs, max_new_tokens=512, prefix_allowed_tokens_fn=prefix_func)\n",
"result = processor.batch_decode(output[:, start_generation:], skip_special_tokens=True)[0]\n",
"duration = time() - start_time\n",
"print(f\"Generation took {duration:.1f} seconds\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Viewing the result"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{\n",
" \"brands\": [\n",
" \"Apple\",\n",
" \"Google\"\n",
" ]\n",
"}\n"
]
}
],
"source": [
"print(result)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As you can see, the result adheres to the JSON Schema, and gives us the desired information from the image!"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "commentranker",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
Binary file added samples/colab_llama32_vision_input.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 343b790

Please sign in to comment.