Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle multiple providers/llms #45

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 14 additions & 40 deletions ai_chatbots/chatbots.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,20 @@
import posthog
from django.conf import settings
from django.utils.module_loading import import_string
from langchain_community.chat_models import ChatLiteLLM
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.messages.ai import AIMessageChunk
from langchain_core.tools.base import BaseTool
from langgraph.constants import END
from langgraph.graph import MessagesState, StateGraph
from langgraph.graph.graph import CompiledGraph
from langgraph.prebuilt import ToolNode, create_react_agent
from langgraph.prebuilt import ToolNode, create_react_agent, tools_condition
from langgraph.prebuilt.chat_agent_executor import AgentState
from openai import BadRequestError
from typing_extensions import TypedDict

from ai_chatbots import tools
from ai_chatbots.api import ChatMemory, get_search_tool_metadata
from ai_chatbots.constants import LLMClassEnum
from ai_chatbots.tools import search_content_files

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -58,7 +57,7 @@ def __init__( # noqa: PLR0913
):
"""Initialize the AI chat agent service"""
self.bot_name = name
self.model = model or settings.AI_MODEL
self.model = model or settings.AI_DEFAULT_MODEL
self.temperature = temperature or DEFAULT_TEMPERATURE
self.instructions = instructions or self.INSTRUCTIONS
self.user_id = user_id
Expand All @@ -69,8 +68,10 @@ def __init__( # noqa: PLR0913
f"ai_chatbots.proxies.{settings.AI_PROXY_CLASS}"
)()
self.proxy.create_proxy_user(self.user_id)
self.proxy_prefix = self.proxy.PROXY_MODEL_PREFIX
else:
self.proxy = None
self.proxy_prefix = ""
self.tools = self.create_tools()
self.llm = self.get_llm()
self.agent = None
Expand All @@ -82,21 +83,17 @@ def create_tools(self):
def get_llm(self, **kwargs) -> BaseChatModel:
"""
Return the LLM instance for the chatbot.
Determine the LLM class to use based on the AI_PROVIDER setting.
Set it up to use a proxy, with required proxy kwargs, if applicable.
Bind the LLM to any tools if they are present.
"""
try:
llm_class = LLMClassEnum[settings.AI_PROVIDER].value
except KeyError:
raise NotImplementedError from KeyError
llm = llm_class(
model=self.model,
llm = ChatLiteLLM(
model=f"{self.proxy_prefix}{self.model}",
**(self.proxy.get_api_kwargs() if self.proxy else {}),
**(self.proxy.get_additional_kwargs(self) if self.proxy else {}),
**kwargs,
)
if self.temperature:
# Set the temperature if it's supported by the model
if self.temperature and self.model not in settings.AI_UNSUPPORTED_TEMP_MODELS:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new o3-mini model does not support the temperature parameter and will raise an exception if it is passed along

llm.temperature = self.temperature
# Bind tools to the LLM if any
if self.tools:
Expand Down Expand Up @@ -127,46 +124,23 @@ def create_agent_graph(self) -> CompiledGraph:
agent_node = "agent"
tools_node = "tools"

def start_agent(state: MessagesState) -> MessagesState:
def tool_calling_llm(state: MessagesState) -> MessagesState:
"""Call the LLM, injecting system prompt"""
if len(state["messages"]) == 1:
# New chat, so inject the system prompt
state["messages"].insert(0, SystemMessage(self.instructions))
return MessagesState(messages=[self.llm.invoke(state["messages"])])

def continue_on_tool_call(state: MessagesState) -> str:
"""
Define the conditional edge that determines whether
to continue or not
"""
messages = state["messages"]
last_message = messages[-1]
# Finish if no tool call is requested
if not last_message.tool_calls:
return END
# If there is, run the tool
else:
return CONTINUE

agent_graph = StateGraph(MessagesState)
# Add the agent node that first calls the LLM
agent_graph.add_node(agent_node, start_agent)
agent_graph.add_node(agent_node, tool_calling_llm)
if self.tools:
# Add the tools node
agent_graph.add_node(tools_node, ToolNode(tools=self.tools))
# Add a conditional edge that determines when to run the tools.
# If no tool call is requested, the edge is not taken and the
# agent node will end its response.
agent_graph.add_conditional_edges(
agent_node,
continue_on_tool_call,
{
# If tool requested then we call the tool node.
CONTINUE: tools_node,
# Otherwise finish.
END: END,
},
)
agent_graph.add_conditional_edges(agent_node, tools_condition)
Comment on lines -160 to +143
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This example was probably more verbose than it needed to be, the built in tools_condition function should suffice for most use cases.

# Send the tool node output back to the agent node
agent_graph.add_edge(tools_node, agent_node)
# Set the entry point to the agent node
Expand Down Expand Up @@ -385,7 +359,7 @@ def __init__( # noqa: PLR0913
super().__init__(
user_id,
name=name,
model=model,
model=model or settings.AI_DEFAULT_RECOMMENDATION_MODEL,
temperature=temperature,
instructions=instructions,
thread_id=thread_id,
Expand Down Expand Up @@ -447,7 +421,7 @@ def __init__( # noqa: PLR0913
super().__init__(
user_id,
name=name,
model=model or settings.AI_MODEL,
model=model or settings.AI_DEFAULT_SYLLABUS_MODEL,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After AWS Bedrock access is set up, the default syllabus model should be changed to bedrock/us.anthropic.claude-3-5-sonnet-20241022-v2:0

temperature=temperature,
instructions=instructions,
thread_id=thread_id,
Expand Down
53 changes: 45 additions & 8 deletions ai_chatbots/chatbots_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import pytest
from django.conf import settings
from langchain_community.chat_models import ChatLiteLLM
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableBinding

Expand All @@ -15,13 +16,13 @@
SyllabusBot,
)
from ai_chatbots.conftest import MockAsyncIterator
from ai_chatbots.constants import LLMClassEnum
from ai_chatbots.factories import (
AIMessageChunkFactory,
HumanMessageFactory,
SystemMessageFactory,
ToolMessageFactory,
)
from ai_chatbots.proxies import LiteLLMProxy
from ai_chatbots.tools import SearchToolSchema
from main.test_utils import assert_json_equal

Expand Down Expand Up @@ -85,18 +86,18 @@ def test_recommendation_bot_initialization_defaults(
temperature=temperature,
instructions=instructions,
)
assert chatbot.model == (model if model else settings.AI_MODEL)
assert chatbot.model == (
model if model else settings.AI_DEFAULT_RECOMMENDATION_MODEL
)
assert chatbot.temperature == (temperature if temperature else DEFAULT_TEMPERATURE)
assert chatbot.instructions == (
instructions if instructions else chatbot.instructions
)
worker_llm = chatbot.llm
assert (
worker_llm.__class__ == RunnableBinding
if has_tools
else LLMClassEnum.openai.value
assert worker_llm.__class__ == RunnableBinding if has_tools else ChatLiteLLM
assert worker_llm.model == (
model if model else settings.AI_DEFAULT_RECOMMENDATION_MODEL
)
assert worker_llm.model_name == (model if model else settings.AI_MODEL)


@pytest.mark.django_db
Expand Down Expand Up @@ -248,8 +249,10 @@ async def test_syllabus_bot_create_agent_graph_(mocker):
)


async def test_syllabus_bot_get_completion_state(mocker, mock_openai_astream):
@pytest.mark.parametrize("default_model", ["gpt-3.5-turbo", "gpt-4", "gpt-4o"])
async def test_syllabus_bot_get_completion_state(mock_openai_astream, default_model):
"""Proper state should get passed along by get_completion"""
settings.AI_DEFAULT_SYLLABUS_MODEL = default_model
chatbot = SyllabusBot("anonymous", name="test agent", thread_id="foo")
extra_state = {
"course_id": ["mitx1.23"],
Expand All @@ -262,6 +265,7 @@ async def test_syllabus_bot_get_completion_state(mocker, mock_openai_astream):
chatbot.config,
stream_mode="messages",
)
assert chatbot.llm.model == default_model


@pytest.mark.django_db
Expand Down Expand Up @@ -408,3 +412,36 @@ async def test_get_tool_metadata_error(mocker):
assert metadata == json.dumps(
{"error": "Error parsing tool metadata", "content": "Could not connect to api"}
)


@pytest.mark.parametrize("use_proxy", [True, False])
def test_proxy_settings(settings, mocker, use_proxy):
"""Test that the proxy settings are set correctly"""
mock_create_proxy_user = mocker.patch(
"ai_chatbots.proxies.LiteLLMProxy.create_proxy_user"
)
mock_llm = mocker.patch("ai_chatbots.chatbots.ChatLiteLLM")
settings.AI_PROXY_CLASS = "LiteLLMProxy" if use_proxy else None
settings.AI_PROXY_URL = "http://proxy.url"
settings.AI_PROXY_AUTH_TOKEN = "test" # noqa: S105
model_name = "openai/o9-turbo"
settings.AI_DEFAULT_RECOMMENDATION_MODEL = model_name
chatbot = ResourceRecommendationBot("user1")
if use_proxy:
mock_create_proxy_user.assert_called_once_with("user1")
assert chatbot.proxy_prefix == LiteLLMProxy.PROXY_MODEL_PREFIX
assert isinstance(chatbot.proxy, LiteLLMProxy)
mock_llm.assert_called_once_with(
model=f"{LiteLLMProxy.PROXY_MODEL_PREFIX}{model_name}",
**chatbot.proxy.get_api_kwargs(),
**chatbot.proxy.get_additional_kwargs(chatbot),
)
else:
mock_create_proxy_user.assert_not_called()
assert chatbot.proxy_prefix == ""
assert chatbot.proxy is None
mock_llm.assert_called_once_with(
model=model_name,
**{},
**{},
)
12 changes: 0 additions & 12 deletions ai_chatbots/constants.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,12 @@
"""Constants for the AI Chat application."""

from langchain_openai import ChatOpenAI
from named_enum import ExtendedEnum

GROUP_STAFF_AI_SYTEM_PROMPT_EDITORS = "ai_system_prompt_editors"
AI_ANONYMOUS_USER = "anonymous"
AI_THREAD_COOKIE_KEY = "ai_thread_id"


class LLMClassEnum(ExtendedEnum):
"""
Enum for determining which LLM class to
use based on settings.AI_PROVIDER. For example,
if AI_PROVIDER == "openai", the OpenAI LLM class
should be used.
"""

openai = ChatOpenAI


class LearningResourceType(ExtendedEnum):
"""Enum for LearningResource resource_type values"""

Expand Down
4 changes: 2 additions & 2 deletions ai_chatbots/consumers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ async def test_recommend_agent_handle( # noqa: PLR0913
assert recommendation_consumer.bot.llm.temperature == (
temperature if temperature else settings.AI_DEFAULT_TEMPERATURE
)
assert recommendation_consumer.bot.llm.model_name == (
model if model else settings.AI_MODEL
assert recommendation_consumer.bot.llm.model == (
model if model else settings.AI_DEFAULT_RECOMMENDATION_MODEL
)
assert recommendation_consumer.bot.instructions == (
instructions if instructions else ResourceRecommendationBot.INSTRUCTIONS
Expand Down
23 changes: 22 additions & 1 deletion ai_chatbots/proxies.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ class AIProxy(ABC):

REQUIRED_SETTINGS = []

# if the proxy model needs to be prefixed, set this here
PROXY_MODEL_PREFIX = ""

def __init__(self):
"""Raise an error if required settings are missing."""
missing_settings = [
Expand Down Expand Up @@ -46,16 +49,34 @@ class LiteLLMProxy(AIProxy):
"""Helper class for the Lite LLM proxy."""

REQUIRED_SETTINGS = ("AI_PROXY_URL", "AI_PROXY_AUTH_TOKEN")
PROXY_MODEL_PREFIX = "litellm_proxy/"

def get_api_kwargs(
self, base_url_key: str = "base_url", api_key_key: str = "openai_api_key"
self, base_url_key: str = "api_base", api_key_key: str = "openai_api_key"
) -> dict:
"""
Get the required API kwargs to connect to the Lite LLM proxy.
When using the ChatLiteLLM class, these kwargs should be
"api_base" and "openai_api_key".

Args:
base_url_key (str): The key to pass in the proxy API URL.
api_key_key (str): The key to pass in the proxy authentication token.

Returns:
dict: The proxy API kwargs.
"""

return {
f"{base_url_key}": settings.AI_PROXY_URL,
f"{api_key_key}": settings.AI_PROXY_AUTH_TOKEN,
}

def get_additional_kwargs(self, service: BaseChatbot) -> dict:
"""
Get additional kwargs to send to the Lite LLM proxy, such
as user_id and job/task identification.
"""
return {
"user": service.user_id,
"store": True,
Expand Down
5 changes: 1 addition & 4 deletions ai_chatbots/serializers.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
"""Serializers for the ai_chatbots app"""

from django.conf import settings
from rest_framework import serializers


class ChatRequestSerializer(serializers.Serializer):
"""Serializer for chatbot requests"""

message = serializers.CharField(required=True, allow_blank=False)
model = serializers.CharField(
default=settings.AI_MODEL, required=False, allow_blank=True
)
model = serializers.CharField(required=False, allow_blank=True)
temperature = serializers.FloatField(
min_value=0.0,
max_value=1.0,
Expand Down
Loading
Loading