diff --git a/Dockerfile b/Dockerfile index bfa36a3..4088e7d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,9 +1,11 @@ -FROM python:3-slim +FROM python:3.11-slim RUN mkdir -p /app WORKDIR /app COPY requirements.txt . +RUN pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.6/flash_attn-2.5.6+cu122torch2.2cxx11abiFALSE-cp311-cp311-linux_x86_64.whl RUN pip install -r requirements.txt + COPY *.py . COPY backend /app/backend CMD python vision.py diff --git a/README.md b/README.md index f975113..d79649d 100644 --- a/README.md +++ b/README.md @@ -9,10 +9,11 @@ An OpenAI API compatible vision server, it functions like `gpt-4-vision-preview` Backend Model support: - [X] Moondream2 [vikhyatk/moondream2](https://huggingface.co/vikhyatk/moondream2) *(only a single image and single question currently supported) +- [X] Llava [llava-hf/llava-v1.6-mistral-7b-hf](https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf) *(mistral only for now, single image/question) - [ ] Deepseek-VL - (in progress) [deepseek-ai/deepseek-vl-7b-chat](https://huggingface.co/deepseek-ai/deepseek-vl-7b-chat) - [ ] ... -Version: 0.1.0 +Version: 0.2.0 API Documentation @@ -34,16 +35,19 @@ Usage ----- ``` -usage: vision.py [-h] [-m MODEL] [-b BACKEND] [-d DEVICE] [-P PORT] [-H HOST] [--preload] +usage: vision.py [-h] [-m MODEL] [-b BACKEND] [--load-in-4bit] [--load-in-8bit] [--use-flash-attn] [-d DEVICE] [-P PORT] [-H HOST] [--preload] OpenedAI Vision API Server options: -h, --help show this help message and exit -m MODEL, --model MODEL - The model to use, Ex. deepseek-ai/deepseek-vl-7b-chat (default: vikhyatk/moondream2) + The model to use, Ex. llava-hf/llava-v1.6-mistral-7b-hf (default: vikhyatk/moondream2) -b BACKEND, --backend BACKEND - The backend to use (moondream, deepseek) (default: moondream) + The backend to use (moondream, llava) (default: moondream) + --load-in-4bit load in 4bit (default: False) + --load-in-8bit load in 8bit (default: False) + --use-flash-attn Use Flash Attention 2 (default: False) -d DEVICE, --device DEVICE Set the torch device for the model. Ex. cuda:1 (default: auto) -P PORT, --port PORT Server tcp port (default: 5006) diff --git a/backend/llava.py b/backend/llava.py new file mode 100644 index 0000000..a8bcffe --- /dev/null +++ b/backend/llava.py @@ -0,0 +1,58 @@ +from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration +import torch + +# Assumes mistral prompt format!! +# model_id = "llava-hf/llava-v1.6-mistral-7b-hf" + +from vision_qna import VisionQnABase + +class VisionQnA(VisionQnABase): + model_name: str = "llava" + + def __init__(self, model_id: str, device: str, extra_params = {}): + self.device = self.select_device() if device == 'auto' else device + + params = { + 'pretrained_model_name_or_path': model_id, + 'torch_dtype': torch.float32 if device == 'cpu' else torch.float16, + 'low_cpu_mem_usage': True, + } + if extra_params.get('load_in_4bit', False): + load_in_4bit_params = { + 'bnb_4bit_compute_dtype': torch.float32 if device == 'cpu' else torch.float16, + 'load_in_4bit': True, + } + params.update(load_in_4bit_params) + + if extra_params.get('load_in_8bit', False): + load_in_8bit_params = { + 'load_in_8bit': True, + } + params.update(load_in_8bit_params) + +# 'use_flash_attention_2': True, + if extra_params.get('use_flash_attn', False): + flash_attn_params = { + "attn_implementation": "flash_attention_2", + } + params.update(flash_attn_params) + + self.processor = LlavaNextProcessor.from_pretrained(model_id) + self.model = LlavaNextForConditionalGeneration.from_pretrained(**params) + if not (extra_params.get('load_in_4bit', False) or extra_params.get('load_in_8bit', False)): + self.model.to(self.device) + + print(f"Loaded on device: {self.model.device} with dtype: {self.model.dtype}") + + async def single_question(self, image_url: str, prompt: str) -> str: + image = await self.url_to_image(image_url) + + # prepare image and text prompt, using the appropriate prompt template + prompt = f"[INST] \n{prompt} [/INST]" + inputs = self.processor(prompt, image, return_tensors="pt").to(self.device) + + # autoregressively complete prompt + output = self.model.generate(**inputs, max_new_tokens=300) + answer = self.processor.decode(output[0], skip_special_tokens=True) + id = answer.rfind('[/INST]') + return answer[id + 8:] diff --git a/backend/moondream.py b/backend/moondream.py index 4e5fa36..4007fd1 100644 --- a/backend/moondream.py +++ b/backend/moondream.py @@ -8,7 +8,7 @@ class VisionQnA(VisionQnABase): model_name: str = "moondream2" revision: str = '2024-03-13' - def __init__(self, model_id: str, device: str): + def __init__(self, model_id: str, device: str, extra_params = {}): if device == 'auto': device = self.select_device() @@ -18,12 +18,10 @@ def __init__(self, model_id: str, device: str): 'revision': self.revision, 'torch_dtype': torch.float32 if device == 'cpu' else torch.float16, } - + params.update(extra_params) + self.model = AutoModelForCausalLM.from_pretrained(**params).to(device) self.tokenizer = AutoTokenizer.from_pretrained(model_id) - - def select_device(self): - return 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu' async def single_question(self, image_url: str, prompt: str) -> str: image = await self.url_to_image(image_url) diff --git a/docker-compose.yml b/docker-compose.yml index db376d1..fee4174 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -11,7 +11,7 @@ services: - ./hf_home:/app/hf_home ports: - 5006:5006 - command: ["python", "vision.py", "--host", "0.0.0.0", "--port", "5006"] + command: ["python", "vision.py", "--host", "0.0.0.0", "--port", "5006", "--backend", "llava", "--model", "llava-hf/llava-v1.6-mistral-7b-hf"] runtime: nvidia deploy: resources: diff --git a/requirements.txt b/requirements.txt index b73db69..9102d6f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,4 +6,9 @@ fastapi # moondream timm einops -transformers>=4.39.* \ No newline at end of file + +transformers>=4.39.* +torch==2.2.* +accelerate +bitsandbytes +flash_attn \ No newline at end of file diff --git a/vision.py b/vision.py index 3955c5f..fcd1119 100644 --- a/vision.py +++ b/vision.py @@ -74,8 +74,12 @@ def parse_args(argv=None): description='OpenedAI Vision API Server', formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('-m', '--model', action='store', default="vikhyatk/moondream2", help="The model to use, Ex. deepseek-ai/deepseek-vl-7b-chat") - parser.add_argument('-b', '--backend', action='store', default="moondream", help="The backend to use (moondream, deepseek)") + parser.add_argument('-m', '--model', action='store', default="vikhyatk/moondream2", help="The model to use, Ex. llava-hf/llava-v1.6-mistral-7b-hf") + parser.add_argument('-b', '--backend', action='store', default="moondream", help="The backend to use (moondream, llava)") + #'load_in_4bit', 'load_in_8bit', 'use_flash_attn' + parser.add_argument('--load-in-4bit', action='store_true', help="load in 4bit") + parser.add_argument('--load-in-8bit', action='store_true', help="load in 8bit") + parser.add_argument('--use-flash-attn', action='store_true', help="Use Flash Attention 2") parser.add_argument('-d', '--device', action='store', default="auto", help="Set the torch device for the model. Ex. cuda:1") parser.add_argument('-P', '--port', action='store', default=5006, type=int, help="Server tcp port") parser.add_argument('-H', '--host', action='store', default='localhost', help="Host to listen on, Ex. 0.0.0.0") @@ -87,7 +91,16 @@ def parse_args(argv=None): print(f"Loading VisionQnA[{args.backend}] with {args.model}") backend = importlib.import_module(f'backend.{args.backend}') - vision_qna = backend.VisionQnA(args.model, args.device) + + extra_params = {} + if args.load_in_4bit: + extra_params['load_in_4bit'] = True + if args.load_in_8bit: + extra_params['load_in_8bit'] = True + if args.use_flash_attn: + extra_params['use_flash_attn'] = True + + vision_qna = backend.VisionQnA(args.model, args.device, extra_params) if args.preload: sys.exit(0) diff --git a/vision_qna.py b/vision_qna.py index ceeb1a9..02af4f1 100644 --- a/vision_qna.py +++ b/vision_qna.py @@ -3,11 +3,12 @@ import requests from datauri import DataURI from PIL import Image +import torch class VisionQnABase: model_name: str = None - def __init__(self, model_id: str, device: str): + def __init__(self, model_id: str, device: str, extra_params = {}): pass async def url_to_image(self, img_url: str) -> Image.Image: @@ -19,6 +20,9 @@ async def url_to_image(self, img_url: str) -> Image.Image: img_data = DataURI(img_url).data return Image.open(io.BytesIO(img_data)).convert("RGB") + + def select_device(self): + return 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu' async def single_question(self, image_url: str, prompt: str) -> str: pass