Skip to content

Commit

Permalink
langgraph: allow create_react_agent to take empty tools (#2553)
Browse files Browse the repository at this point in the history
  • Loading branch information
vbarda authored Nov 27, 2024
1 parent 4b1b3ce commit 5144b8f
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 12 deletions.
40 changes: 28 additions & 12 deletions libs/langgraph/langgraph/prebuilt/chat_agent_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ def create_react_agent(
Args:
model: The `LangChain` chat model that supports tool calling.
tools: A list of tools, a ToolExecutor, or a ToolNode instance.
If an empty list is provided, the agent will consist of a single LLM node without tool calling.
state_schema: An optional state schema that defines graph state.
Must have `messages` and `is_last_step` keys.
Defaults to `AgentState` that defines those two keys.
Expand Down Expand Up @@ -540,19 +541,10 @@ class Agent,Tools otherClass
# get the tool functions wrapped in a tool class from the ToolNode
tool_classes = list(tool_node.tools_by_name.values())

if _should_bind_tools(model, tool_classes):
model = cast(BaseChatModel, model).bind_tools(tool_classes)
tool_calling_enabled = len(tool_classes) > 0

# Define the function that determines whether to continue or not
def should_continue(state: AgentState) -> Literal["tools", "__end__"]:
messages = state["messages"]
last_message = messages[-1]
# If there is no function call, then we finish
if not isinstance(last_message, AIMessage) or not last_message.tool_calls:
return "__end__"
# Otherwise if there is, we continue
else:
return "tools"
if _should_bind_tools(model, tool_classes) and tool_calling_enabled:
model = cast(BaseChatModel, model).bind_tools(tool_classes)

# we're passing store here for validation
preprocessor = _get_model_preprocessing_runnable(
Expand Down Expand Up @@ -635,6 +627,30 @@ async def acall_model(state: AgentState, config: RunnableConfig) -> AgentState:
# We return a list, because this will get added to the existing list
return {"messages": [response]}

if not tool_calling_enabled:
# Define a new graph
workflow = StateGraph(state_schema or AgentState)
workflow.add_node("agent", RunnableCallable(call_model, acall_model))
workflow.set_entry_point("agent")
return workflow.compile(
checkpointer=checkpointer,
store=store,
interrupt_before=interrupt_before,
interrupt_after=interrupt_after,
debug=debug,
)

# Define the function that determines whether to continue or not
def should_continue(state: AgentState) -> Literal["tools", "__end__"]:
messages = state["messages"]
last_message = messages[-1]
# If there is no function call, then we finish
if not isinstance(last_message, AIMessage) or not last_message.tool_calls:
return "__end__"
# Otherwise if there is, we continue
else:
return "tools"

# Define a new graph
workflow = StateGraph(state_schema or AgentState)

Expand Down
3 changes: 3 additions & 0 deletions libs/langgraph/tests/test_prebuilt.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,9 @@ def bind_tools(
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
if len(tools) == 0:
raise ValueError("Must provide at least one tool")

tool_dicts = []
for tool in tools:
if not isinstance(tool, BaseTool):
Expand Down

0 comments on commit 5144b8f

Please sign in to comment.