Skip to content

Commit

Permalink
feat: Basic FastAPI setup
Browse files Browse the repository at this point in the history
  • Loading branch information
RezaRahemtola committed Aug 26, 2024
1 parent 192ad1c commit 0cd51d0
Show file tree
Hide file tree
Showing 11 changed files with 882 additions and 212 deletions.
18 changes: 16 additions & 2 deletions libertai_agents/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@

import aiohttp
from aiohttp import ClientSession
from fastapi import APIRouter, FastAPI

from libertai_agents.interfaces import Message, MessageRoleEnum, LlamaCppParams, MessageToolCall, ToolCallFunction, \
ToolCallMessage, CustomizableLlamaCppParams, ToolResponseMessage
from libertai_agents.interfaces.common import Message, MessageRoleEnum, LlamaCppParams, MessageToolCall, \
ToolCallFunction, ToolCallMessage, CustomizableLlamaCppParams, ToolResponseMessage
from libertai_agents.interfaces.models import ModelInformation
from libertai_agents.models import Model
from libertai_agents.utils import find

Expand All @@ -16,6 +18,7 @@ class ChatAgent:
system_prompt: str
tools: list[Callable[..., Awaitable[Any]]]
llamacpp_params: CustomizableLlamaCppParams
app: FastAPI

def __init__(self, model: Model, system_prompt: str, tools: list[Callable[..., Awaitable[Any]]] | None = None,
llamacpp_params: CustomizableLlamaCppParams = CustomizableLlamaCppParams()):
Expand All @@ -29,6 +32,17 @@ def __init__(self, model: Model, system_prompt: str, tools: list[Callable[..., A
self.tools = tools
self.llamacpp_params = llamacpp_params

# Define API routes
router = APIRouter()
router.add_api_route("/generate-answer", self.generate_answer, methods=["POST"])
router.add_api_route("/model", self.get_model_information, methods=["GET"])

self.app = FastAPI(title="LibertAI ChatAgent")
self.app.include_router(router)

def get_model_information(self) -> ModelInformation:
return ModelInformation(id=self.model.model_id, context_length=self.model.context_length)

async def generate_answer(self, messages: list[Message]) -> str:
if len(messages) == 0:
raise ValueError("No previous message to respond to")
Expand Down
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@


class MessageRoleEnum(str, Enum):
system = 'system'
user = 'user'
assistant = 'assistant'
system = 'system'
tool = 'tool'


Expand Down
8 changes: 8 additions & 0 deletions libertai_agents/interfaces/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from pydantic import BaseModel

from libertai_agents.models.base import ModelId


class ModelInformation(BaseModel):
id: ModelId
context_length: int
13 changes: 11 additions & 2 deletions libertai_agents/models/base.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,33 @@
import logging
from abc import ABC, abstractmethod
from typing import Literal

from libertai_agents.interfaces import Message, ToolCallFunction, MessageRoleEnum
from libertai_agents.interfaces.common import Message, ToolCallFunction, MessageRoleEnum

# Disables the error about models not available
logging.getLogger("transformers").disabled = True

ModelId = Literal[
"NousResearch/Hermes-2-Pro-Llama-3-8B",
"NousResearch/Hermes-3-Llama-3.1-8B",
"mistralai/Mistral-Nemo-Instruct-2407"
]


class Model(ABC):
from transformers import PreTrainedTokenizerFast

tokenizer: PreTrainedTokenizerFast
model_id: ModelId
vm_url: str
context_length: int
system_message: bool

def __init__(self, model_id: str, vm_url: str, context_length: int, system_message: bool = True):
def __init__(self, model_id: ModelId, vm_url: str, context_length: int, system_message: bool = True):
from transformers import AutoTokenizer

self.tokenizer = AutoTokenizer.from_pretrained(model_id)
self.model_id = model_id
self.vm_url = vm_url
self.context_length = context_length
self.system_message = system_message
Expand Down
2 changes: 1 addition & 1 deletion libertai_agents/models/hermes.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import re

from libertai_agents.interfaces import ToolCallFunction
from libertai_agents.interfaces.common import ToolCallFunction
from libertai_agents.models.base import Model


Expand Down
2 changes: 1 addition & 1 deletion libertai_agents/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import random
import string

from libertai_agents.interfaces import ToolCallFunction
from libertai_agents.interfaces.common import ToolCallFunction
from libertai_agents.models.base import Model


Expand Down
7 changes: 1 addition & 6 deletions libertai_agents/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from huggingface_hub import login
from pydantic import BaseModel

from libertai_agents.models.base import Model
from libertai_agents.models.base import Model, ModelId
from libertai_agents.models.hermes import HermesModel
from libertai_agents.models.mistral import MistralModel

Expand All @@ -14,11 +14,6 @@ class ModelConfiguration(BaseModel):
constructor: typing.Type[Model]


ModelId = typing.Literal[
"NousResearch/Hermes-2-Pro-Llama-3-8B",
"NousResearch/Hermes-3-Llama-3.1-8B",
"mistralai/Mistral-Nemo-Instruct-2407"
]
MODEL_IDS: list[ModelId] = list(typing.get_args(ModelId))

# TODO: update URLs with prod, and check context size (if we deploy it with a lower one)
Expand Down
17 changes: 4 additions & 13 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,9 @@
import asyncio

from libertai_agents.agents import ChatAgent
from libertai_agents.interfaces import Message, MessageRoleEnum
from libertai_agents.models import get_model
from libertai_agents.tools import get_current_temperature

agent = ChatAgent(model=get_model("NousResearch/Hermes-2-Pro-Llama-3-8B"),
system_prompt="You are a helpful assistant",
tools=[get_current_temperature])

async def start():
agent = ChatAgent(model=get_model("NousResearch/Hermes-2-Pro-Llama-3-8B"),
system_prompt="You are a helpful assistant",
tools=[get_current_temperature])
response = await agent.generate_answer(
[Message(role=MessageRoleEnum.user, content="What's the temperature in Paris and in Lyon in Celsius ?")])
print(response)


asyncio.run(start())
app = agent.app
Loading

0 comments on commit 0cd51d0

Please sign in to comment.