diff --git a/libertai_agents/agents.py b/libertai_agents/agents.py index d96919b..5fa4436 100644 --- a/libertai_agents/agents.py +++ b/libertai_agents/agents.py @@ -6,8 +6,9 @@ from aiohttp import ClientSession from fastapi import APIRouter, FastAPI -from libertai_agents.interfaces.common import Message, MessageRoleEnum, LlamaCppParams, MessageToolCall, \ - ToolCallFunction, ToolCallMessage, CustomizableLlamaCppParams, ToolResponseMessage +from libertai_agents.interfaces.llamacpp import CustomizableLlamaCppParams, LlamaCppParams +from libertai_agents.interfaces.messages import Message, MessageRoleEnum, MessageToolCall, ToolCallFunction, \ + ToolCallMessage, ToolResponseMessage from libertai_agents.interfaces.models import ModelInformation from libertai_agents.models import Model from libertai_agents.utils import find @@ -18,10 +19,20 @@ class ChatAgent: system_prompt: str tools: list[Callable[..., Awaitable[Any]]] llamacpp_params: CustomizableLlamaCppParams - app: FastAPI + app: FastAPI | None def __init__(self, model: Model, system_prompt: str, tools: list[Callable[..., Awaitable[Any]]] | None = None, - llamacpp_params: CustomizableLlamaCppParams = CustomizableLlamaCppParams()): + llamacpp_params: CustomizableLlamaCppParams = CustomizableLlamaCppParams(), + expose_api: bool = True): + """ + Create a LibertAI chatbot agent that can answer to messages from users + + :param model: The LLM you want to use, selected from the available ones + :param system_prompt: Customize the behavior of the agent with your own prompt + :param tools: List of functions that the agent can call. Each function must be asynchronous, have a docstring and return a stringifyable response + :param llamacpp_params: Override params given to llamacpp when calling the model + :param expose_api: Set at False to avoid exposing an API (useful if you are using a custom trigger) + """ if tools is None: tools = [] @@ -32,18 +43,28 @@ def __init__(self, model: Model, system_prompt: str, tools: list[Callable[..., A self.tools = tools self.llamacpp_params = llamacpp_params - # Define API routes - router = APIRouter() - router.add_api_route("/generate-answer", self.generate_answer, methods=["POST"]) - router.add_api_route("/model", self.get_model_information, methods=["GET"]) + if expose_api: + # Define API routes + router = APIRouter() + router.add_api_route("/generate-answer", self.generate_answer, methods=["POST"]) + router.add_api_route("/model", self.get_model_information, methods=["GET"]) - self.app = FastAPI(title="LibertAI ChatAgent") - self.app.include_router(router) + self.app = FastAPI(title="LibertAI ChatAgent") + self.app.include_router(router) def get_model_information(self) -> ModelInformation: + """ + Get information about the model powering this agent + """ return ModelInformation(id=self.model.model_id, context_length=self.model.context_length) async def generate_answer(self, messages: list[Message]) -> str: + """ + Generate an answer based on a conversation + + :param messages: List of messages previously sent in this conversation + :return: The string response of the agent + """ if len(messages) == 0: raise ValueError("No previous message to respond to") if messages[-1].role not in [MessageRoleEnum.user, MessageRoleEnum.tool]: diff --git a/libertai_agents/interfaces/llamacpp.py b/libertai_agents/interfaces/llamacpp.py new file mode 100644 index 0000000..c7f257e --- /dev/null +++ b/libertai_agents/interfaces/llamacpp.py @@ -0,0 +1,9 @@ +from pydantic import BaseModel + + +class CustomizableLlamaCppParams(BaseModel): + stream: bool = False + + +class LlamaCppParams(CustomizableLlamaCppParams): + prompt: str diff --git a/libertai_agents/interfaces/common.py b/libertai_agents/interfaces/messages.py similarity index 82% rename from libertai_agents/interfaces/common.py rename to libertai_agents/interfaces/messages.py index 470c14a..dfdddce 100644 --- a/libertai_agents/interfaces/common.py +++ b/libertai_agents/interfaces/messages.py @@ -34,11 +34,3 @@ class ToolCallMessage(Message): class ToolResponseMessage(Message): name: Optional[str] = None tool_call_id: Optional[str] = None - - -class CustomizableLlamaCppParams(BaseModel): - stream: bool = False - - -class LlamaCppParams(CustomizableLlamaCppParams): - prompt: str diff --git a/libertai_agents/models/base.py b/libertai_agents/models/base.py index 295eb9e..1a084cc 100644 --- a/libertai_agents/models/base.py +++ b/libertai_agents/models/base.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from typing import Literal -from libertai_agents.interfaces.common import Message, ToolCallFunction, MessageRoleEnum +from libertai_agents.interfaces.messages import Message, ToolCallFunction, MessageRoleEnum # Disables the error about models not available logging.getLogger("transformers").disabled = True @@ -37,6 +37,14 @@ def __count_tokens(self, content: str) -> int: return len(tokens) def generate_prompt(self, messages: list[Message], system_prompt: str, tools: list) -> str: + """ + Generate the whole chat prompt + + :param messages: Messages conversation history + :param system_prompt: Prompt to include in the beginning + :param tools: Available tools + :return: Prompt string + """ system_message = Message(role=MessageRoleEnum.system, content=system_prompt) raw_messages = list(map(lambda x: x.model_dump(), messages)) @@ -52,6 +60,11 @@ def generate_prompt(self, messages: list[Message], system_prompt: str, tools: li raise ValueError(f"Can't fit messages into the available context length ({self.context_length} tokens)") def generate_tool_call_id(self) -> str | None: + """ + Generate a random ID for a tool call + + :return: A string, or None if this model doesn't require a tool call ID + """ return None @staticmethod diff --git a/libertai_agents/models/hermes.py b/libertai_agents/models/hermes.py index 3ca37ee..8e85c6a 100644 --- a/libertai_agents/models/hermes.py +++ b/libertai_agents/models/hermes.py @@ -1,12 +1,12 @@ import json import re -from libertai_agents.interfaces.common import ToolCallFunction -from libertai_agents.models.base import Model +from libertai_agents.interfaces.messages import ToolCallFunction +from libertai_agents.models.base import Model, ModelId class HermesModel(Model): - def __init__(self, model_id: str, vm_url: str, context_length: int): + def __init__(self, model_id: ModelId, vm_url: str, context_length: int): super().__init__(model_id=model_id, vm_url=vm_url, context_length=context_length) @staticmethod diff --git a/libertai_agents/models/mistral.py b/libertai_agents/models/mistral.py index 3bf424f..ed98841 100644 --- a/libertai_agents/models/mistral.py +++ b/libertai_agents/models/mistral.py @@ -2,12 +2,12 @@ import random import string -from libertai_agents.interfaces.common import ToolCallFunction -from libertai_agents.models.base import Model +from libertai_agents.interfaces.messages import ToolCallFunction +from libertai_agents.models.base import Model, ModelId class MistralModel(Model): - def __init__(self, model_id: str, vm_url: str, context_length: int): + def __init__(self, model_id: ModelId, vm_url: str, context_length: int): super().__init__(model_id=model_id, vm_url=vm_url, context_length=context_length, system_message=False) @staticmethod diff --git a/libertai_agents/models/models.py b/libertai_agents/models/models.py index 67c4b87..b06a927 100644 --- a/libertai_agents/models/models.py +++ b/libertai_agents/models/models.py @@ -32,6 +32,13 @@ class ModelConfiguration(BaseModel): def get_model(model_id: ModelId, hf_token: str | None = None) -> Model: + """ + Get one of the available models + + :param model_id: HuggingFace ID of the model, must be one of the supported models + :param hf_token: Optional access token, required to use gated models + :return: An instance of the model + """ model_configuration = MODELS_CONFIG.get(model_id) if model_configuration is None: diff --git a/libertai_agents/tools.py b/libertai_agents/tools.py deleted file mode 100644 index c859a3d..0000000 --- a/libertai_agents/tools.py +++ /dev/null @@ -1,11 +0,0 @@ -async def get_current_temperature(location: str, unit: str) -> float: - """ - Get the current temperature at a location. - - Args: - location: The location to get the temperature for, in the format "City, Country" - unit: The unit to return the temperature in. (choices: ["celsius", "fahrenheit"]) - Returns: - The current temperature at the specified location in the specified units, as a float. - """ - return 22. # A real function should probably actually get the temperature! diff --git a/main.py b/main.py index 50107f3..e451bb1 100644 --- a/main.py +++ b/main.py @@ -1,6 +1,19 @@ from libertai_agents.agents import ChatAgent from libertai_agents.models import get_model -from libertai_agents.tools import get_current_temperature + + +async def get_current_temperature(location: str, unit: str) -> float: + """ + Get the current temperature at a location. + + Args: + location: The location to get the temperature for, in the format "City, Country" + unit: The unit to return the temperature in. (choices: ["celsius", "fahrenheit"]) + Returns: + The current temperature at the specified location in the specified units, as a float. + """ + return 22. # A real function should probably actually get the temperature! + agent = ChatAgent(model=get_model("NousResearch/Hermes-2-Pro-Llama-3-8B"), system_prompt="You are a helpful assistant",