Skip to content

Commit

Permalink
feat: Documentation on public functions and methods
Browse files Browse the repository at this point in the history
  • Loading branch information
RezaRahemtola committed Aug 26, 2024
1 parent 0cd51d0 commit 5f9d2c8
Show file tree
Hide file tree
Showing 9 changed files with 81 additions and 37 deletions.
41 changes: 31 additions & 10 deletions libertai_agents/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
from aiohttp import ClientSession
from fastapi import APIRouter, FastAPI

from libertai_agents.interfaces.common import Message, MessageRoleEnum, LlamaCppParams, MessageToolCall, \
ToolCallFunction, ToolCallMessage, CustomizableLlamaCppParams, ToolResponseMessage
from libertai_agents.interfaces.llamacpp import CustomizableLlamaCppParams, LlamaCppParams
from libertai_agents.interfaces.messages import Message, MessageRoleEnum, MessageToolCall, ToolCallFunction, \
ToolCallMessage, ToolResponseMessage
from libertai_agents.interfaces.models import ModelInformation
from libertai_agents.models import Model
from libertai_agents.utils import find
Expand All @@ -18,10 +19,20 @@ class ChatAgent:
system_prompt: str
tools: list[Callable[..., Awaitable[Any]]]
llamacpp_params: CustomizableLlamaCppParams
app: FastAPI
app: FastAPI | None

def __init__(self, model: Model, system_prompt: str, tools: list[Callable[..., Awaitable[Any]]] | None = None,
llamacpp_params: CustomizableLlamaCppParams = CustomizableLlamaCppParams()):
llamacpp_params: CustomizableLlamaCppParams = CustomizableLlamaCppParams(),
expose_api: bool = True):
"""
Create a LibertAI chatbot agent that can answer to messages from users
:param model: The LLM you want to use, selected from the available ones
:param system_prompt: Customize the behavior of the agent with your own prompt
:param tools: List of functions that the agent can call. Each function must be asynchronous, have a docstring and return a stringifyable response
:param llamacpp_params: Override params given to llamacpp when calling the model
:param expose_api: Set at False to avoid exposing an API (useful if you are using a custom trigger)
"""
if tools is None:
tools = []

