Skip to content

Commit

Permalink
Merge pull request #28 from Dino-Kupinic/16-add-image-endpoint
Browse files Browse the repository at this point in the history
16-add-image-endpoint
  • Loading branch information
Dino-Kupinic authored Aug 2, 2024
2 parents b1299f1 + 977fc19 commit fbcee45
Show file tree
Hide file tree
Showing 17 changed files with 426 additions and 32 deletions.
39 changes: 36 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,17 @@
> [!CAUTION]
> AI Backend is still in Development. You will find bugs and broken/unfinished features.
## ✨ Installation and Configuration
## 🌟 Overview

ai-backend is a backend for AI-powered applications. It leverages FastAPI and Ollama to provide a robust API for natural language processing tasks.

## 🚀 Installation and Configuration

### Prerequisites

- Python 3.12
- pip
- git

### Installation for Development

Expand Down Expand Up @@ -32,6 +42,8 @@ pip install -r requirements.txt
fastapi dev src/main.py
```

## 📖 Documentation

### OpenAPI Documentation

The OpenAPI documentation is available at `/docs`. It is automatically generated from the code.
Expand All @@ -52,6 +64,19 @@ curl -X POST "http://localhost:8000/message/" -H "Content-Type: application/json

// WIP

## 🧪 Testing

To run the test suite:

1. Ensure that both the AI Backend and Ollama services are running.
2. Execute the following command:

```bash
pytest
```

This will run all tests in the `tests/` directory.

## 📝 Contributing

// WIP
Expand All @@ -62,6 +87,14 @@ curl -X POST "http://localhost:8000/message/" -H "Content-Type: application/json
- [Ollama](https://ollama.com/)
- [Pydantic](https://pydantic-docs.helpmanual.io/)

## 😊 License
## 📄 License

This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.

## 🙏 Acknowledgements

- Special thanks to the FastAPI and Ollama communities for their excellent tools and documentation

---

[MIT](https://choosealicense.com/licenses/mit/)
For more information, please [open an issue](https://github.com/Dino-Kupinic/ai-backend/issues) or contact the maintainers.
2 changes: 2 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[pytest]
pythonpath = .
17 changes: 17 additions & 0 deletions src/internal/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from enum import Enum


class TextModel(Enum):
"""
Enum for text models.
"""

LLAMA3 = "llama3"


class ImageModel(Enum):
"""
Enum for image models.
"""

LLAVA = "llava"
22 changes: 0 additions & 22 deletions src/internal/queries.py

This file was deleted.

106 changes: 106 additions & 0 deletions src/internal/query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from abc import ABC, abstractmethod
from typing import overload, Any, Union

import ollama
from fastapi import HTTPException

from src.internal.models import TextModel, ImageModel
from src.types.alias import LLMResponse


class Query(ABC):
"""
Abstract class for querying a model
"""

@abstractmethod
def query(self, prompt: str, model: Any, **kwargs) -> LLMResponse:
"""
Query the model with the given prompt
:param prompt: The prompt to query the model with
:param model: The model to query
:return: A generator of the model's responses
"""
pass


class TextQuery(Query):
"""
Query a text model
"""

@overload
def query(self, prompt: str, model: TextModel, **kwargs) -> LLMResponse: ...

@overload
def query(self, prompt: str, model: str, **kwargs) -> LLMResponse: ...

def query(self, prompt: str, model: Union[TextModel, str], **kwargs) -> LLMResponse:
"""
Query the model with the given prompt
:param prompt: The prompt to query the model with
:param model: The model to query (TextModel instance or string)
:return: A generator of the model's responses
"""
try:
model_name = model.value if isinstance(model, TextModel) else model
messages = [{"role": "user", "content": prompt}]
yield from self._text_llm_call(model_name, messages)
except Exception as e:
print(f"Unexpected error: {str(e)}")
raise HTTPException(
status_code=500, detail=f"Internal Server Error: {str(e)}"
)

@staticmethod
def _text_llm_call(model: str, messages: list) -> LLMResponse:
for chunk in ollama.chat(
model,
messages=messages,
stream=True,
):
yield chunk["message"]["content"]


class ImageQuery(Query):
"""
Query an image model
"""

@overload
def query(self, prompt: str, model: ImageModel, **kwargs) -> LLMResponse: ...

@overload
def query(self, prompt: str, model: str, **kwargs) -> LLMResponse: ...

def query(
self, prompt: str, model: Union[ImageModel, str], **kwargs
) -> LLMResponse:
"""
Query the model with the given prompt and images
:param prompt: The prompt to query the model with
:param model: The model to query (ImageModel instance)
:return: A generator of the model's responses
"""
try:
images = kwargs.get("images", [])
if not images:
raise ValueError("Images must be provided for an image query.")

model_name = model.value if isinstance(model, ImageModel) else model
messages = [{"role": "user", "content": prompt, "images": images}]
yield from self._image_llm_call(model_name, messages)
except Exception as e:
print(f"Unexpected error: {str(e)}")
raise HTTPException(
status_code=500, detail=f"Internal Server Error: {str(e)}"
)

@staticmethod
def _image_llm_call(model: str, messages: list) -> LLMResponse:
for chunk in ollama.chat(
model,
messages=messages,
stream=True,
):
yield chunk["message"]["content"]
3 changes: 2 additions & 1 deletion src/main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware

from src.routers import message
from src.routers import message, image

app = FastAPI(
title="AI Backend",
Expand All @@ -20,3 +20,4 @@
)

app.include_router(message.router)
app.include_router(image.router)
30 changes: 30 additions & 0 deletions src/routers/image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from fastapi import APIRouter, HTTPException
from fastapi.responses import StreamingResponse

from src.internal.models import ImageModel
from src.internal.query import ImageQuery
from src.types.requests import Image

router = APIRouter(prefix="/image", tags=["image"])

image_query = ImageQuery()


@router.post("/")
async def send_message(image: Image):
try:
if isinstance(image.model, ImageModel):
model = image.model
elif isinstance(image.model, str):
model = image.model
else:
raise HTTPException(status_code=400, detail="Invalid model type")

return StreamingResponse(
image_query.query(prompt=image.prompt, model=model, images=image.images),
media_type="text/plain",
)
except HTTPException as http_exc:
raise http_exc
except Exception as e:
raise HTTPException(status_code=500, detail=f"Internal Server Error: {str(e)}")
24 changes: 21 additions & 3 deletions src/routers/message.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,30 @@
from fastapi import APIRouter
from fastapi import APIRouter, HTTPException
from fastapi.responses import StreamingResponse

from src.internal.queries import query
from src.internal.query import TextQuery
from src.types.requests import Message
from src.internal.models import TextModel

router = APIRouter(prefix="/message", tags=["message"])

text_query = TextQuery()


@router.post("/")
async def send_message(message: Message):
return StreamingResponse(query(message.text))
try:
if isinstance(message.model, TextModel):
model = message.model
elif isinstance(message.model, str):
model = message.model
else:
raise HTTPException(status_code=400, detail="Invalid model type")

return StreamingResponse(
text_query.query(prompt=message.prompt, model=model),
media_type="text/plain",
)
except HTTPException as http_exc:
raise http_exc
except Exception as e:
raise HTTPException(status_code=500, detail=f"Internal Server Error: {str(e)}")
3 changes: 3 additions & 0 deletions src/types/alias.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from typing import TypeAlias, Generator

LLMResponse: TypeAlias = Generator[str, None, None]
17 changes: 16 additions & 1 deletion src/types/requests.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,20 @@
from pydantic import BaseModel

from src.internal.models import TextModel, ImageModel


class Message(BaseModel):
text: str
"""
Request body for sending a message to a model
"""
prompt: str
model: TextModel


class Image(BaseModel):
"""
Request body for sending an image to a model
"""
prompt: str
model: ImageModel
images: list[str]
Binary file added tests/assets/images/bird.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/assets/images/cat.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/assets/images/house.webp
Binary file not shown.
44 changes: 44 additions & 0 deletions tests/assets/test_images.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
[
{
"prompt": "What can you see in the image?",
"model": "llava",
"images": ["images/bird.jpg"],
"answers": [
"bird",
"animal",
"tree",
"yellow",
"grey",
"wings"
]
},
{
"prompt": "What can you see in the image?",
"model": "llava",
"images": ["images/cat.jpg"],
"answers": [
"cat",
"animal",
"black",
"white",
"furry",
"ears",
"green"
]
},
{
"prompt": "What can you see in the image?",
"model": "llava",
"images": ["images/house.webp"],
"answers": [
"house",
"building",
"blue",
"white",
"windows",
"door",
"roof",
"plants"
]
}
]
50 changes: 50 additions & 0 deletions tests/assets/test_images_base64.json

Large diffs are not rendered by default.

Loading

0 comments on commit fbcee45

Please sign in to comment.