Skip to content

Commit

Permalink
factor out + test
Browse files Browse the repository at this point in the history
  • Loading branch information
vbarda committed Dec 6, 2024
1 parent 681fbcf commit d35c1fa
Show file tree
Hide file tree
Showing 2 changed files with 219 additions and 68 deletions.
113 changes: 45 additions & 68 deletions libs/langgraph/langgraph/prebuilt/tool_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,40 +309,9 @@ def _run_one(
# invoke the tool with raw args to return raw value instead of a ToolMessage
response = tool.invoke(call["args"])
if isinstance(response, Command):
updated_command = deepcopy(response)
if isinstance(updated_command.update, dict):
if output_type != "dict":
raise ValueError(
f"When using dict with '{self.messages_key}' key as ToolNode input, tools must provide a dict in Command.update, got: {response.update} for tool '{call['name']}'"
)

state_update = updated_command.update or {}
messages_update = state_update.get(self.messages_key, [])
else:
if output_type != "list":
raise ValueError(
f"When using list of messages as ToolNode input, tools must provide `[('__root__', update)]` in Command.update, got: {response.update} for tool '{call['name']}'"
)

channels, messages_updates = zip(*updated_command.update)
if len(channels) != 1 or channels[0] != "__root__":
raise ValueError(
f"When using list of messages as ToolNode input, Command.update can only contain a single update in the following format: `[('__root__', update)]`, got: {updated_command.update} for tool '{call['name']}'"
)

messages_update = messages_updates[0]

if len(messages_update) != 1 or not isinstance(
messages_update[0], ToolMessage
):
raise ValueError(
f"Expected exactly one ToolMessage in Command.update for tool '{call['name']}', got: {messages_update}"
)

tool_message = messages_update[0]
tool_message.name = call["name"]
tool_message.tool_call_id = cast(str, call["id"])
return updated_command
return self._add_tool_call_name_and_id_to_command(
response, call, output_type
)
else:
return ToolMessage(
content=cast(Union[str, list], msg_content_output(response)),
Expand Down Expand Up @@ -402,40 +371,9 @@ async def _arun_one(
# invoke the tool with raw args to return raw value instead of a ToolMessage
response = await tool.ainvoke(call["args"])
if isinstance(response, Command):
updated_command = deepcopy(response)
if isinstance(updated_command.update, dict):
if output_type != "dict":
raise ValueError(
f"When using dict with '{self.messages_key}' key as ToolNode input, tools must provide a dict in Command.update, got: {response.update} for tool '{call['name']}'"
)

state_update = updated_command.update or {}
messages_update = state_update.get(self.messages_key, [])
else:
if output_type != "list":
raise ValueError(
f"When using list of messages as ToolNode input, tools must provide `[('__root__', update)]` in Command.update, got: {response.update} for tool '{call['name']}'"
)

channels, messages_updates = zip(*updated_command.update)
if len(channels) != 1 or channels[0] != "__root__":
raise ValueError(
f"When using list of messages as ToolNode input, Command.update can only contain a single update in the following format: `[('__root__', update)]`, got: {updated_command.update} for tool '{call['name']}'"
)

messages_update = messages_updates[0]

if len(messages_update) != 1 or not isinstance(
messages_update[0], ToolMessage
):
raise ValueError(
f"Expected exactly one ToolMessage in Command.update for tool '{call['name']}', got: {messages_update}"
)

tool_message = messages_update[0]
tool_message.name = call["name"]
tool_message.tool_call_id = cast(str, call["id"])
return updated_command
return self._add_tool_call_name_and_id_to_command(
response, call, output_type
)
else:
return ToolMessage(
content=cast(Union[str, list], msg_content_output(response)),
Expand Down Expand Up @@ -592,6 +530,45 @@ def _inject_tool_args(
tool_call_with_store = self._inject_store(tool_call_with_state, store)
return tool_call_with_store

def _add_tool_call_name_and_id_to_command(
self, command: Command, call: ToolCall, output_type: Literal["list", "dict"]
) -> Command:
if isinstance(command.update, dict):
if output_type != "dict":
raise ValueError(
f"When using dict with '{self.messages_key}' key as ToolNode input, tools must provide a dict in Command.update, got: {command.update} for tool '{call['name']}'"
)

updated_command = deepcopy(command)
state_update = cast(dict[str, Any], updated_command.update) or {}
messages_update = state_update.get(self.messages_key, [])
elif isinstance(command.update, list):
if output_type != "list":
raise ValueError(
f"When using list of messages as ToolNode input, tools must provide `[('__root__', update)]` in Command.update, got: {command.update} for tool '{call['name']}'"
)

updated_command = deepcopy(command)
channels, messages_updates = zip(*updated_command.update)
if len(channels) != 1 or channels[0] != "__root__":
raise ValueError(
f"When using list of messages as ToolNode input, Command.update can only contain a single update in the following format: `[('__root__', update)]`, got: {updated_command.update} for tool '{call['name']}'"
)

messages_update = messages_updates[0]
else:
return command

if len(messages_update) != 1 or not isinstance(messages_update[0], ToolMessage):
raise ValueError(
f"Expected exactly one ToolMessage in Command.update for tool '{call['name']}', got: {messages_update}"
)

tool_message: ToolMessage = messages_update[0]
tool_message.name = call["name"]
tool_message.tool_call_id = cast(str, call["id"])
return updated_command


def tools_condition(
state: Union[list[AnyMessage], dict[str, Any], BaseModel],
Expand Down
174 changes: 174 additions & 0 deletions libs/langgraph/tests/test_prebuilt.py
Original file line number Diff line number Diff line change
Expand Up @@ -1166,6 +1166,180 @@ def add(a: int, b: int) -> int:
]


async def test_tool_node_command_list_input():
command = Command(
update=[
("__root__", [ToolMessage(content="Transferred to Bob", tool_call_id="")])
],
goto="bob",
graph=Command.PARENT,
)

@dec_tool
def transfer_to_bob():
"""Transfer to Bob"""
return command

@dec_tool
async def async_transfer_to_bob():
"""Transfer to Bob"""
return command

class MyCustomTool(BaseTool):
def _run(*args: Any, **kwargs: Any):
return command

async def _arun(*args: Any, **kwargs: Any):
return command

custom_tool = MyCustomTool(
name="custom_transfer_to_bob", description="Transfer to bob"
)
async_custom_tool = MyCustomTool(
name="async_custom_transfer_to_bob", description="Transfer to bob"
)

# test mixing regular tools and tools returning commands
def add(a: int, b: int) -> int:
"""Add two numbers"""
return a + b

result = ToolNode([add, transfer_to_bob]).invoke(
[
AIMessage(
"",
tool_calls=[
{"args": {"a": 1, "b": 2}, "id": "1", "name": "add"},
{"args": {}, "id": "2", "name": "transfer_to_bob"},
],
)
]
)

assert result == [
[
ToolMessage(
content="3",
tool_call_id="1",
name="add",
)
],
Command(
update=[
(
"__root__",
[
ToolMessage(
content="Transferred to Bob",
tool_call_id="2",
name="transfer_to_bob",
)
],
)
],
goto="bob",
graph=Command.PARENT,
),
]

# test tools returning commands

# test sync tools
for tool in [transfer_to_bob, custom_tool]:
result = ToolNode([tool]).invoke(
[AIMessage("", tool_calls=[{"args": {}, "id": "1", "name": tool.name}])]
)
assert result == [
Command(
update=[
(
"__root__",
[
ToolMessage(
content="Transferred to Bob",
tool_call_id="1",
name=tool.name,
)
],
)
],
goto="bob",
graph=Command.PARENT,
)
]

# test async tools
for tool in [async_transfer_to_bob, async_custom_tool]:
result = await ToolNode([tool]).ainvoke(
[AIMessage("", tool_calls=[{"args": {}, "id": "1", "name": tool.name}])]
)
assert result == [
Command(
update=[
(
"__root__",
[
ToolMessage(
content="Transferred to Bob",
tool_call_id="1",
name=tool.name,
)
],
)
],
goto="bob",
graph=Command.PARENT,
)
]

# test multiple commands
result = ToolNode([transfer_to_bob, custom_tool]).invoke(
[
AIMessage(
"",
tool_calls=[
{"args": {}, "id": "1", "name": "transfer_to_bob"},
{"args": {}, "id": "2", "name": "custom_transfer_to_bob"},
],
)
]
)
assert result == [
Command(
update=[
(
"__root__",
[
ToolMessage(
content="Transferred to Bob",
tool_call_id="1",
name="transfer_to_bob",
)
],
)
],
goto="bob",
graph=Command.PARENT,
),
Command(
update=[
(
"__root__",
[
ToolMessage(
content="Transferred to Bob",
tool_call_id="2",
name="custom_transfer_to_bob",
)
],
)
],
goto="bob",
graph=Command.PARENT,
),
]


def test_react_agent_update_state():
class State(AgentState):
user_name: str
Expand Down

0 comments on commit d35c1fa

Please sign in to comment.