-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Handle multiple providers/llms #45
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,21 +11,20 @@ | |
import posthog | ||
from django.conf import settings | ||
from django.utils.module_loading import import_string | ||
from langchain_community.chat_models import ChatLiteLLM | ||
from langchain_core.language_models.chat_models import BaseChatModel | ||
from langchain_core.messages import HumanMessage, SystemMessage | ||
from langchain_core.messages.ai import AIMessageChunk | ||
from langchain_core.tools.base import BaseTool | ||
from langgraph.constants import END | ||
from langgraph.graph import MessagesState, StateGraph | ||
from langgraph.graph.graph import CompiledGraph | ||
from langgraph.prebuilt import ToolNode, create_react_agent | ||
from langgraph.prebuilt import ToolNode, create_react_agent, tools_condition | ||
from langgraph.prebuilt.chat_agent_executor import AgentState | ||
from openai import BadRequestError | ||
from typing_extensions import TypedDict | ||
|
||
from ai_chatbots import tools | ||
from ai_chatbots.api import ChatMemory, get_search_tool_metadata | ||
from ai_chatbots.constants import LLMClassEnum | ||
from ai_chatbots.tools import search_content_files | ||
|
||
log = logging.getLogger(__name__) | ||
|
@@ -58,7 +57,7 @@ def __init__( # noqa: PLR0913 | |
): | ||
"""Initialize the AI chat agent service""" | ||
self.bot_name = name | ||
self.model = model or settings.AI_MODEL | ||
self.model = model or settings.AI_DEFAULT_MODEL | ||
self.temperature = temperature or DEFAULT_TEMPERATURE | ||
self.instructions = instructions or self.INSTRUCTIONS | ||
self.user_id = user_id | ||
|
@@ -69,8 +68,10 @@ def __init__( # noqa: PLR0913 | |
f"ai_chatbots.proxies.{settings.AI_PROXY_CLASS}" | ||
)() | ||
self.proxy.create_proxy_user(self.user_id) | ||
self.proxy_prefix = self.proxy.PROXY_MODEL_PREFIX | ||
else: | ||
self.proxy = None | ||
self.proxy_prefix = "" | ||
self.tools = self.create_tools() | ||
self.llm = self.get_llm() | ||
self.agent = None | ||
|
@@ -82,21 +83,17 @@ def create_tools(self): | |
def get_llm(self, **kwargs) -> BaseChatModel: | ||
""" | ||
Return the LLM instance for the chatbot. | ||
Determine the LLM class to use based on the AI_PROVIDER setting. | ||
Set it up to use a proxy, with required proxy kwargs, if applicable. | ||
Bind the LLM to any tools if they are present. | ||
""" | ||
try: | ||
llm_class = LLMClassEnum[settings.AI_PROVIDER].value | ||
except KeyError: | ||
raise NotImplementedError from KeyError | ||
llm = llm_class( | ||
model=self.model, | ||
llm = ChatLiteLLM( | ||
model=f"{self.proxy_prefix}{self.model}", | ||
**(self.proxy.get_api_kwargs() if self.proxy else {}), | ||
**(self.proxy.get_additional_kwargs(self) if self.proxy else {}), | ||
**kwargs, | ||
) | ||
if self.temperature: | ||
# Set the temperature if it's supported by the model | ||
if self.temperature and self.model not in settings.AI_UNSUPPORTED_TEMP_MODELS: | ||
llm.temperature = self.temperature | ||
# Bind tools to the LLM if any | ||
if self.tools: | ||
|
@@ -127,46 +124,23 @@ def create_agent_graph(self) -> CompiledGraph: | |
agent_node = "agent" | ||
tools_node = "tools" | ||
|
||
def start_agent(state: MessagesState) -> MessagesState: | ||
def tool_calling_llm(state: MessagesState) -> MessagesState: | ||
"""Call the LLM, injecting system prompt""" | ||
if len(state["messages"]) == 1: | ||
# New chat, so inject the system prompt | ||
state["messages"].insert(0, SystemMessage(self.instructions)) | ||
return MessagesState(messages=[self.llm.invoke(state["messages"])]) | ||
|
||
def continue_on_tool_call(state: MessagesState) -> str: | ||
""" | ||
Define the conditional edge that determines whether | ||
to continue or not | ||
""" | ||
messages = state["messages"] | ||
last_message = messages[-1] | ||
# Finish if no tool call is requested | ||
if not last_message.tool_calls: | ||
return END | ||
# If there is, run the tool | ||
else: | ||
return CONTINUE | ||
|
||
agent_graph = StateGraph(MessagesState) | ||
# Add the agent node that first calls the LLM | ||
agent_graph.add_node(agent_node, start_agent) | ||
agent_graph.add_node(agent_node, tool_calling_llm) | ||
if self.tools: | ||
# Add the tools node | ||
agent_graph.add_node(tools_node, ToolNode(tools=self.tools)) | ||
# Add a conditional edge that determines when to run the tools. | ||
# If no tool call is requested, the edge is not taken and the | ||
# agent node will end its response. | ||
agent_graph.add_conditional_edges( | ||
agent_node, | ||
continue_on_tool_call, | ||
{ | ||
# If tool requested then we call the tool node. | ||
CONTINUE: tools_node, | ||
# Otherwise finish. | ||
END: END, | ||
}, | ||
) | ||
agent_graph.add_conditional_edges(agent_node, tools_condition) | ||
Comment on lines
-160
to
+143
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This example was probably more verbose than it needed to be, the built in |
||
# Send the tool node output back to the agent node | ||
agent_graph.add_edge(tools_node, agent_node) | ||
# Set the entry point to the agent node | ||
|
@@ -385,7 +359,7 @@ def __init__( # noqa: PLR0913 | |
super().__init__( | ||
user_id, | ||
name=name, | ||
model=model, | ||
model=model or settings.AI_DEFAULT_RECOMMENDATION_MODEL, | ||
temperature=temperature, | ||
instructions=instructions, | ||
thread_id=thread_id, | ||
|
@@ -447,7 +421,7 @@ def __init__( # noqa: PLR0913 | |
super().__init__( | ||
user_id, | ||
name=name, | ||
model=model or settings.AI_MODEL, | ||
model=model or settings.AI_DEFAULT_SYLLABUS_MODEL, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. After AWS Bedrock access is set up, the default syllabus model should be changed to |
||
temperature=temperature, | ||
instructions=instructions, | ||
thread_id=thread_id, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new
o3-mini
model does not support the temperature parameter and will raise an exception if it is passed along