Skip to content

Commit

Permalink
image support
Browse files Browse the repository at this point in the history
  • Loading branch information
genekogan committed Mar 12, 2024
1 parent 5db5927 commit 8aa4515
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 10 deletions.
19 changes: 18 additions & 1 deletion app/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def get_messages(
def __call__(
self,
prompt: Union[str, Any],
image: Optional[str] = None,
id: Union[str, UUID] = None,
system: str = None,
save_messages: bool = None,
Expand All @@ -162,6 +163,7 @@ def __call__(
return sess.gen_with_tools(
model,
prompt,
image,
tools,
client=self.client,
system=system,
Expand All @@ -172,6 +174,7 @@ def __call__(
return sess.gen(
model,
prompt,
image,
client=self.client,
system=system,
save_messages=save_messages,
Expand All @@ -182,7 +185,9 @@ def __call__(

def stream(
self,
model: str,
prompt: str,
image: Optional[str] = None,
id: Union[str, UUID] = None,
system: str = None,
save_messages: bool = None,
Expand All @@ -191,7 +196,9 @@ def stream(
) -> str:
sess = self.get_session(id)
return sess.stream(
model,
prompt,
image,
client=self.client,
system=system,
save_messages=save_messages,
Expand Down Expand Up @@ -235,7 +242,10 @@ def print_messages(self, id: Union[str, UUID] = None) -> None:
session = self.get_session(id) if id else self.default_session
if session:
for msg in session.messages:
print(f"{msg.role} : {msg.content}")
message_str = f"{msg.role} : {msg.content}"
if msg.image:
message_str += f" : ((image))"
print(message_str)

def __repr__(self) -> str:
return ""
Expand Down Expand Up @@ -343,6 +353,7 @@ async def __call__(
self,
model: str,
prompt: str,
image: Optional[str] = None,
id: Union[str, UUID] = None,
system: str = None,
save_messages: bool = None,
Expand All @@ -362,6 +373,7 @@ async def __call__(
return await sess.gen_with_tools_async(
model,
prompt,
image,
tools,
client=self.client,
system=system,
Expand All @@ -372,6 +384,7 @@ async def __call__(
return await sess.gen_async(
model,
prompt,
image,
client=self.client,
system=system,
save_messages=save_messages,
Expand All @@ -382,7 +395,9 @@ async def __call__(

async def stream(
self,
model: str,
prompt: str,
image: Optional[str] = None,
id: Union[str, UUID] = None,
system: str = None,
save_messages: bool = None,
Expand All @@ -394,7 +409,9 @@ async def stream(
self.client = AsyncClient(proxies=os.getenv("https_proxy"))
sess = self.get_session(id)
return sess.stream_async(
model,
prompt,
image,
client=self.client,
system=system,
save_messages=save_messages,
Expand Down
51 changes: 42 additions & 9 deletions app/llm/session.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import time
import os
from io import BytesIO
import base64
from pydantic import BaseModel, SecretStr, HttpUrl, Field
from uuid import uuid4, UUID
from httpx import Client, AsyncClient
Expand All @@ -8,12 +10,13 @@
import datetime

from ..models import ChatMessage
from ..utils import remove_a_key, now_tz
from ..utils import remove_a_key, now_tz, url_to_image_data


ALLOWED_MODELS = [
"gpt-3.5-turbo",
"gpt-4-1106-preview",
"gpt-4-vision-preview",
"gryphe/mythomax-l2-13b-8k",
"mistralai/mistral-medium",
"mistralai/mixtral-8x7b-instruct",
Expand Down Expand Up @@ -80,6 +83,7 @@ def format_input_messages(
if self.recent_messages
else self.messages
)
# Todo: include images in previous messages
messages = (
[system_message.model_dump(include=self.input_fields, exclude_none=True)]
+ [
Expand All @@ -88,7 +92,20 @@ def format_input_messages(
]
)
if user_message:
messages += [user_message.model_dump(include=self.input_fields, exclude_none=True)]
new_message = user_message.model_dump(include=self.input_fields, exclude_none=True)
if user_message.image:
img_data_url = url_to_image_data(user_message.image)
new_message["content"] = [
{
"type": "text",
"text": user_message.content
},
{
"type": "image_url",
"image_url": img_data_url
}
]
messages += [new_message]
return messages

def add_messages(
Expand All @@ -114,6 +131,7 @@ def prepare_request(
self,
model: str = "gpt-3.5-turbo",
prompt: str = None,
image: Optional[str] = None,
system: str = None,
params: Dict[str, Any] = None,
stream: bool = False,
Expand All @@ -127,6 +145,9 @@ def prepare_request(

if model not in ALLOWED_MODELS:
raise ValueError(f"Invalid model: {model}. Available models: {ALLOWED_MODELS}")

if image:
model = "gpt-4-vision-preview"

provider = "openai" if "gpt-" in model else "openrouter"

Expand All @@ -146,14 +167,15 @@ def prepare_request(

if prompt:
if not input_schema:
user_message = ChatMessage(role="user", content=prompt)
user_message = ChatMessage(role="user", content=prompt, image=image)
else:
assert isinstance(
prompt, input_schema
), f"prompt must be an instance of {input_schema.__name__}"
user_message = ChatMessage(
role="function",
content=prompt.model_dump_json(),
image=image,
name=input_schema.__name__,
)

Expand Down Expand Up @@ -203,6 +225,7 @@ def gen(
self,
model: str,
prompt: str,
image: Optional[str],
client: Union[Client, AsyncClient],
system: str = None,
save_messages: bool = None,
Expand All @@ -218,7 +241,7 @@ def gen(

while not finished:
api_url, headers, data, user_message = self.prepare_request(
model, prompt, system, params, False, input_schema, output_schema
model, prompt, image, system, params, False, input_schema, output_schema
)

resp = client.post(
Expand Down Expand Up @@ -269,15 +292,17 @@ def gen(

def stream(
self,
model: str,
prompt: str,
image: Optional[str],
client: Union[Client, AsyncClient],
system: str = None,
save_messages: bool = None,
params: Dict[str, Any] = None,
input_schema: Any = None,
):
api_url, headers, data, user_message = self.prepare_request(
prompt, system, params, True, input_schema
model, prompt, image, system, params, True, input_schema
)

with client.stream(
Expand Down Expand Up @@ -311,6 +336,7 @@ def stream(
def gen_with_tools(
self,
prompt: str,
image: Optional[str],
tools: List[Any],
client: Union[Client, AsyncClient],
system: str = None,
Expand All @@ -328,6 +354,7 @@ def gen_with_tools(
tool_idx = int(
self.gen(
prompt,
image,
client=client,
system=tool_prompt_format,
save_messages=False,
Expand All @@ -344,6 +371,7 @@ def gen_with_tools(
return {
"response": self.gen(
prompt,
image,
client=client,
system=system,
save_messages=save_messages,
Expand Down Expand Up @@ -371,7 +399,7 @@ def gen_with_tools(
)

# manually append the nonmodified user message + normal AI response
user_message = ChatMessage(role="user", content=prompt)
user_message = ChatMessage(role="user", content=prompt, image=image)
assistant_message = ChatMessage(
role="assistant", content=context_dict["response"]
)
Expand All @@ -383,6 +411,7 @@ async def gen_async(
self,
model: str,
prompt: str,
image: Optional[str],
client: Union[Client, AsyncClient],
system: str = None,
save_messages: bool = None,
Expand All @@ -391,7 +420,7 @@ async def gen_async(
output_schema: Any = None,
):
api_url, headers, data, user_message = self.prepare_request(
model, prompt, system, params, False, input_schema, output_schema
model, prompt, image, system, params, False, input_schema, output_schema
)

r = await client.post(
Expand Down Expand Up @@ -430,14 +459,15 @@ async def stream_async(
self,
model: str,
prompt: str,
image: Optional[str],
client: Union[Client, AsyncClient],
system: str = None,
save_messages: bool = None,
params: Dict[str, Any] = None,
input_schema: Any = None,
):
api_url, headers, data, user_message = self.prepare_request(
model, prompt, system, params, True, input_schema
model, prompt, image, system, params, True, input_schema
)

async with client.stream(
Expand Down Expand Up @@ -469,6 +499,7 @@ async def stream_async(
async def gen_with_tools_async(
self,
prompt: str,
image: Optional[str],
tools: List[Any],
client: Union[Client, AsyncClient],
system: str = None,
Expand All @@ -486,6 +517,7 @@ async def gen_with_tools_async(
tool_idx = int(
await self.gen_async(
prompt,
image,
client=client,
system=tool_prompt_format,
save_messages=False,
Expand Down Expand Up @@ -522,14 +554,15 @@ async def gen_with_tools_async(

context_dict["response"] = await self.gen_async(
new_prompt,
image,
client=client,
system=new_system,
save_messages=False,
params=params,
)

# manually append the nonmodified user message + normal AI response
user_message = ChatMessage(role="user", content=prompt)
user_message = ChatMessage(role="user", content=prompt, image=image)
assistant_message = ChatMessage(
role="assistant", content=context_dict["response"]
)
Expand Down
1 change: 1 addition & 0 deletions app/models/characters.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class ChatMessage(BaseModel):

role: str
content: str
image: Optional[str] = Field(None)
name: Optional[str] = None
function_call: Optional[str] = None
received_at: datetime.datetime = Field(default_factory=now_tz)
Expand Down
9 changes: 9 additions & 0 deletions app/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import time
import os
import re
import base64
import traceback
import requests
import math
Expand Down Expand Up @@ -80,6 +81,14 @@ def PIL_to_bytes(image, ext="JPEG", quality=95):
return img_byte_arr.getvalue()


def url_to_image_data(url):
img = download_image(url)
img_bytes = PIL_to_bytes(img)
data = base64.b64encode(img_bytes).decode("utf-8")
data = "data:image/jpeg;base64," + data
return data


def calculate_target_dimensions(images, max_pixels):
min_w = float('inf')
min_h = float('inf')
Expand Down

0 comments on commit 8aa4515

Please sign in to comment.