Skip to content

Commit

Permalink
0.40.0 +smolvlm, +paligemma2, -deprec, updates
Browse files Browse the repository at this point in the history
  • Loading branch information
matatonic committed Dec 7, 2024
1 parent 6e7327d commit 3cc9051
Show file tree
Hide file tree
Showing 11 changed files with 293 additions and 159 deletions.
7 changes: 2 additions & 5 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,16 @@ RUN --mount=type=cache,target=/root/.cache/pip pip install --upgrade pip

WORKDIR /app
RUN git clone https://github.com/TIGER-AI-Lab/Mantis.git --single-branch /app/Mantis && \
git clone https://github.com/togethercomputer/Dragonfly --single-branch /app/Dragonfly && \
git clone https://github.com/baaivision/Emu3 --single-branch /app/Emu3

COPY requirements.txt .
ARG VERSION=latest
RUN if [ "$VERSION" = "alt" ]; then echo "transformers==4.41.2" >> requirements.txt; else echo "transformers>=4.45.2" >> requirements.txt ; fi
RUN if [ "$VERSION" = "alt" ]; then echo "transformers==4.41.2" >> requirements.txt; else echo "transformers>=4.47.0" >> requirements.txt ; fi
RUN --mount=type=cache,target=/root/.cache/pip pip install -U -r requirements.txt

WORKDIR /app/Mantis
RUN --mount=type=cache,target=/root/.cache/pip pip install --no-deps -e .

WORKDIR /app/Dragonfly
RUN --mount=type=cache,target=/root/.cache/pip pip install --no-deps -e .

WORKDIR /app

COPY *.py model_conf_tests.json README.md LICENSE /app/
Expand All @@ -31,6 +27,7 @@ ARG GROUP_ID=1000
ENV GROUP_ID=${GROUP_ID}
RUN groupadd -g ${GROUP_ID} openedai && \
useradd -r -u ${USER_ID} -g ${GROUP_ID} -M -d /app openedai
RUN chown openedai:openedai /app # for .triton, .config/matplotlib

USER openedai
ENV CLI_COMMAND="python vision.py"
Expand Down
42 changes: 32 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ Can't decide which to use? See the [OpenVLM Leaderboard](https://huggingface.co/
<summary>Full list of supported models</summary>

