Skip to content

Commit

Permalink
update tests to match core changes
Browse files Browse the repository at this point in the history
  • Loading branch information
vbarda committed Dec 10, 2024
1 parent 55a22f3 commit 442ad8c
Showing 1 changed file with 85 additions and 36 deletions.
121 changes: 85 additions & 36 deletions libs/langgraph/tests/test_prebuilt.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,8 +877,7 @@ def test_tool_node_individual_tool_error_handling():

tool_message: ToolMessage = result_individual_tool_error_handler["messages"][-1]
assert tool_message.type == "tool"
# TODO: figure out how to propagate this properly
# assert tool_message.status == "error"
assert tool_message.status == "error"
assert tool_message.content == "foo"
assert tool_message.tool_call_id == "some 0"

Expand Down Expand Up @@ -991,30 +990,58 @@ def handle(e: NodeInterrupt):


async def test_tool_node_command():
command = Command(
update={
"messages": [ToolMessage(content="Transferred to Bob", tool_call_id="")]
},
goto="bob",
graph=Command.PARENT,
)

@dec_tool
def transfer_to_bob():
def transfer_to_bob(tool_call_id: str):
"""Transfer to Bob"""
return command
return Command(
update={
"messages": [
ToolMessage(content="Transferred to Bob", tool_call_id=tool_call_id)
]
},
goto="bob",
graph=Command.PARENT,
)

@dec_tool
async def async_transfer_to_bob():
async def async_transfer_to_bob(tool_call_id: str):
"""Transfer to Bob"""
return command
return Command(
update={
"messages": [
ToolMessage(content="Transferred to Bob", tool_call_id=tool_call_id)
]
},
goto="bob",
graph=Command.PARENT,
)

class MyCustomTool(BaseTool):
def _run(*args: Any, **kwargs: Any):
return command
def _run(*args: Any, tool_call_id: str, **kwargs: Any):
return Command(
update={
"messages": [
ToolMessage(
content="Transferred to Bob", tool_call_id=tool_call_id
)
]
},
goto="bob",
graph=Command.PARENT,
)

async def _arun(*args: Any, **kwargs: Any):
return command
async def _arun(*args: Any, tool_call_id: str, **kwargs: Any):
return Command(
update={
"messages": [
ToolMessage(
content="Transferred to Bob", tool_call_id=tool_call_id
)
]
},
goto="bob",
graph=Command.PARENT,
)

custom_tool = MyCustomTool(
name="custom_transfer_to_bob", description="Transfer to bob"
Expand Down Expand Up @@ -1170,9 +1197,11 @@ def add(a: int, b: int) -> int:
with pytest.raises(InvalidToolCommandError):

@dec_tool
def list_update_tool():
def list_update_tool(tool_call_id: str):
"""My tool"""
return Command(update=[ToolMessage(content="foo", tool_call_id="")])
return Command(
update=[ToolMessage(content="foo", tool_call_id=tool_call_id)]
)

ToolNode([list_update_tool]).invoke(
{
Expand Down Expand Up @@ -1261,28 +1290,46 @@ def multiple_tool_messages_tool():


async def test_tool_node_command_list_input():
command = Command(
update=[ToolMessage(content="Transferred to Bob", tool_call_id="")],
goto="bob",
graph=Command.PARENT,
)

@dec_tool
def transfer_to_bob():
def transfer_to_bob(tool_call_id: str):
"""Transfer to Bob"""
return command
return Command(
update=[
ToolMessage(content="Transferred to Bob", tool_call_id=tool_call_id)
],
goto="bob",
graph=Command.PARENT,
)

@dec_tool
async def async_transfer_to_bob():
async def async_transfer_to_bob(tool_call_id: str):
"""Transfer to Bob"""
return command
return Command(
update=[
ToolMessage(content="Transferred to Bob", tool_call_id=tool_call_id)
],
goto="bob",
graph=Command.PARENT,
)

class MyCustomTool(BaseTool):
def _run(*args: Any, **kwargs: Any):
return command
def _run(*args: Any, tool_call_id: str, **kwargs: Any):
return Command(
update=[
ToolMessage(content="Transferred to Bob", tool_call_id=tool_call_id)
],
goto="bob",
graph=Command.PARENT,
)

async def _arun(*args: Any, **kwargs: Any):
return command
async def _arun(*args: Any, tool_call_id: str, **kwargs: Any):
return Command(
update=[
ToolMessage(content="Transferred to Bob", tool_call_id=tool_call_id)
],
goto="bob",
graph=Command.PARENT,
)

custom_tool = MyCustomTool(
name="custom_transfer_to_bob", description="Transfer to bob"
Expand Down Expand Up @@ -1410,10 +1457,12 @@ def add(a: int, b: int) -> int:
with pytest.raises(InvalidToolCommandError):

@dec_tool
def list_update_tool():
def list_update_tool(tool_call_id: str):
"""My tool"""
return Command(
update={"messages": [ToolMessage(content="foo", tool_call_id="")]}
update={
"messages": [ToolMessage(content="foo", tool_call_id=tool_call_id)]
}
)

ToolNode([list_update_tool]).invoke(
Expand Down

0 comments on commit 442ad8c

Please sign in to comment.