Skip to content

Commit

Permalink
remove special exception
Browse files Browse the repository at this point in the history
  • Loading branch information
vbarda committed Dec 10, 2024
1 parent 91313f3 commit d0c6690
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 46 deletions.
75 changes: 36 additions & 39 deletions libs/langgraph/langgraph/prebuilt/tool_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,6 @@
TOOL_CALL_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes."


class InvalidToolCommandError(Exception):
"""Raised when the Command returned by a tool is invalid."""


def msg_content_output(output: Any) -> Union[str, list[dict]]:
recognized_content_block_types = ("image", "image_url", "text", "json")
if isinstance(output, str):
Expand Down Expand Up @@ -302,25 +298,14 @@ def _run_one(
try:
input = {**call, **{"type": "tool_call"}}
response = self.tools_by_name[call["name"]].invoke(input)
if isinstance(response, Command):
return self._validate_tool_command(response, call, input_type)
elif isinstance(response, ToolMessage):
response.content = cast(
Union[str, list], msg_content_output(response.content)
)
return response
else:
raise TypeError(
f"Tool {call['name']} returned unexpected type: {type(response)}"
)

# GraphInterrupt is a special exception that will always be raised.
# It can be triggered in the following scenarios:
# (1) a NodeInterrupt is raised inside a tool
# (2) a NodeInterrupt is raised inside a graph node for a graph called as a tool
# (3) a GraphInterrupt is raised when a subgraph is interrupted inside a graph called as a tool
# (2 and 3 can happen in a "supervisor w/ tools" multi-agent architecture)
except (GraphBubbleUp, InvalidToolCommandError) as e:
except GraphBubbleUp as e:
raise e
except Exception as e:
if isinstance(self.handle_tool_errors, tuple):
Expand All @@ -337,10 +322,21 @@ def _run_one(
# Handled
else:
content = _handle_tool_error(e, flag=self.handle_tool_errors)
return ToolMessage(
content=content, name=call["name"], tool_call_id=call["id"], status="error"
)

return ToolMessage(
content=content, name=call["name"], tool_call_id=call["id"], status="error"
)
if isinstance(response, Command):
return self._validate_tool_command(response, call, input_type)
elif isinstance(response, ToolMessage):
response.content = cast(
Union[str, list], msg_content_output(response.content)
)
return response
else:
raise TypeError(
f"Tool {call['name']} returned unexpected type: {type(response)}"
)

async def _arun_one(
self,
Expand All @@ -354,25 +350,14 @@ async def _arun_one(
try:
input = {**call, **{"type": "tool_call"}}
response = await self.tools_by_name[call["name"]].ainvoke(input)
if isinstance(response, Command):
return self._validate_tool_command(response, call, input_type)
elif isinstance(response, ToolMessage):
response.content = cast(
Union[str, list], msg_content_output(response.content)
)
return response
else:
raise TypeError(
f"Tool {call['name']} returned unexpected type: {type(response)}"
)

# GraphInterrupt is a special exception that will always be raised.
# It can be triggered in the following scenarios:
# (1) a NodeInterrupt is raised inside a tool
# (2) a NodeInterrupt is raised inside a graph node for a graph called as a tool
# (3) a GraphInterrupt is raised when a subgraph is interrupted inside a graph called as a tool
# (2 and 3 can happen in a "supervisor w/ tools" multi-agent architecture)
except (GraphBubbleUp, InvalidToolCommandError) as e:
except GraphBubbleUp as e:
raise e
except Exception as e:
if isinstance(self.handle_tool_errors, tuple):
Expand All @@ -390,9 +375,21 @@ async def _arun_one(
else:
content = _handle_tool_error(e, flag=self.handle_tool_errors)

return ToolMessage(
content=content, name=call["name"], tool_call_id=call["id"], status="error"
)
return ToolMessage(
content=content, name=call["name"], tool_call_id=call["id"], status="error"
)

if isinstance(response, Command):
return self._validate_tool_command(response, call, input_type)
elif isinstance(response, ToolMessage):
response.content = cast(
Union[str, list], msg_content_output(response.content)
)
return response
else:
raise TypeError(
f"Tool {call['name']} returned unexpected type: {type(response)}"
)

def _parse_input(
self,
Expand Down Expand Up @@ -522,7 +519,7 @@ def _validate_tool_command(
if isinstance(command.update, dict):
# input type is dict when ToolNode is invoked with a dict input (e.g. {"messages": [AIMessage(..., tool_calls=[...])]})
if input_type != "dict":
raise InvalidToolCommandError(
raise ValueError(
f"Tools can provide a dict in Command.update only when using dict with '{self.messages_key}' key as ToolNode input, "
f"got: {command.update} for tool '{call['name']}'"
)
Expand All @@ -533,7 +530,7 @@ def _validate_tool_command(
elif isinstance(command.update, list):
# input type is list when ToolNode is invoked with a list input (e.g. [AIMessage(..., tool_calls=[...])])
if input_type != "list":
raise InvalidToolCommandError(
raise ValueError(
f"Tools can provide a list of messages in Command.update only when using list of messages as ToolNode input, "
f"got: {command.update} for tool '{call['name']}'"
)
Expand All @@ -551,12 +548,12 @@ def _validate_tool_command(
continue

if have_seen_tool_messages:
raise InvalidToolCommandError(
raise ValueError(
f"Expected at most one ToolMessage in Command.update for tool '{call['name']}', got multiple: {messages_update}."
)

if message.tool_call_id != call["id"]:
raise InvalidToolCommandError(
raise ValueError(
f"ToolMessage.tool_call_id must match the tool call id. Expected: {call['id']}, got: {message.tool_call_id} for tool '{call['name']}'."
)

Expand All @@ -570,7 +567,7 @@ def _validate_tool_command(
if input_type == "dict"
else '`Command(update=[ToolMessage("Success", tool_call_id=tool_call_id), ...], ...)`'
)
raise InvalidToolCommandError(
raise ValueError(
f"Expected exactly one message (ToolMessage) in Command.update for tool '{call['name']}', got: {messages_update}. "
"Every tool call (LLM requesting to call a tool) in the message history MUST have a corresponding ToolMessage. "
f"You can fix it by modifying the tool to return {example_update}."
Expand Down
13 changes: 6 additions & 7 deletions libs/langgraph/tests/test_prebuilt.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@
TOOL_CALL_ERROR_TEMPLATE,
InjectedState,
InjectedStore,
InvalidToolCommandError,
_get_state_args,
_infer_handled_types,
)
Expand Down Expand Up @@ -1208,7 +1207,7 @@ def add(a: int, b: int) -> int:
]

# test validation (mismatch between input type and command.update type)
with pytest.raises(InvalidToolCommandError):
with pytest.raises(ValueError):

@dec_tool
def list_update_tool(tool_call_id: Annotated[str, InjectedToolCallId]):
Expand All @@ -1231,7 +1230,7 @@ def list_update_tool(tool_call_id: Annotated[str, InjectedToolCallId]):
)

# test validation (missing tool message in the update for current graph)
with pytest.raises(InvalidToolCommandError):
with pytest.raises(ValueError):

@dec_tool
def no_update_tool():
Expand Down Expand Up @@ -1269,7 +1268,7 @@ def node_update_parent_tool():
) == [Command(update={"messages": []}, graph=Command.PARENT)]

# test validation (multiple tool messages)
with pytest.raises(InvalidToolCommandError):
with pytest.raises(ValueError):
for graph in (None, Command.PARENT):

@dec_tool
Expand Down Expand Up @@ -1485,7 +1484,7 @@ def add(a: int, b: int) -> int:
]

# test validation (mismatch between input type and command.update type)
with pytest.raises(InvalidToolCommandError):
with pytest.raises(ValueError):

@dec_tool
def list_update_tool(tool_call_id: Annotated[str, InjectedToolCallId]):
Expand All @@ -1506,7 +1505,7 @@ def list_update_tool(tool_call_id: Annotated[str, InjectedToolCallId]):
)

# test validation (missing tool message in the update for current graph)
with pytest.raises(InvalidToolCommandError):
with pytest.raises(ValueError):

@dec_tool
def no_update_tool():
Expand Down Expand Up @@ -1538,7 +1537,7 @@ def node_update_parent_tool():
) == [Command(update=[], graph=Command.PARENT)]

# test validation (multiple tool messages)
with pytest.raises(InvalidToolCommandError):
with pytest.raises(ValueError):
for graph in (None, Command.PARENT):

@dec_tool
Expand Down

0 comments on commit d0c6690

Please sign in to comment.