Skip to content

Commit

Permalink
+llava-mistral models
Browse files Browse the repository at this point in the history
  • Loading branch information
matatonic committed Apr 1, 2024
1 parent 93bc2bc commit 37db248
Show file tree
Hide file tree
Showing 8 changed files with 100 additions and 16 deletions.
4 changes: 3 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
FROM python:3-slim
FROM python:3.11-slim

RUN mkdir -p /app
WORKDIR /app
COPY requirements.txt .
RUN pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.6/flash_attn-2.5.6+cu122torch2.2cxx11abiFALSE-cp311-cp311-linux_x86_64.whl
RUN pip install -r requirements.txt

COPY *.py .
COPY backend /app/backend
CMD python vision.py
12 changes: 8 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@ An OpenAI API compatible vision server, it functions like `gpt-4-vision-preview`

Backend Model support:
- [X] Moondream2 [vikhyatk/moondream2](https://huggingface.co/vikhyatk/moondream2) *(only a single image and single question currently supported)
- [X] Llava [llava-hf/llava-v1.6-mistral-7b-hf](https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf) *(mistral only for now, single image/question)
- [ ] Deepseek-VL - (in progress) [deepseek-ai/deepseek-vl-7b-chat](https://huggingface.co/deepseek-ai/deepseek-vl-7b-chat)
- [ ] ...

Version: 0.1.0
Version: 0.2.0


API Documentation
Expand All @@ -34,16 +35,19 @@ Usage
-----

```
usage: vision.py [-h] [-m MODEL] [-b BACKEND] [-d DEVICE] [-P PORT] [-H HOST] [--preload]
usage: vision.py [-h] [-m MODEL] [-b BACKEND] [--load-in-4bit] [--load-in-8bit] [--use-flash-attn] [-d DEVICE] [-P PORT] [-H HOST] [--preload]
OpenedAI Vision API Server
options:
-h, --help show this help message and exit
-m MODEL, --model MODEL
The model to use, Ex. deepseek-ai/deepseek-vl-7b-chat (default: vikhyatk/moondream2)
The model to use, Ex. llava-hf/llava-v1.6-mistral-7b-hf (default: vikhyatk/moondream2)
-b BACKEND, --backend BACKEND
The backend to use (moondream, deepseek) (default: moondream)
The backend to use (moondream, llava) (default: moondream)
--load-in-4bit load in 4bit (default: False)
--load-in-8bit load in 8bit (default: False)
--use-flash-attn Use Flash Attention 2 (default: False)
-d DEVICE, --device DEVICE
Set the torch device for the model. Ex. cuda:1 (default: auto)
-P PORT, --port PORT Server tcp port (default: 5006)
Expand Down
58 changes: 58 additions & 0 deletions backend/llava.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
import torch

# Assumes mistral prompt format!!
# model_id = "llava-hf/llava-v1.6-mistral-7b-hf"

from vision_qna import VisionQnABase

class VisionQnA(VisionQnABase):
model_name: str = "llava"

def __init__(self, model_id: str, device: str, extra_params = {}):
self.device = self.select_device() if device == 'auto' else device

params = {
'pretrained_model_name_or_path': model_id,
'torch_dtype': torch.float32 if device == 'cpu' else torch.float16,
'low_cpu_mem_usage': True,
}
if extra_params.get('load_in_4bit', False):
load_in_4bit_params = {
'bnb_4bit_compute_dtype': torch.float32 if device == 'cpu' else torch.float16,
'load_in_4bit': True,
}
params.update(load_in_4bit_params)

if extra_params.get('load_in_8bit', False):
load_in_8bit_params = {
'load_in_8bit': True,
}
params.update(load_in_8bit_params)

# 'use_flash_attention_2': True,
if extra_params.get('use_flash_attn', False):
flash_attn_params = {
"attn_implementation": "flash_attention_2",
}
params.update(flash_attn_params)

self.processor = LlavaNextProcessor.from_pretrained(model_id)
self.model = LlavaNextForConditionalGeneration.from_pretrained(**params)
if not (extra_params.get('load_in_4bit', False) or extra_params.get('load_in_8bit', False)):
self.model.to(self.device)

print(f"Loaded on device: {self.model.device} with dtype: {self.model.dtype}")

async def single_question(self, image_url: str, prompt: str) -> str:
image = await self.url_to_image(image_url)

# prepare image and text prompt, using the appropriate prompt template
prompt = f"[INST] <image>\n{prompt} [/INST]"
inputs = self.processor(prompt, image, return_tensors="pt").to(self.device)

# autoregressively complete prompt
output = self.model.generate(**inputs, max_new_tokens=300)
answer = self.processor.decode(output[0], skip_special_tokens=True)
id = answer.rfind('[/INST]')
return answer[id + 8:]
8 changes: 3 additions & 5 deletions backend/moondream.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class VisionQnA(VisionQnABase):
model_name: str = "moondream2"
revision: str = '2024-03-13'

def __init__(self, model_id: str, device: str):
def __init__(self, model_id: str, device: str, extra_params = {}):
if device == 'auto':
device = self.select_device()

Expand All @@ -18,12 +18,10 @@ def __init__(self, model_id: str, device: str):
'revision': self.revision,
'torch_dtype': torch.float32 if device == 'cpu' else torch.float16,
}

params.update(extra_params)

self.model = AutoModelForCausalLM.from_pretrained(**params).to(device)
self.tokenizer = AutoTokenizer.from_pretrained(model_id)

def select_device(self):
return 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'

async def single_question(self, image_url: str, prompt: str) -> str:
image = await self.url_to_image(image_url)
Expand Down
2 changes: 1 addition & 1 deletion docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ services:
- ./hf_home:/app/hf_home
ports:
- 5006:5006
command: ["python", "vision.py", "--host", "0.0.0.0", "--port", "5006"]
command: ["python", "vision.py", "--host", "0.0.0.0", "--port", "5006", "--backend", "llava", "--model", "llava-hf/llava-v1.6-mistral-7b-hf"]
runtime: nvidia
deploy:
resources:
Expand Down
7 changes: 6 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,9 @@ fastapi
# moondream
timm
einops
transformers>=4.39.*

transformers>=4.39.*
torch==2.2.*
accelerate
bitsandbytes
flash_attn
19 changes: 16 additions & 3 deletions vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,12 @@ def parse_args(argv=None):
description='OpenedAI Vision API Server',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)

parser.add_argument('-m', '--model', action='store', default="vikhyatk/moondream2", help="The model to use, Ex. deepseek-ai/deepseek-vl-7b-chat")
parser.add_argument('-b', '--backend', action='store', default="moondream", help="The backend to use (moondream, deepseek)")
parser.add_argument('-m', '--model', action='store', default="vikhyatk/moondream2", help="The model to use, Ex. llava-hf/llava-v1.6-mistral-7b-hf")
parser.add_argument('-b', '--backend', action='store', default="moondream", help="The backend to use (moondream, llava)")
#'load_in_4bit', 'load_in_8bit', 'use_flash_attn'
parser.add_argument('--load-in-4bit', action='store_true', help="load in 4bit")
parser.add_argument('--load-in-8bit', action='store_true', help="load in 8bit")
parser.add_argument('--use-flash-attn', action='store_true', help="Use Flash Attention 2")
parser.add_argument('-d', '--device', action='store', default="auto", help="Set the torch device for the model. Ex. cuda:1")
parser.add_argument('-P', '--port', action='store', default=5006, type=int, help="Server tcp port")
parser.add_argument('-H', '--host', action='store', default='localhost', help="Host to listen on, Ex. 0.0.0.0")
Expand All @@ -87,7 +91,16 @@ def parse_args(argv=None):

print(f"Loading VisionQnA[{args.backend}] with {args.model}")
backend = importlib.import_module(f'backend.{args.backend}')
vision_qna = backend.VisionQnA(args.model, args.device)

extra_params = {}
if args.load_in_4bit:
extra_params['load_in_4bit'] = True
if args.load_in_8bit:
extra_params['load_in_8bit'] = True
if args.use_flash_attn:
extra_params['use_flash_attn'] = True

vision_qna = backend.VisionQnA(args.model, args.device, extra_params)

if args.preload:
sys.exit(0)
Expand Down
6 changes: 5 additions & 1 deletion vision_qna.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
import requests
from datauri import DataURI
from PIL import Image
import torch

class VisionQnABase:
model_name: str = None

def __init__(self, model_id: str, device: str):
def __init__(self, model_id: str, device: str, extra_params = {}):
pass

async def url_to_image(self, img_url: str) -> Image.Image:
Expand All @@ -19,6 +20,9 @@ async def url_to_image(self, img_url: str) -> Image.Image:
img_data = DataURI(img_url).data

return Image.open(io.BytesIO(img_data)).convert("RGB")

def select_device(self):
return 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'

async def single_question(self, image_url: str, prompt: str) -> str:
pass

0 comments on commit 37db248

Please sign in to comment.