-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
matatonic
committed
Apr 2, 2024
1 parent
37db248
commit bc23c72
Showing
11 changed files
with
456 additions
and
180 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,58 +1,48 @@ | ||
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration | ||
import torch | ||
from transformers import LlavaProcessor, LlavaForConditionalGeneration | ||
from vision_qna import * | ||
|
||
# Assumes mistral prompt format!! | ||
# model_id = "llava-hf/llava-v1.6-mistral-7b-hf" | ||
|
||
from vision_qna import VisionQnABase | ||
# llava-hf/bakLlava-v1-hf # llama2 | ||
# llava-hf/llava-1.5-7b-hf # vicuna | ||
# llava-hf/llava-1.5-13b-hf # vicuna | ||
|
||
class VisionQnA(VisionQnABase): | ||
model_name: str = "llava" | ||
format: str = 'vicuna' | ||
|
||
def __init__(self, model_id: str, device: str, extra_params = {}): | ||
self.device = self.select_device() if device == 'auto' else device | ||
|
||
params = { | ||
'pretrained_model_name_or_path': model_id, | ||
'torch_dtype': torch.float32 if device == 'cpu' else torch.float16, | ||
'low_cpu_mem_usage': True, | ||
} | ||
if extra_params.get('load_in_4bit', False): | ||
load_in_4bit_params = { | ||
'bnb_4bit_compute_dtype': torch.float32 if device == 'cpu' else torch.float16, | ||
'load_in_4bit': True, | ||
} | ||
params.update(load_in_4bit_params) | ||
|
||
if extra_params.get('load_in_8bit', False): | ||
load_in_8bit_params = { | ||
'load_in_8bit': True, | ||
} | ||
params.update(load_in_8bit_params) | ||
|
||
# 'use_flash_attention_2': True, | ||
if extra_params.get('use_flash_attn', False): | ||
flash_attn_params = { | ||
"attn_implementation": "flash_attention_2", | ||
} | ||
params.update(flash_attn_params) | ||
|
||
self.processor = LlavaNextProcessor.from_pretrained(model_id) | ||
self.model = LlavaNextForConditionalGeneration.from_pretrained(**params) | ||
if not (extra_params.get('load_in_4bit', False) or extra_params.get('load_in_8bit', False)): | ||
self.model.to(self.device) | ||
def __init__(self, model_id: str, device: str, extra_params = {}, format = None): | ||
super().__init__(model_id, device, extra_params, format) | ||
|
||
if not format: | ||
# guess the format based on model id | ||
if 'mistral' in model_id.lower(): | ||
self.format = 'llama2' | ||
elif 'bakllava' in model_id.lower(): | ||
self.format = 'llama2' | ||
elif 'vicuna' in model_id.lower(): | ||
self.format = 'vicuna' | ||
|
||
self.processor = LlavaProcessor.from_pretrained(model_id) | ||
self.model = LlavaForConditionalGeneration.from_pretrained(**self.params) | ||
|
||
print(f"Loaded on device: {self.model.device} with dtype: {self.model.dtype}") | ||
|
||
async def single_question(self, image_url: str, prompt: str) -> str: | ||
image = await self.url_to_image(image_url) | ||
|
||
# prepare image and text prompt, using the appropriate prompt template | ||
prompt = f"[INST] <image>\n{prompt} [/INST]" | ||
inputs = self.processor(prompt, image, return_tensors="pt").to(self.device) | ||
async def chat_with_images(self, messages: list[Message], max_tokens: int) -> str: | ||
|
||
images, prompt = await prompt_from_messages(messages, self.format) | ||
inputs = self.processor(prompt, images, return_tensors="pt").to(self.device) | ||
|
||
# autoregressively complete prompt | ||
output = self.model.generate(**inputs, max_new_tokens=300) | ||
output = self.model.generate(**inputs, max_new_tokens=max_tokens) | ||
answer = self.processor.decode(output[0], skip_special_tokens=True) | ||
id = answer.rfind('[/INST]') | ||
return answer[id + 8:] | ||
|
||
if self.format in ['llama2', 'mistral']: | ||
idx = answer.rfind('[/INST]') + len('[/INST]') + 1 #+ len(images) | ||
return answer[idx:] | ||
elif self.format == 'vicuna': | ||
idx = answer.rfind('ASSISTANT:') + len('ASSISTANT:') + 1 #+ len(images) | ||
return answer[idx:] | ||
elif self.format == 'chatml': | ||
idx = answer.rfind('<|im_user|>assistant\n') + len('<|im_user|>assistant\n') + 1 #+ len(images) | ||
end_idx = answer.rfind('<|im_end|>') | ||
return answer[idx:end_idx] | ||
|
||
return answer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration | ||
from vision_qna import * | ||
|
||
# model_id = "llava-hf/llava-v1.6-mistral-7b-hf" # llama2 | ||
# model_id = "llava-hf/llava-v1.6-34b-hf" # chatml | ||
# model_id = "llava-hf/llava-v1.6-vicuna-13b-hf" # vicuna | ||
# model_id = "llava-hf/llava-v1.6-vicuna-7b-hf" # vicuna | ||
|
||
class VisionQnA(VisionQnABase): | ||
model_name: str = "llavanext" | ||
format: str = 'llama2' | ||
|
||
def __init__(self, model_id: str, device: str, extra_params = {}, format = None): | ||
super().__init__(model_id, device, extra_params, format) | ||
|
||
if not format: | ||
if 'mistral' in model_id: | ||
self.format = 'llama2' | ||
elif 'vicuna' in model_id: | ||
self.format = 'vicuna' | ||
elif 'v1.6-34b' in model_id: | ||
self.format = 'chatml' | ||
|
||
self.processor = LlavaNextProcessor.from_pretrained(model_id) | ||
self.model = LlavaNextForConditionalGeneration.from_pretrained(**self.params) | ||
|
||
print(f"Loaded on device: {self.model.device} with dtype: {self.model.dtype}") | ||
|
||
async def chat_with_images(self, messages: list[Message], max_tokens: int) -> str: | ||
|
||
images, prompt = await prompt_from_messages(messages, self.format) | ||
inputs = self.processor(prompt, images, return_tensors="pt").to(self.model.device) | ||
|
||
output = self.model.generate(**inputs, max_new_tokens=max_tokens) | ||
answer = self.processor.decode(output[0], skip_special_tokens=True) | ||
|
||
if self.format in ['llama2', 'mistral']: | ||
idx = answer.rfind('[/INST]') + len('[/INST]') + 1 #+ len(images) | ||
return answer[idx:] | ||
elif self.format == 'vicuna': | ||
idx = answer.rfind('ASSISTANT:') + len('ASSISTANT:') + 1 #+ len(images) | ||
return answer[idx:] | ||
elif self.format == 'chatml': | ||
# XXX This is broken with the 34b, extra spaces in the tokenizer | ||
# XXX You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers | ||
idx = answer.rfind('<|im_start|>assistant\n') + len('<|im_start|>assistant\n') + 1 #+ len(images) | ||
end_idx = answer.rfind('<|im_end|>') | ||
return answer[idx:end_idx] | ||
|
||
return answer |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import re | ||
from transformers import CodeGenTokenizerFast, AutoModelForCausalLM | ||
|
||
from vision_qna import * | ||
|
||
class VisionQnA(VisionQnABase): | ||
model_name: str = "moondream1" | ||
format: str = 'phi15' | ||
|
||
def __init__(self, model_id: str, device: str, extra_params = {}, format = None): | ||
super().__init__(model_id, device, extra_params, format) | ||
|
||
# not supported yet | ||
del self.params['device_map'] | ||
|
||
self.tokenizer = CodeGenTokenizerFast.from_pretrained(model_id) | ||
self.model = AutoModelForCausalLM.from_pretrained(**self.params, trust_remote_code=True) | ||
|
||
# bitsandbytes already moves the model to the device, so we don't need to do it again. | ||
if not (extra_params.get('load_in_4bit', False) or extra_params.get('load_in_8bit', False)): | ||
self.model.to(self.device) | ||
|
||
print(f"Loaded on device: {self.model.device} with dtype: {self.model.dtype}") | ||
|
||
async def chat_with_images(self, messages: list[Message], max_tokens: int) -> str: | ||
images, prompt = await prompt_from_messages(messages, self.format) | ||
encoded_images = self.model.encode_image(images[0]).to(self.device) | ||
|
||
# XXX currently broken here... | ||
""" | ||
File "hf_home/modules/transformers_modules/vikhyatk/moondream1/f6e9da68e8f1b78b8f3ee10905d56826db7a5802/modeling_phi.py", line 318, in forward | ||
padding_mask.masked_fill_(key_padding_mask, 0.0) | ||
RuntimeError: The expanded size of the tensor (747) must match the existing size (748) at non-singleton dimension 1. Target sizes: [1, 747]. Tensor sizes: [1, 748] | ||
""" | ||
answer = self.model.generate( | ||
encoded_images, | ||
prompt, | ||
eos_text="<END>", | ||
tokenizer=self.tokenizer, | ||
max_new_tokens=max_tokens, | ||
)[0] | ||
answer = re.sub("<$|<END$", "", answer).strip() | ||
return answer | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import re | ||
from transformers import AutoTokenizer, AutoModelForCausalLM | ||
|
||
from vision_qna import * | ||
|
||
class VisionQnA(VisionQnABase): | ||
model_name: str = "moondream2" | ||
revision: str = '2024-03-13' # 'main' | ||
format: str = 'phi15' | ||
|
||
def __init__(self, model_id: str, device: str, extra_params = {}, format = None): | ||
super().__init__(model_id, device, extra_params, format) | ||
|
||
# not supported yet | ||
del self.params['device_map'] | ||
|
||
self.tokenizer = AutoTokenizer.from_pretrained(model_id) | ||
self.model = AutoModelForCausalLM.from_pretrained(**self.params, trust_remote_code=True) | ||
|
||
# # bitsandbytes already moves the model to the device, so we don't need to do it again. | ||
if not (extra_params.get('load_in_4bit', False) or extra_params.get('load_in_8bit', False)): | ||
self.model.to(self.device) | ||
|
||
print(f"Loaded on device: {self.model.device} with dtype: {self.model.dtype}") | ||
|
||
async def chat_with_images(self, messages: list[Message], max_tokens: int) -> str: | ||
images, prompt = await prompt_from_messages(messages, self.format) | ||
|
||
encoded_images = self.model.encode_image(images).to(self.device) | ||
|
||
answer = self.model.generate( | ||
encoded_images, | ||
prompt, | ||
eos_text="<END>", | ||
tokenizer=self.tokenizer, | ||
max_new_tokens=max_tokens, | ||
#**kwargs, | ||
)[0] | ||
answer = re.sub("<$|<END$", "", answer).strip() | ||
return answer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
#!/usr/bin/env python | ||
import argparse | ||
from datauri import DataURI | ||
from openai import OpenAI | ||
|
||
# Initialize argparse | ||
parser = argparse.ArgumentParser(description='Test vision using OpenAI') | ||
parser.add_argument('image_url', type=str, help='URL or image file to be tested') | ||
parser.add_argument('questions', type=str, nargs='*', help='The question to ask the image') | ||
args = parser.parse_args() | ||
|
||
client = OpenAI(base_url='http://localhost:5006/v1', api_key='skip') | ||
|
||
image_url = args.image_url | ||
|
||
if not image_url.startswith('http'): | ||
image_url = str(DataURI.from_file(image_url)) | ||
|
||
messages = [ { "role": "user", "content": [ | ||
{ "type": "text", "text": ' '.join(args.questions) }, | ||
{"type": "image_url", "image_url": { "url": image_url } } | ||
]}] | ||
|
||
while True: | ||
response = client.chat.completions.create(model="gpt-4-vision-preview", messages=messages, max_tokens=512,) | ||
print(f"Answer: {response.choices[0].message.content}\n") | ||
|
||
image_url = None | ||
try: | ||
q = input("Question: ") | ||
# if q.startswith('http'): | ||
# image_url = q | ||
# q = input("Question: ") | ||
except EOFError as e: | ||
break | ||
|
||
messages.extend([ | ||
{ "role": "assistant", "content": [ { 'type': 'text', 'text': response.choices[0].message.content } ] }, | ||
{ "role": "user", "content": [ { 'type': 'text', 'text': q } ] } | ||
]) | ||
|
||
# if image_url: | ||
# messages[-1]['content'].extend([ | ||
# {"type": "image_url", "image_url": { "url": image_url } } | ||
# ]) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.