From 5144b8f374dd18d7ebd8ab6b75b73685b8e1ea62 Mon Sep 17 00:00:00 2001 From: Vadym Barda Date: Wed, 27 Nov 2024 12:54:59 -0500 Subject: [PATCH] langgraph: allow create_react_agent to take empty tools (#2553) --- .../langgraph/prebuilt/chat_agent_executor.py | 40 +++++++++++++------ libs/langgraph/tests/test_prebuilt.py | 3 ++ 2 files changed, 31 insertions(+), 12 deletions(-) diff --git a/libs/langgraph/langgraph/prebuilt/chat_agent_executor.py b/libs/langgraph/langgraph/prebuilt/chat_agent_executor.py index fc812ccbc..4c8699360 100644 --- a/libs/langgraph/langgraph/prebuilt/chat_agent_executor.py +++ b/libs/langgraph/langgraph/prebuilt/chat_agent_executor.py @@ -212,6 +212,7 @@ def create_react_agent( Args: model: The `LangChain` chat model that supports tool calling. tools: A list of tools, a ToolExecutor, or a ToolNode instance. + If an empty list is provided, the agent will consist of a single LLM node without tool calling. state_schema: An optional state schema that defines graph state. Must have `messages` and `is_last_step` keys. Defaults to `AgentState` that defines those two keys. @@ -540,19 +541,10 @@ class Agent,Tools otherClass # get the tool functions wrapped in a tool class from the ToolNode tool_classes = list(tool_node.tools_by_name.values()) - if _should_bind_tools(model, tool_classes): - model = cast(BaseChatModel, model).bind_tools(tool_classes) + tool_calling_enabled = len(tool_classes) > 0 - # Define the function that determines whether to continue or not - def should_continue(state: AgentState) -> Literal["tools", "__end__"]: - messages = state["messages"] - last_message = messages[-1] - # If there is no function call, then we finish - if not isinstance(last_message, AIMessage) or not last_message.tool_calls: - return "__end__" - # Otherwise if there is, we continue - else: - return "tools" + if _should_bind_tools(model, tool_classes) and tool_calling_enabled: + model = cast(BaseChatModel, model).bind_tools(tool_classes) # we're passing store here for validation preprocessor = _get_model_preprocessing_runnable( @@ -635,6 +627,30 @@ async def acall_model(state: AgentState, config: RunnableConfig) -> AgentState: # We return a list, because this will get added to the existing list return {"messages": [response]} + if not tool_calling_enabled: + # Define a new graph + workflow = StateGraph(state_schema or AgentState) + workflow.add_node("agent", RunnableCallable(call_model, acall_model)) + workflow.set_entry_point("agent") + return workflow.compile( + checkpointer=checkpointer, + store=store, + interrupt_before=interrupt_before, + interrupt_after=interrupt_after, + debug=debug, + ) + + # Define the function that determines whether to continue or not + def should_continue(state: AgentState) -> Literal["tools", "__end__"]: + messages = state["messages"] + last_message = messages[-1] + # If there is no function call, then we finish + if not isinstance(last_message, AIMessage) or not last_message.tool_calls: + return "__end__" + # Otherwise if there is, we continue + else: + return "tools" + # Define a new graph workflow = StateGraph(state_schema or AgentState) diff --git a/libs/langgraph/tests/test_prebuilt.py b/libs/langgraph/tests/test_prebuilt.py index a6655a451..0997668b2 100644 --- a/libs/langgraph/tests/test_prebuilt.py +++ b/libs/langgraph/tests/test_prebuilt.py @@ -102,6 +102,9 @@ def bind_tools( tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]], **kwargs: Any, ) -> Runnable[LanguageModelInput, BaseMessage]: + if len(tools) == 0: + raise ValueError("Must provide at least one tool") + tool_dicts = [] for tool in tools: if not isinstance(tool, BaseTool):