Expand All @@ -32,18 +43,28 @@ def __init__(self, model: Model, system_prompt: str, tools: list[Callable[..., A
self.tools = tools
self.llamacpp_params = llamacpp_params

# Define API routes
router = APIRouter()
router.add_api_route("/generate-answer", self.generate_answer, methods=["POST"])
router.add_api_route("/model", self.get_model_information, methods=["GET"])
if expose_api:
# Define API routes
router = APIRouter()
router.add_api_route("/generate-answer", self.generate_answer, methods=["POST"])
router.add_api_route("/model", self.get_model_information, methods=["GET"])

self.app = FastAPI(title="LibertAI ChatAgent")
self.app.include_router(router)
self.app = FastAPI(title="LibertAI ChatAgent")
self.app.include_router(router)

def get_model_information(self) -> ModelInformation:
"""
Get information about the model powering this agent
"""
return ModelInformation(id=self.model.model_id, context_length=self.model.context_length)

async def generate_answer(self, messages: list[Message]) -> str:
"""
Generate an answer based on a conversation
:param messages: List of messages previously sent in this conversation
:return: The string response of the agent
"""
if len(messages) == 0:
raise ValueError("No previous message to respond to")
if messages[-1].role not in [MessageRoleEnum.user, MessageRoleEnum.tool]:
Expand Down
9 changes: 9 additions & 0 deletions libertai_agents/interfaces/llamacpp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from pydantic import BaseModel


class CustomizableLlamaCppParams(BaseModel):
stream: bool = False


class LlamaCppParams(CustomizableLlamaCppParams):
prompt: str
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,3 @@ class ToolCallMessage(Message):
class ToolResponseMessage(Message):
name: Optional[str] = None
tool_call_id: Optional[str] = None


class CustomizableLlamaCppParams(BaseModel):
stream: bool = False


class LlamaCppParams(CustomizableLlamaCppParams):
prompt: str
15 changes: 14 additions & 1 deletion libertai_agents/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from abc import ABC, abstractmethod
from typing import Literal

from libertai_agents.interfaces.common import Message, ToolCallFunction, MessageRoleEnum
from libertai_agents.interfaces.messages import Message, ToolCallFunction, MessageRoleEnum

# Disables the error about models not available
logging.getLogger("transformers").disabled = True
Expand Down Expand Up @@ -37,6 +37,14 @@ def __count_tokens(self, content: str) -> int:
return len(tokens)

def generate_prompt(self, messages: list[Message], system_prompt: str, tools: list) -> str:
"""
Generate the whole chat prompt
:param messages: Messages conversation history
:param system_prompt: Prompt to include in the beginning
:param tools: Available tools
:return: Prompt string
"""
system_message = Message(role=MessageRoleEnum.system, content=system_prompt)
raw_messages = list(map(lambda x: x.model_dump(), messages))

Expand All @@ -52,6 +60,11 @@ def generate_prompt(self, messages: list[Message], system_prompt: str, tools: li
raise ValueError(f"Can't fit messages into the available context length ({self.context_length} tokens)")

def generate_tool_call_id(self) -> str | None:
"""
Generate a random ID for a tool call
:return: A string, or None if this model doesn't require a tool call ID
"""
return None

@staticmethod
Expand Down
6 changes: 3 additions & 3 deletions libertai_agents/models/hermes.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import json
import re

from libertai_agents.interfaces.common import ToolCallFunction
from libertai_agents.models.base import Model
from libertai_agents.interfaces.messages import ToolCallFunction
from libertai_agents.models.base import Model, ModelId


class HermesModel(Model):
def __init__(self, model_id: str, vm_url: str, context_length: int):
def __init__(self, model_id: ModelId, vm_url: str, context_length: int):
super().__init__(model_id=model_id, vm_url=vm_url, context_length=context_length)

@staticmethod
Expand Down
6 changes: 3 additions & 3 deletions libertai_agents/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
import random
import string

from libertai_agents.interfaces.common import ToolCallFunction
from libertai_agents.models.base import Model
from libertai_agents.interfaces.messages import ToolCallFunction
from libertai_agents.models.base import Model, ModelId


class MistralModel(Model):
def __init__(self, model_id: str, vm_url: str, context_length: int):
def __init__(self, model_id: ModelId, vm_url: str, context_length: int):
super().__init__(model_id=model_id, vm_url=vm_url, context_length=context_length, system_message=False)

@staticmethod
Expand Down
7 changes: 7 additions & 0 deletions libertai_agents/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@ class ModelConfiguration(BaseModel):


def get_model(model_id: ModelId, hf_token: str | None = None) -> Model:
"""
Get one of the available models
:param model_id: HuggingFace ID of the model, must be one of the supported models
:param hf_token: Optional access token, required to use gated models
:return: An instance of the model
"""
model_configuration = MODELS_CONFIG.get(model_id)

if model_configuration is None:
Expand Down
11 changes: 0 additions & 11 deletions libertai_agents/tools.py

This file was deleted.

15 changes: 14 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,19 @@
from libertai_agents.agents import ChatAgent
from libertai_agents.models import get_model
from libertai_agents.tools import get_current_temperature


async def get_current_temperature(location: str, unit: str) -> float:
"""
Get the current temperature at a location.
Args:
location: The location to get the temperature for, in the format "City, Country"
unit: The unit to return the temperature in. (choices: ["celsius", "fahrenheit"])
Returns:
The current temperature at the specified location in the specified units, as a float.
"""
return 22. # A real function should probably actually get the temperature!


agent = ChatAgent(model=get_model("NousResearch/Hermes-2-Pro-Llama-3-8B"),
system_prompt="You are a helpful assistant",
Expand Down

0 comments on commit 5f9d2c8

Please sign in to comment.