- [X] [AIDC-AI](https://huggingface.co/AIDC-AI)
- - [X] [Ovis1.6-Llama3.2-3B](https://huggingface.co/AIDC-AI/Ovis1.6-Llama3.2-3B)
- - [X] [Ovis1.6-Gemma2-9B](https://huggingface.co/AIDC-AI/Ovis1.6-Gemma2-9B)
- - [X] [Ovis1.6-Gemma2-27B](https://huggingface.co/AIDC-AI/Ovis1.6-Gemma2-27B)
- - [X] [Ovis1.5-Gemma2-9B](https://huggingface.co/AIDC-AI/Ovis1.5-Gemma2-9B)
- - [X] [Ovis1.5-Llama3-8B](https://huggingface.co/AIDC-AI/Ovis1.5-Llama3-8B)
- [X] [Ai2](https://huggingface.co/allenai)
Expand All @@ -23,6 +25,7 @@ Can't decide which to use? See the [OpenVLM Leaderboard](https://huggingface.co/
- - [X] [Molmo-7B-D-0924](https://huggingface.co/allenai/Molmo-7B-D-0924)
- - [X] [MolmoE-1B-0924](https://huggingface.co/allenai/MolmoE-1B-0924)
- [X] [BAAI](https://huggingface.co/BAAI/)
- - [X] [BAAI/Aquila-VL-2B-llava-qwen](https://huggingface.co/BAAI/Aquila-VL-2B-llava-qwen)
- - [X] [BAAI/Bunny-v1_0-2B-zh](https://huggingface.co/BAAI/Bunny-v1_0-2B-zh)
- - [X] [BAAI/Bunny-v1_0-3B-zh](https://huggingface.co/BAAI/Bunny-v1_0-3B-zh)
- - [X] [BAAI/Bunny-v1_0-3B](https://huggingface.co/BAAI/Bunny-v1_0-3B)
Expand All @@ -45,19 +48,24 @@ Can't decide which to use? See the [OpenVLM Leaderboard](https://huggingface.co/
- - [X] [joy-caption-alpha-two](https://huggingface.co/spaces/fancyfeast/joy-caption-alpha-two) (with experimental multi-image support)
- - [X] [joy-caption-pre-alpha](https://huggingface.co/spaces/fancyfeast/joy-caption-pre-alpha) (caption only)
- [X] [fuyu-8b](https://huggingface.co/adept/fuyu-8b) [pretrain]
- [X] [HuggingFaceM4/idefics2](https://huggingface.co/HuggingFaceM4)
- [X] [Google](https://huggingface.co/google)
- - [X] [paligemma2-3b](https://huggingface.co/google/paligemma2-3b-ft-docci-448)
- - [X] [paligemma2-10b](https://huggingface.co/google/paligemma2-10b-ft-docci-448)
- [X] [HuggingFaceM4](https://huggingface.co/HuggingFaceM4)
- - [X] [idefics2-8b](https://huggingface.co/HuggingFaceM4/idefics2-8b) (wont gpu split, alternate docker only)
- - [X] [idefics2-8b-AWQ](https://huggingface.co/HuggingFaceM4/idefics2-8b-AWQ) (wont gpu split, alternate docker only)
- - [X] [idefics2-8b-chatty](https://huggingface.co/HuggingFaceM4/idefics2-8b-chatty) (wont gpu split, alternate docker only)
- - [X] [idefics2-8b-chatty-AWQ](https://huggingface.co/HuggingFaceM4/idefics2-8b-chatty-AWQ) (wont gpu split, alternate docker only)
- [X] [HuggingFaceTB](https://huggingface.co/HuggingFaceTB)
- - [X] [SmolVLM-Instruct](https://huggingface.co/HuggingFaceTB/SmolVLM-Instruct)
- [X] [InternLM](https://huggingface.co/internlm/)
- - [X] [XComposer2-2d5-7b](https://huggingface.co/internlm/internlm-xcomposer2d5-7b) (wont gpu split)
- - [X] [XComposer2-4KHD-7b](https://huggingface.co/internlm/internlm-xcomposer2-4khd-7b) (wont gpu split)
- - [X] [XComposer2-7b](https://huggingface.co/internlm/internlm-xcomposer2-7b) [finetune] (wont gpu split)
- - [X] [XComposer2-7b-4bit](https://huggingface.co/internlm/internlm-xcomposer2-7b-4bit) (not recommended)
- - [X] [XComposer2-VL](https://huggingface.co/internlm/internlm-xcomposer2-vl-7b) [pretrain] (wont gpu split)
- - [X] [XComposer2-VL-4bit](https://huggingface.co/internlm/internlm-xcomposer2-vl-7b-4bit)
- - [X] [XComposer2-VL-1.8b](https://huggingface.co/internlm/internlm-xcomposer2-vl-1_8b)
- - [X] [XComposer2-7b](https://huggingface.co/internlm/internlm-xcomposer2-7b) [finetune] (wont gpu split) (0.39.2 only)
- - [X] [XComposer2-7b-4bit](https://huggingface.co/internlm/internlm-xcomposer2-7b-4bit) (not recommended) (0.39.2 only)
- - [X] [XComposer2-VL](https://huggingface.co/internlm/internlm-xcomposer2-vl-7b) [pretrain] (wont gpu split) (0.39.2 only)
- - [X] [XComposer2-VL-4bit](https://huggingface.co/internlm/internlm-xcomposer2-vl-7b-4bit) (0.39.2 only)
- - [X] [XComposer2-VL-1.8b](https://huggingface.co/internlm/internlm-xcomposer2-vl-1_8b) (0.39.2 only)
- [X] [LMMs-Lab](https://huggingface.co/lmms-lab)
- - [X] [llava-onevision-qwen2-0.5b-ov](https://huggingface.co/lmms-lab/llava-onevision-qwen2-0.5b-ov)
- - [X] [llava-onevision-qwen2-7b-ov](https://huggingface.co/lmms-lab/llava-onevision-qwen2-7b-ov)
Expand All @@ -82,7 +90,7 @@ Can't decide which to use? See the [OpenVLM Leaderboard](https://huggingface.co/
- [X] [Mistral AI](https://huggingface.co/mistralai)
- - [X] [Pixtral-12B](https://huggingface.co/mistralai/Pixtral-12B-2409)
- [X] [mx262/MiniMonkey](https://huggingface.co/mx262/MiniMonkey)
- [X] [nvidia/NVLM-D-72B](https://huggingface.co/nvidia/NVLM-D-72B)
- [X] [nvidia/NVLM-D-72B](https://huggingface.co/nvidia/NVLM-D-72B) (0.39.2 only)
- [X] [omlab/omchat-v2.0-13B-single-beta_hf](https://huggingface.co/omlab/omchat-v2.0-13B-single-beta_hf) (alt docker)
- [X] [openbmb](https://huggingface.co/openbmb)
- - [X] [MiniCPM-V-2_6](https://huggingface.co/openbmb/MiniCPM-V-2_6) (video not supported yet)
Expand Down Expand Up @@ -121,8 +129,8 @@ Can't decide which to use? See the [OpenVLM Leaderboard](https://huggingface.co/
- - [X] [Mantis-8B-clip-llama3](https://huggingface.co/TIGER-Lab/Mantis-8B-clip-llama3) (wont gpu split, alt docker)
- - [X] [Mantis-8B-Fuyu](https://huggingface.co/TIGER-Lab/Mantis-8B-Fuyu) (wont gpu split)
- [X] [Together.ai](https://huggingface.co/togethercomputer)
- - [X] [Llama-3-8B-Dragonfly-v1](https://huggingface.co/togethercomputer/Llama-3-8B-Dragonfly-v1)
- - [X] [Llama-3-8B-Dragonfly-Med-v1](https://huggingface.co/togethercomputer/Llama-3-8B-Dragonfly-Med-v1)
- - [X] [Llama-3-8B-Dragonfly-v1](https://huggingface.co/togethercomputer/Llama-3-8B-Dragonfly-v1) (0.39.2 only)
- - [X] [Llama-3-8B-Dragonfly-Med-v1](https://huggingface.co/togethercomputer/Llama-3-8B-Dragonfly-Med-v1) (0.39.2 only)
- [X] [qihoo360](https://huggingface.co/qihoo360)
- - [X] [360VL-8B](https://huggingface.co/qihoo360/360VL-8B) (alt docker)
- - [X] [360VL-70B](https://huggingface.co/qihoo360/360VL-70B) (untested)
Expand All @@ -132,12 +140,16 @@ Can't decide which to use? See the [OpenVLM Leaderboard](https://huggingface.co/
- [X] [qresearch](https://huggingface.co/qresearch/)
- - [X] [llama-3-vision-alpha-hf](https://huggingface.co/qresearch/llama-3-vision-alpha-hf) (wont gpu split)
- [X] [Qwen](https://huggingface.co/Qwen/)
- - [X] [Qwen2-VL-72B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-72B-Instruct) (untested)
- - [X] [Qwen2-VL-72B-Instruct-AWQ](https://huggingface.co/Qwen/Qwen2-VL-72B-Instruct-AWQ)
- - [X] [Qwen2-VL-72B-Instruct-GPTQ-Int4](https://huggingface.co/Qwen/Qwen2-VL-72B-Instruct-GPTQ-Int4)
- - [X] [Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct)
- - [X] [Qwen2-VL-7B-Instruct-AWQ](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct-AWQ)
- - [X] [Qwen2-VL-7B-Instruct-GPTQ-Int4](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct-GPTQ-Int4)
- - [X] [Qwen2-VL-2B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct)
- - [X] [Qwen2-VL-2B-Instruct-AWQ](https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct-AWQ)
- - [X] [Qwen-VL-Chat](https://huggingface.co/Qwen/Qwen-VL-Chat)
- - [X] [Qwen2-VL-2B-Instruct-GPTQ-Int4](https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4)
- [X] [Qwen-VL-Chat](https://huggingface.co/Qwen/Qwen-VL-Chat)
- [X] [stepfun-ai/GOT-OCR2_0](https://huggingface.co/stepfun-ai/GOT-OCR2_0) (ocr only model)
- [X] [vikhyatk](https://huggingface.co/vikhyatk)
- - [X] [moondream2](https://huggingface.co/vikhyatk/moondream2)
Expand All @@ -159,6 +171,16 @@ If you can't find your favorite model, you can [open a new issue](https://github

## Recent updates

Version 0.40.0

- new model support: AIDC-AI/Ovis1.6-Llama3.2-3B, AIDC-AI/Ovis1.6-Gemma2-27B
- new model support: BAAI/Aquila-VL-2B-llava-qwen
- new model support: HuggingFaceTB/SmolVLM-Instruct
- new model support: google/paligemma2 family of models (very limited instruct/chat training so far)
- Qwen2-VL: unpin Qwen2-VL-7B & remove Qwen hacks, GTPT-Int4/8 working again (still slow - why?)
- pin bitsandbytes==0.44.1
- ⚠️ DEPRECATED MODELS (use the `0.39.2` docker image for support of these models): internlm-xcomposer2-7b, internlm-xcomposer2-7b-4bit, internlm-xcomposer2-vl-1_8b, internlm-xcomposer2-vl-7b, internlm-xcomposer2-vl-7b-4bit, nvidia/NVLM-D-72B, Llama-3-8B-Dragonfly-Med-v1, Llama-3-8B-Dragonfly-v1

Version 0.39.2

- performance: use float16 with Qwen2 AWQ, small performance improvement
Expand Down
2 changes: 2 additions & 0 deletions backend/llavanextgit.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
# lmms-lab/llava-onevision-qwen2-72b-ov
# lmms-lab/llava-onevision-qwen2-72b-si

# BAAI/Aquila-VL-2B-llava-qwen

import warnings
warnings.filterwarnings("ignore")

Expand Down
2 changes: 2 additions & 0 deletions backend/ovis16.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

from vision_qna import *

# AIDC-AI/Ovis1.6-Llama3.2-3B
# AIDC-AI/Ovis1.6-Gemma2-9B
# AIDC-AI/Ovis1.6-Gemma2-27B

IMAGE_TOKEN = "<image>"

Expand Down
61 changes: 61 additions & 0 deletions backend/paligemma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# "google/paligemma2-3b-ft-docci-448"
# "google/paligemma2-10b-ft-docci-448"
# "google/paligemma2-28b-pt-896" - pretrain

from transformers import PaliGemmaProcessor, PaliGemmaForConditionalGeneration
from vision_qna import *

class VisionQnA(VisionQnABase):
model_name: str = "paligemma2"
format: str = "gemma" # doesn't seem to actually be instruction trained
visual_layers: List[str] = ["vision_tower", "multi_modal_projector"]

def __init__(self, model_id: str, device: str, device_map: str = 'auto', extra_params = {}, format = None):
super().__init__(model_id, device, device_map, extra_params, format)

if not format:
self.format = guess_model_format(model_id)

for i in ['trust_remote_code']:
del self.params[i]

self.model = PaliGemmaForConditionalGeneration.from_pretrained(**self.params).eval()
self.processor = PaliGemmaProcessor.from_pretrained(model_id)

# bitsandbytes already moves the model to the device, so we don't need to do it again.
if not (extra_params.get('load_in_4bit', False) or extra_params.get('load_in_8bit', False)):
self.model = self.model.to(self.device)

self.loaded_banner()

async def stream_chat_with_images(self, request: ImageChatRequest) -> AsyncGenerator[str, None]:
images, prompt = await prompt_from_messages(request.messages, self.format)

if len(images) < 1:
images = [ await url_to_image(black_pixel_url) ]
prompt = "<image>\n" + prompt

# Instruct the model to create a caption in English
#prompt = "caption en"
inputs = self.processor(text=prompt, images=images, return_tensors="pt").to(dtype=self.dtype, device=self.device)

default_params = {
'do_sample': False,
# 'eos_token_id': self.processor.tokenizer.eos_token_id,
# 'pad_token_id': self.processor.tokenizer.eos_token_id,
}

params = self.get_generation_params(request, default_params=default_params)

generation_kwargs = dict(
**inputs,
**params,
)

for new_text in threaded_streaming_generator(generate=self.model.generate, tokenizer=self.processor.tokenizer, generation_kwargs=generation_kwargs):
end = new_text.find(self.processor.tokenizer.eos_token)
if end == -1:
yield new_text
else:
yield new_text[:end]
break
18 changes: 8 additions & 10 deletions backend/qwen2-vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,17 @@
# Qwen/Qwen2-VL-7B-Instruct-AWQ
# Qwen/Qwen2-VL-7B-Instruct
# Qwen/Qwen2-VL-72B-Instruct-AWQ
# Qwen/Qwen2-VL-72B-Instruct
# Not recommended:
# X Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4
# X Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int8
# X Qwen/Qwen2-VL-7B-Instruct-GPTQ-Int4
# X Qwen/Qwen2-VL-7B-Instruct-GPTQ-Int8
# X Qwen/Qwen2-VL-72B-Instruct-GPTQ-Int4
# X Qwen/Qwen2-VL-72B-Instruct-GPTQ-Int8

# https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct-GPTQ-Int4
# Performance: for A100 80GB Qwen claim 30-40 T/s, I can't reproduce with this setup, I see more like 5-10 T/s.

class VisionQnA(VisionQnABase):
model_name: str = "qwen2-vl"
Expand All @@ -22,16 +29,13 @@ class VisionQnA(VisionQnABase):
def __init__(self, model_id: str, device: str, device_map: str = 'auto', extra_params = {}, format = None):
super().__init__(model_id, device, device_map, extra_params, format)

if 'awq' in model_id.lower() and self.dtype == torch.bfloat16:
if ('awq' in model_id.lower() or 'gptq' in model_id.lower()) and self.dtype == torch.bfloat16:
self.dtype = self.params['torch_dtype'] = torch.float16 # recommended

self.processor = AutoProcessor.from_pretrained(model_id)

del self.params['trust_remote_code']

if model_id == 'Qwen/Qwen2-VL-7B-Instruct-AWQ':
self.params['revision'] = '9d72ae62396aaa1817b006e07ddbbd121024f50d' # re: https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct-AWQ/discussions/4

self.model = Qwen2VLForConditionalGeneration.from_pretrained(**self.params).eval()

self.loaded_banner()
Expand All @@ -46,12 +50,6 @@ async def stream_chat_with_images(self, request: ImageChatRequest) -> AsyncGener
msg = { 'role': m.role, 'content': [] }
for c in m.content:
if c.type == 'image_url':
# hack around https://github.com/QwenLM/Qwen2-VL/issues/202'
if c.image_url.url.startswith('data:image'):
parts = c.image_url.url.split(';')
if parts[1].startswith('charset='):
c.image_url.url = parts[0] + ';' + parts[2]

msg['content'].extend([{'type': c.type, 'image': c.image_url.url}])
elif c.type == 'text':
msg['content'].extend([{'type': c.type, 'text': c.text}])
Expand Down
47 changes: 47 additions & 0 deletions backend/smolvlm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from transformers import AutoProcessor, AutoModelForVision2Seq

from vision_qna import *

# HuggingFaceTB/SmolVLM-Instruct

class VisionQnA(VisionQnABase):
model_name: str = "generic"
format: str = "internal"
visual_layers: List[str] = ["vision_model"]

def __init__(self, model_id: str, device: str, device_map: str = 'auto', extra_params = {}, format = None):
super().__init__(model_id, device, device_map, extra_params, format)

self.processor = AutoProcessor.from_pretrained(model_id)
self.model = AutoModelForVision2Seq.from_pretrained(**self.params).eval()

# bitsandbytes already moves the model to the device, so we don't need to do it again.
if not (extra_params.get('load_in_4bit', False) or extra_params.get('load_in_8bit', False)):
self.model = self.model.to(self.device)

self.loaded_banner()

async def stream_chat_with_images(self, request: ImageChatRequest) -> AsyncGenerator[str, None]:
images, messages = await images_hfmessages_from_messages(request.messages)
prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True)

if len(images) < 1:
images = [ await url_to_image(black_pixel_url) ]
prompt = "<image>\n" + prompt

inputs = self.processor(text=prompt, images=images, return_tensors="pt").to(self.device)

params = self.get_generation_params(request)

generation_kwargs = dict(
**inputs,
**params,
)

for new_text in threaded_streaming_generator(generate=self.model.generate, tokenizer=self.processor.tokenizer, generation_kwargs=generation_kwargs):
end = new_text.find(self.processor.tokenizer.eos_token)
if end == -1:
yield new_text
else:
yield new_text[:end]
break
Loading

0 comments on commit 3cc9051

Please sign in to comment.