Skip to content

Commit

Permalink
langgraph: relax constraints in ToolNode Command validation (#2778)
Browse files Browse the repository at this point in the history
  • Loading branch information
vbarda authored Dec 16, 2024
1 parent d2794ed commit d9b7aaa
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 80 deletions.
24 changes: 8 additions & 16 deletions libs/langgraph/langgraph/prebuilt/tool_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,33 +548,25 @@ def _validate_tool_command(

# convert to message objects if updates are in a dict format
messages_update = convert_to_messages(messages_update)
have_seen_tool_messages = False
has_matching_tool_message = False
for message in messages_update:
if not isinstance(message, ToolMessage):
continue

if have_seen_tool_messages:
raise ValueError(
f"Expected at most one ToolMessage in Command.update for tool '{call['name']}', got multiple: {messages_update}."
)

if message.tool_call_id != call["id"]:
raise ValueError(
f"ToolMessage.tool_call_id must match the tool call id. Expected: {call['id']}, got: {message.tool_call_id} for tool '{call['name']}'."
)

message.name = call["name"]
have_seen_tool_messages = True
if message.tool_call_id == call["id"]:
message.name = call["name"]
has_matching_tool_message = True

# validate that we always have exactly one ToolMessage in Command.update if command is sent to the CURRENT graph
if updated_command.graph is None and not have_seen_tool_messages:
# validate that we always have a ToolMessage matching the tool call in
# Command.update if command is sent to the CURRENT graph
if updated_command.graph is None and not has_matching_tool_message:
example_update = (
'`Command(update={"messages": [ToolMessage("Success", tool_call_id=tool_call_id), ...]}, ...)`'
if input_type == "dict"
else '`Command(update=[ToolMessage("Success", tool_call_id=tool_call_id), ...], ...)`'
)
raise ValueError(
f"Expected exactly one message (ToolMessage) in Command.update for tool '{call['name']}', got: {messages_update}. "
f"Expected to have a matching ToolMessage in Command.update for tool '{call['name']}', got: {messages_update}. "
"Every tool call (LLM requesting to call a tool) in the message history MUST have a corresponding ToolMessage. "
f"You can fix it by modifying the tool to return {example_update}."
)
Expand Down
110 changes: 46 additions & 64 deletions libs/langgraph/tests/test_prebuilt.py
Original file line number Diff line number Diff line change
Expand Up @@ -1249,6 +1249,33 @@ def no_update_tool():
}
)

# test validation (tool message with a wrong tool call ID)
with pytest.raises(ValueError):

@dec_tool
def mismatching_tool_call_id_tool():
"""My tool"""
return Command(
update={"messages": [ToolMessage(content="foo", tool_call_id="2")]}
)

ToolNode([mismatching_tool_call_id_tool]).invoke(
{
"messages": [
AIMessage(
"",
tool_calls=[
{
"args": {},
"id": "1",
"name": "mismatching_tool_call_id_tool",
}
],
)
]
}
)

# test validation (missing tool message in the update for parent graph is OK)
@dec_tool
def node_update_parent_tool():
Expand All @@ -1268,40 +1295,6 @@ def node_update_parent_tool():
}
) == [Command(update={"messages": []}, graph=Command.PARENT)]

# test validation (multiple tool messages)
with pytest.raises(ValueError):
for graph in (None, Command.PARENT):

@dec_tool
def multiple_tool_messages_tool():
"""My tool"""
return Command(
update={
"messages": [
ToolMessage(content="foo", tool_call_id=""),
ToolMessage(content="bar", tool_call_id=""),
]
},
graph=graph,
)

ToolNode([multiple_tool_messages_tool]).invoke(
{
"messages": [
AIMessage(
"",
tool_calls=[
{
"args": {},
"id": "1",
"name": "multiple_tool_messages_tool",
}
],
)
]
}
)


@pytest.mark.skipif(
not IS_LANGCHAIN_CORE_030_OR_GREATER,
Expand Down Expand Up @@ -1524,6 +1517,25 @@ def no_update_tool():
]
)

# test validation (tool message with a wrong tool call ID)
with pytest.raises(ValueError):

@dec_tool
def mismatching_tool_call_id_tool():
"""My tool"""
return Command(update=[ToolMessage(content="foo", tool_call_id="2")])

ToolNode([mismatching_tool_call_id_tool]).invoke(
[
AIMessage(
"",
tool_calls=[
{"args": {}, "id": "1", "name": "mismatching_tool_call_id_tool"}
],
)
]
)

# test validation (missing tool message in the update for parent graph is OK)
@dec_tool
def node_update_parent_tool():
Expand All @@ -1539,36 +1551,6 @@ def node_update_parent_tool():
]
) == [Command(update=[], graph=Command.PARENT)]

# test validation (multiple tool messages)
with pytest.raises(ValueError):
for graph in (None, Command.PARENT):

@dec_tool
def multiple_tool_messages_tool():
"""My tool"""
return Command(
update=[
ToolMessage(content="foo", tool_call_id=""),
ToolMessage(content="bar", tool_call_id=""),
],
graph=graph,
)

ToolNode([multiple_tool_messages_tool]).invoke(
[
AIMessage(
"",
tool_calls=[
{
"args": {},
"id": "1",
"name": "multiple_tool_messages_tool",
}
],
)
]
)


@pytest.mark.skipif(
not IS_LANGCHAIN_CORE_030_OR_GREATER,
Expand Down

0 comments on commit d9b7aaa

Please sign in to comment.