diff --git a/app/llm/llm.py b/app/llm/llm.py index d57e2de..1252d5e 100644 --- a/app/llm/llm.py +++ b/app/llm/llm.py @@ -145,6 +145,7 @@ def get_messages( def __call__( self, prompt: Union[str, Any], + image: Optional[str] = None, id: Union[str, UUID] = None, system: str = None, save_messages: bool = None, @@ -162,6 +163,7 @@ def __call__( return sess.gen_with_tools( model, prompt, + image, tools, client=self.client, system=system, @@ -172,6 +174,7 @@ def __call__( return sess.gen( model, prompt, + image, client=self.client, system=system, save_messages=save_messages, @@ -182,7 +185,9 @@ def __call__( def stream( self, + model: str, prompt: str, + image: Optional[str] = None, id: Union[str, UUID] = None, system: str = None, save_messages: bool = None, @@ -191,7 +196,9 @@ def stream( ) -> str: sess = self.get_session(id) return sess.stream( + model, prompt, + image, client=self.client, system=system, save_messages=save_messages, @@ -235,7 +242,10 @@ def print_messages(self, id: Union[str, UUID] = None) -> None: session = self.get_session(id) if id else self.default_session if session: for msg in session.messages: - print(f"{msg.role} : {msg.content}") + message_str = f"{msg.role} : {msg.content}" + if msg.image: + message_str += f" : ((image))" + print(message_str) def __repr__(self) -> str: return "" @@ -343,6 +353,7 @@ async def __call__( self, model: str, prompt: str, + image: Optional[str] = None, id: Union[str, UUID] = None, system: str = None, save_messages: bool = None, @@ -362,6 +373,7 @@ async def __call__( return await sess.gen_with_tools_async( model, prompt, + image, tools, client=self.client, system=system, @@ -372,6 +384,7 @@ async def __call__( return await sess.gen_async( model, prompt, + image, client=self.client, system=system, save_messages=save_messages, @@ -382,7 +395,9 @@ async def __call__( async def stream( self, + model: str, prompt: str, + image: Optional[str] = None, id: Union[str, UUID] = None, system: str = None, save_messages: bool = None, @@ -394,7 +409,9 @@ async def stream( self.client = AsyncClient(proxies=os.getenv("https_proxy")) sess = self.get_session(id) return sess.stream_async( + model, prompt, + image, client=self.client, system=system, save_messages=save_messages, diff --git a/app/llm/session.py b/app/llm/session.py index 3a3498f..8473ac3 100644 --- a/app/llm/session.py +++ b/app/llm/session.py @@ -1,5 +1,7 @@ import time import os +from io import BytesIO +import base64 from pydantic import BaseModel, SecretStr, HttpUrl, Field from uuid import uuid4, UUID from httpx import Client, AsyncClient @@ -8,12 +10,13 @@ import datetime from ..models import ChatMessage -from ..utils import remove_a_key, now_tz +from ..utils import remove_a_key, now_tz, url_to_image_data ALLOWED_MODELS = [ "gpt-3.5-turbo", "gpt-4-1106-preview", + "gpt-4-vision-preview", "gryphe/mythomax-l2-13b-8k", "mistralai/mistral-medium", "mistralai/mixtral-8x7b-instruct", @@ -80,6 +83,7 @@ def format_input_messages( if self.recent_messages else self.messages ) + # Todo: include images in previous messages messages = ( [system_message.model_dump(include=self.input_fields, exclude_none=True)] + [ @@ -88,7 +92,20 @@ def format_input_messages( ] ) if user_message: - messages += [user_message.model_dump(include=self.input_fields, exclude_none=True)] + new_message = user_message.model_dump(include=self.input_fields, exclude_none=True) + if user_message.image: + img_data_url = url_to_image_data(user_message.image) + new_message["content"] = [ + { + "type": "text", + "text": user_message.content + }, + { + "type": "image_url", + "image_url": img_data_url + } + ] + messages += [new_message] return messages def add_messages( @@ -114,6 +131,7 @@ def prepare_request( self, model: str = "gpt-3.5-turbo", prompt: str = None, + image: Optional[str] = None, system: str = None, params: Dict[str, Any] = None, stream: bool = False, @@ -127,6 +145,9 @@ def prepare_request( if model not in ALLOWED_MODELS: raise ValueError(f"Invalid model: {model}. Available models: {ALLOWED_MODELS}") + + if image: + model = "gpt-4-vision-preview" provider = "openai" if "gpt-" in model else "openrouter" @@ -146,7 +167,7 @@ def prepare_request( if prompt: if not input_schema: - user_message = ChatMessage(role="user", content=prompt) + user_message = ChatMessage(role="user", content=prompt, image=image) else: assert isinstance( prompt, input_schema @@ -154,6 +175,7 @@ def prepare_request( user_message = ChatMessage( role="function", content=prompt.model_dump_json(), + image=image, name=input_schema.__name__, ) @@ -203,6 +225,7 @@ def gen( self, model: str, prompt: str, + image: Optional[str], client: Union[Client, AsyncClient], system: str = None, save_messages: bool = None, @@ -218,7 +241,7 @@ def gen( while not finished: api_url, headers, data, user_message = self.prepare_request( - model, prompt, system, params, False, input_schema, output_schema + model, prompt, image, system, params, False, input_schema, output_schema ) resp = client.post( @@ -269,7 +292,9 @@ def gen( def stream( self, + model: str, prompt: str, + image: Optional[str], client: Union[Client, AsyncClient], system: str = None, save_messages: bool = None, @@ -277,7 +302,7 @@ def stream( input_schema: Any = None, ): api_url, headers, data, user_message = self.prepare_request( - prompt, system, params, True, input_schema + model, prompt, image, system, params, True, input_schema ) with client.stream( @@ -311,6 +336,7 @@ def stream( def gen_with_tools( self, prompt: str, + image: Optional[str], tools: List[Any], client: Union[Client, AsyncClient], system: str = None, @@ -328,6 +354,7 @@ def gen_with_tools( tool_idx = int( self.gen( prompt, + image, client=client, system=tool_prompt_format, save_messages=False, @@ -344,6 +371,7 @@ def gen_with_tools( return { "response": self.gen( prompt, + image, client=client, system=system, save_messages=save_messages, @@ -371,7 +399,7 @@ def gen_with_tools( ) # manually append the nonmodified user message + normal AI response - user_message = ChatMessage(role="user", content=prompt) + user_message = ChatMessage(role="user", content=prompt, image=image) assistant_message = ChatMessage( role="assistant", content=context_dict["response"] ) @@ -383,6 +411,7 @@ async def gen_async( self, model: str, prompt: str, + image: Optional[str], client: Union[Client, AsyncClient], system: str = None, save_messages: bool = None, @@ -391,7 +420,7 @@ async def gen_async( output_schema: Any = None, ): api_url, headers, data, user_message = self.prepare_request( - model, prompt, system, params, False, input_schema, output_schema + model, prompt, image, system, params, False, input_schema, output_schema ) r = await client.post( @@ -430,6 +459,7 @@ async def stream_async( self, model: str, prompt: str, + image: Optional[str], client: Union[Client, AsyncClient], system: str = None, save_messages: bool = None, @@ -437,7 +467,7 @@ async def stream_async( input_schema: Any = None, ): api_url, headers, data, user_message = self.prepare_request( - model, prompt, system, params, True, input_schema + model, prompt, image, system, params, True, input_schema ) async with client.stream( @@ -469,6 +499,7 @@ async def stream_async( async def gen_with_tools_async( self, prompt: str, + image: Optional[str], tools: List[Any], client: Union[Client, AsyncClient], system: str = None, @@ -486,6 +517,7 @@ async def gen_with_tools_async( tool_idx = int( await self.gen_async( prompt, + image, client=client, system=tool_prompt_format, save_messages=False, @@ -522,6 +554,7 @@ async def gen_with_tools_async( context_dict["response"] = await self.gen_async( new_prompt, + image, client=client, system=new_system, save_messages=False, @@ -529,7 +562,7 @@ async def gen_with_tools_async( ) # manually append the nonmodified user message + normal AI response - user_message = ChatMessage(role="user", content=prompt) + user_message = ChatMessage(role="user", content=prompt, image=image) assistant_message = ChatMessage( role="assistant", content=context_dict["response"] ) diff --git a/app/models/characters.py b/app/models/characters.py index a71a147..c2eb044 100644 --- a/app/models/characters.py +++ b/app/models/characters.py @@ -46,6 +46,7 @@ class ChatMessage(BaseModel): role: str content: str + image: Optional[str] = Field(None) name: Optional[str] = None function_call: Optional[str] = None received_at: datetime.datetime = Field(default_factory=now_tz) diff --git a/app/utils.py b/app/utils.py index dc0204c..ae37509 100644 --- a/app/utils.py +++ b/app/utils.py @@ -2,6 +2,7 @@ import time import os import re +import base64 import traceback import requests import math @@ -80,6 +81,14 @@ def PIL_to_bytes(image, ext="JPEG", quality=95): return img_byte_arr.getvalue() +def url_to_image_data(url): + img = download_image(url) + img_bytes = PIL_to_bytes(img) + data = base64.b64encode(img_bytes).decode("utf-8") + data = "data:image/jpeg;base64," + data + return data + + def calculate_target_dimensions(images, max_pixels): min_w = float('inf') min_h = float('inf')