Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
vbarda committed Dec 10, 2024
1 parent 24515bf commit 6a79de0
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
8 changes: 4 additions & 4 deletions libs/langgraph/langgraph/prebuilt/tool_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,8 +300,8 @@ def _run_one(
return invalid_tool_message

try:
tool = self.tools_by_name[call["name"]]
response = tool.invoke({**call, **{"type": "tool_call"}})
input = {**call, **{"type": "tool_call"}}
response = self.tools_by_name[call["name"]].invoke(input)
if isinstance(response, Command):
return self._validate_tool_command(response, call, input_type)
elif isinstance(response, ToolMessage):
Expand Down Expand Up @@ -352,8 +352,8 @@ async def _arun_one(
return invalid_tool_message

try:
tool = self.tools_by_name[call["name"]]
response = await tool.ainvoke({**call, **{"type": "tool_call"}})
input = {**call, **{"type": "tool_call"}}
response = await self.tools_by_name[call["name"]].ainvoke(input)
if isinstance(response, Command):
return self._validate_tool_command(response, call, input_type)
elif isinstance(response, ToolMessage):
Expand Down
12 changes: 12 additions & 0 deletions libs/langgraph/tests/test_prebuilt.py
Original file line number Diff line number Diff line change
Expand Up @@ -989,6 +989,10 @@ def handle(e: NodeInterrupt):
assert task.interrupts == (Interrupt(value="foo", when="during"),)


@pytest.mark.skipif(
not IS_LANGCHAIN_CORE_030_OR_GREATER,
reason="Langchain core 0.3.0 or greater is required",
)
async def test_tool_node_command():
@dec_tool
def transfer_to_bob(tool_call_id: str):
Expand Down Expand Up @@ -1289,6 +1293,10 @@ def multiple_tool_messages_tool():
)


@pytest.mark.skipif(
not IS_LANGCHAIN_CORE_030_OR_GREATER,
reason="Langchain core 0.3.0 or greater is required",
)
async def test_tool_node_command_list_input():
@dec_tool
def transfer_to_bob(tool_call_id: str):
Expand Down Expand Up @@ -1537,6 +1545,10 @@ def multiple_tool_messages_tool():
)


@pytest.mark.skipif(
not IS_LANGCHAIN_CORE_030_OR_GREATER,
reason="Langchain core 0.3.0 or greater is required",
)
def test_react_agent_update_state():
class State(AgentState):
user_name: str
Expand Down

0 comments on commit 6a79de0

Please sign in to comment.