-
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #28 from Dino-Kupinic/16-add-image-endpoint
16-add-image-endpoint
- Loading branch information
Showing
17 changed files
with
426 additions
and
32 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
[pytest] | ||
pythonpath = . |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
from enum import Enum | ||
|
||
|
||
class TextModel(Enum): | ||
""" | ||
Enum for text models. | ||
""" | ||
|
||
LLAMA3 = "llama3" | ||
|
||
|
||
class ImageModel(Enum): | ||
""" | ||
Enum for image models. | ||
""" | ||
|
||
LLAVA = "llava" |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
from abc import ABC, abstractmethod | ||
from typing import overload, Any, Union | ||
|
||
import ollama | ||
from fastapi import HTTPException | ||
|
||
from src.internal.models import TextModel, ImageModel | ||
from src.types.alias import LLMResponse | ||
|
||
|
||
class Query(ABC): | ||
""" | ||
Abstract class for querying a model | ||
""" | ||
|
||
@abstractmethod | ||
def query(self, prompt: str, model: Any, **kwargs) -> LLMResponse: | ||
""" | ||
Query the model with the given prompt | ||
:param prompt: The prompt to query the model with | ||
:param model: The model to query | ||
:return: A generator of the model's responses | ||
""" | ||
pass | ||
|
||
|
||
class TextQuery(Query): | ||
""" | ||
Query a text model | ||
""" | ||
|
||
@overload | ||
def query(self, prompt: str, model: TextModel, **kwargs) -> LLMResponse: ... | ||
|
||
@overload | ||
def query(self, prompt: str, model: str, **kwargs) -> LLMResponse: ... | ||
|
||
def query(self, prompt: str, model: Union[TextModel, str], **kwargs) -> LLMResponse: | ||
""" | ||
Query the model with the given prompt | ||
:param prompt: The prompt to query the model with | ||
:param model: The model to query (TextModel instance or string) | ||
:return: A generator of the model's responses | ||
""" | ||
try: | ||
model_name = model.value if isinstance(model, TextModel) else model | ||
messages = [{"role": "user", "content": prompt}] | ||
yield from self._text_llm_call(model_name, messages) | ||
except Exception as e: | ||
print(f"Unexpected error: {str(e)}") | ||
raise HTTPException( | ||
status_code=500, detail=f"Internal Server Error: {str(e)}" | ||
) | ||
|
||
@staticmethod | ||
def _text_llm_call(model: str, messages: list) -> LLMResponse: | ||
for chunk in ollama.chat( | ||
model, | ||
messages=messages, | ||
stream=True, | ||
): | ||
yield chunk["message"]["content"] | ||
|
||
|
||
class ImageQuery(Query): | ||
""" | ||
Query an image model | ||
""" | ||
|
||
@overload | ||
def query(self, prompt: str, model: ImageModel, **kwargs) -> LLMResponse: ... | ||
|
||
@overload | ||
def query(self, prompt: str, model: str, **kwargs) -> LLMResponse: ... | ||
|
||
def query( | ||
self, prompt: str, model: Union[ImageModel, str], **kwargs | ||
) -> LLMResponse: | ||
""" | ||
Query the model with the given prompt and images | ||
:param prompt: The prompt to query the model with | ||
:param model: The model to query (ImageModel instance) | ||
:return: A generator of the model's responses | ||
""" | ||
try: | ||
images = kwargs.get("images", []) | ||
if not images: | ||
raise ValueError("Images must be provided for an image query.") | ||
|
||
model_name = model.value if isinstance(model, ImageModel) else model | ||
messages = [{"role": "user", "content": prompt, "images": images}] | ||
yield from self._image_llm_call(model_name, messages) | ||
except Exception as e: | ||
print(f"Unexpected error: {str(e)}") | ||
raise HTTPException( | ||
status_code=500, detail=f"Internal Server Error: {str(e)}" | ||
) | ||
|
||
@staticmethod | ||
def _image_llm_call(model: str, messages: list) -> LLMResponse: | ||
for chunk in ollama.chat( | ||
model, | ||
messages=messages, | ||
stream=True, | ||
): | ||
yield chunk["message"]["content"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
from fastapi import APIRouter, HTTPException | ||
from fastapi.responses import StreamingResponse | ||
|
||
from src.internal.models import ImageModel | ||
from src.internal.query import ImageQuery | ||
from src.types.requests import Image | ||
|
||
router = APIRouter(prefix="/image", tags=["image"]) | ||
|
||
image_query = ImageQuery() | ||
|
||
|
||
@router.post("/") | ||
async def send_message(image: Image): | ||
try: | ||
if isinstance(image.model, ImageModel): | ||
model = image.model | ||
elif isinstance(image.model, str): | ||
model = image.model | ||
else: | ||
raise HTTPException(status_code=400, detail="Invalid model type") | ||
|
||
return StreamingResponse( | ||
image_query.query(prompt=image.prompt, model=model, images=image.images), | ||
media_type="text/plain", | ||
) | ||
except HTTPException as http_exc: | ||
raise http_exc | ||
except Exception as e: | ||
raise HTTPException(status_code=500, detail=f"Internal Server Error: {str(e)}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,30 @@ | ||
from fastapi import APIRouter | ||
from fastapi import APIRouter, HTTPException | ||
from fastapi.responses import StreamingResponse | ||
|
||
from src.internal.queries import query | ||
from src.internal.query import TextQuery | ||
from src.types.requests import Message | ||
from src.internal.models import TextModel | ||
|
||
router = APIRouter(prefix="/message", tags=["message"]) | ||
|
||
text_query = TextQuery() | ||
|
||
|
||
@router.post("/") | ||
async def send_message(message: Message): | ||
return StreamingResponse(query(message.text)) | ||
try: | ||
if isinstance(message.model, TextModel): | ||
model = message.model | ||
elif isinstance(message.model, str): | ||
model = message.model | ||
else: | ||
raise HTTPException(status_code=400, detail="Invalid model type") | ||
|
||
return StreamingResponse( | ||
text_query.query(prompt=message.prompt, model=model), | ||
media_type="text/plain", | ||
) | ||
except HTTPException as http_exc: | ||
raise http_exc | ||
except Exception as e: | ||
raise HTTPException(status_code=500, detail=f"Internal Server Error: {str(e)}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from typing import TypeAlias, Generator | ||
|
||
LLMResponse: TypeAlias = Generator[str, None, None] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,20 @@ | ||
from pydantic import BaseModel | ||
|
||
from src.internal.models import TextModel, ImageModel | ||
|
||
|
||
class Message(BaseModel): | ||
text: str | ||
""" | ||
Request body for sending a message to a model | ||
""" | ||
prompt: str | ||
model: TextModel | ||
|
||
|
||
class Image(BaseModel): | ||
""" | ||
Request body for sending an image to a model | ||
""" | ||
prompt: str | ||
model: ImageModel | ||
images: list[str] |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
[ | ||
{ | ||
"prompt": "What can you see in the image?", | ||
"model": "llava", | ||
"images": ["images/bird.jpg"], | ||
"answers": [ | ||
"bird", | ||
"animal", | ||
"tree", | ||
"yellow", | ||
"grey", | ||
"wings" | ||
] | ||
}, | ||
{ | ||
"prompt": "What can you see in the image?", | ||
"model": "llava", | ||
"images": ["images/cat.jpg"], | ||
"answers": [ | ||
"cat", | ||
"animal", | ||
"black", | ||
"white", | ||
"furry", | ||
"ears", | ||
"green" | ||
] | ||
}, | ||
{ | ||
"prompt": "What can you see in the image?", | ||
"model": "llava", | ||
"images": ["images/house.webp"], | ||
"answers": [ | ||
"house", | ||
"building", | ||
"blue", | ||
"white", | ||
"windows", | ||
"door", | ||
"roof", | ||
"plants" | ||
] | ||
} | ||
] |
Large diffs are not rendered by default.
Oops, something went wrong.
Oops, something went wrong.