Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

16-add-image-endpoint #28

Merged
merged 4 commits into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 36 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,17 @@
> [!CAUTION]
> AI Backend is still in Development. You will find bugs and broken/unfinished features.

## ✨ Installation and Configuration
## 🌟 Overview

ai-backend is a backend for AI-powered applications. It leverages FastAPI and Ollama to provide a robust API for natural language processing tasks.

## πŸš€ Installation and Configuration

### Prerequisites

- Python 3.12
- pip
- git

### Installation for Development

Expand Down Expand Up @@ -32,6 +42,8 @@ pip install -r requirements.txt
fastapi dev src/main.py
```

## πŸ“– Documentation

### OpenAPI Documentation

The OpenAPI documentation is available at `/docs`. It is automatically generated from the code.
Expand All @@ -52,6 +64,19 @@ curl -X POST "http://localhost:8000/message/" -H "Content-Type: application/json

// WIP

## πŸ§ͺ Testing

To run the test suite:

1. Ensure that both the AI Backend and Ollama services are running.
2. Execute the following command:

```bash
pytest
```

This will run all tests in the `tests/` directory.

## πŸ“ Contributing

// WIP
Expand All @@ -62,6 +87,14 @@ curl -X POST "http://localhost:8000/message/" -H "Content-Type: application/json
- [Ollama](https://ollama.com/)
- [Pydantic](https://pydantic-docs.helpmanual.io/)

## 😊 License
## πŸ“„ License

This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.

## πŸ™ Acknowledgements

- Special thanks to the FastAPI and Ollama communities for their excellent tools and documentation

---

[MIT](https://choosealicense.com/licenses/mit/)
For more information, please [open an issue](https://github.com/Dino-Kupinic/ai-backend/issues) or contact the maintainers.
2 changes: 2 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[pytest]
pythonpath = .
17 changes: 17 additions & 0 deletions src/internal/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from enum import Enum


class TextModel(Enum):
"""
Enum for text models.
"""

LLAMA3 = "llama3"


class ImageModel(Enum):
"""
Enum for image models.
"""

LLAVA = "llava"
22 changes: 0 additions & 22 deletions src/internal/queries.py

This file was deleted.

106 changes: 106 additions & 0 deletions src/internal/query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from abc import ABC, abstractmethod
from typing import overload, Any, Union

import ollama
from fastapi import HTTPException

from src.internal.models import TextModel, ImageModel
from src.types.alias import LLMResponse


class Query(ABC):
"""
Abstract class for querying a model
"""

@abstractmethod
def query(self, prompt: str, model: Any, **kwargs) -> LLMResponse:
"""
Query the model with the given prompt
:param prompt: The prompt to query the model with
:param model: The model to query
:return: A generator of the model's responses
"""
pass


class TextQuery(Query):
"""
Query a text model
"""

@overload
def query(self, prompt: str, model: TextModel, **kwargs) -> LLMResponse: ...

@overload
def query(self, prompt: str, model: str, **kwargs) -> LLMResponse: ...

def query(self, prompt: str, model: Union[TextModel, str], **kwargs) -> LLMResponse:
"""
Query the model with the given prompt
:param prompt: The prompt to query the model with
:param model: The model to query (TextModel instance or string)
:return: A generator of the model's responses
"""
try:
model_name = model.value if isinstance(model, TextModel) else model
messages = [{"role": "user", "content": prompt}]
yield from self._text_llm_call(model_name, messages)
except Exception as e:
print(f"Unexpected error: {str(e)}")
raise HTTPException(
status_code=500, detail=f"Internal Server Error: {str(e)}"
)

@staticmethod
def _text_llm_call(model: str, messages: list) -> LLMResponse:
for chunk in ollama.chat(
model,
messages=messages,
stream=True,
):
yield chunk["message"]["content"]


class ImageQuery(Query):
"""
Query an image model
"""

@overload
def query(self, prompt: str, model: ImageModel, **kwargs) -> LLMResponse: ...

@overload
def query(self, prompt: str, model: str, **kwargs) -> LLMResponse: ...

def query(
self, prompt: str, model: Union[ImageModel, str], **kwargs
) -> LLMResponse:
"""
Query the model with the given prompt and images
:param prompt: The prompt to query the model with
:param model: The model to query (ImageModel instance)
:return: A generator of the model's responses
"""
try:
images = kwargs.get("images", [])
if not images:
raise ValueError("Images must be provided for an image query.")

model_name = model.value if isinstance(model, ImageModel) else model
messages = [{"role": "user", "content": prompt, "images": images}]
yield from self._image_llm_call(model_name, messages)
except Exception as e:
print(f"Unexpected error: {str(e)}")
raise HTTPException(
status_code=500, detail=f"Internal Server Error: {str(e)}"
)

@staticmethod
def _image_llm_call(model: str, messages: list) -> LLMResponse:
for chunk in ollama.chat(
model,
messages=messages,
stream=True,
):
yield chunk["message"]["content"]
3 changes: 2 additions & 1 deletion src/main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware

from src.routers import message
from src.routers import message, image

app = FastAPI(
title="AI Backend",
Expand All @@ -20,3 +20,4 @@
)

app.include_router(message.router)
app.include_router(image.router)
30 changes: 30 additions & 0 deletions src/routers/image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from fastapi import APIRouter, HTTPException
from fastapi.responses import StreamingResponse

from src.internal.models import ImageModel
from src.internal.query import ImageQuery
from src.types.requests import Image

router = APIRouter(prefix="/image", tags=["image"])

image_query = ImageQuery()


@router.post("/")
async def send_message(image: Image):
try:
if isinstance(image.model, ImageModel):
model = image.model
elif isinstance(image.model, str):
model = image.model
else:
raise HTTPException(status_code=400, detail="Invalid model type")

return StreamingResponse(
image_query.query(prompt=image.prompt, model=model, images=image.images),
media_type="text/plain",
)
except HTTPException as http_exc:
raise http_exc
except Exception as e:
raise HTTPException(status_code=500, detail=f"Internal Server Error: {str(e)}")
24 changes: 21 additions & 3 deletions src/routers/message.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,30 @@
from fastapi import APIRouter
from fastapi import APIRouter, HTTPException
from fastapi.responses import StreamingResponse

from src.internal.queries import query
from src.internal.query import TextQuery
from src.types.requests import Message
from src.internal.models import TextModel

router = APIRouter(prefix="/message", tags=["message"])

text_query = TextQuery()


@router.post("/")
async def send_message(message: Message):
return StreamingResponse(query(message.text))
try:
if isinstance(message.model, TextModel):
model = message.model
elif isinstance(message.model, str):
model = message.model
else:
raise HTTPException(status_code=400, detail="Invalid model type")

return StreamingResponse(
text_query.query(prompt=message.prompt, model=model),
media_type="text/plain",
)
except HTTPException as http_exc:
raise http_exc
except Exception as e:
raise HTTPException(status_code=500, detail=f"Internal Server Error: {str(e)}")
3 changes: 3 additions & 0 deletions src/types/alias.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from typing import TypeAlias, Generator

LLMResponse: TypeAlias = Generator[str, None, None]
17 changes: 16 additions & 1 deletion src/types/requests.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,20 @@
from pydantic import BaseModel

from src.internal.models import TextModel, ImageModel


class Message(BaseModel):
text: str
"""
Request body for sending a message to a model
"""
prompt: str
model: TextModel


class Image(BaseModel):
"""
Request body for sending an image to a model
"""
prompt: str
model: ImageModel
images: list[str]
Binary file added tests/assets/images/bird.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/assets/images/cat.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/assets/images/house.webp
Binary file not shown.
44 changes: 44 additions & 0 deletions tests/assets/test_images.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
[
{
"prompt": "What can you see in the image?",
"model": "llava",
"images": ["images/bird.jpg"],
"answers": [
"bird",
"animal",
"tree",
"yellow",
"grey",
"wings"
]
},
{
"prompt": "What can you see in the image?",
"model": "llava",
"images": ["images/cat.jpg"],
"answers": [
"cat",
"animal",
"black",
"white",
"furry",
"ears",
"green"
]
},
{
"prompt": "What can you see in the image?",
"model": "llava",
"images": ["images/house.webp"],
"answers": [
"house",
"building",
"blue",
"white",
"windows",
"door",
"roof",
"plants"
]
}
]
50 changes: 50 additions & 0 deletions tests/assets/test_images_base64.json

Large diffs are not rendered by default.

Loading