diff --git a/libs/langgraph/langgraph/types.py b/libs/langgraph/langgraph/types.py index e8dad7ba7..2086552f9 100644 --- a/libs/langgraph/langgraph/types.py +++ b/libs/langgraph/langgraph/types.py @@ -35,7 +35,9 @@ try: from langchain_core.messages.tool import ToolOutputMixin except ImportError: - ToolOutputMixin = object + + class ToolOutputMixin: # type: ignore[no-redef] + pass All = Literal["*"] diff --git a/libs/langgraph/tests/test_prebuilt.py b/libs/langgraph/tests/test_prebuilt.py index 109ebd2bf..b80552698 100644 --- a/libs/langgraph/tests/test_prebuilt.py +++ b/libs/langgraph/tests/test_prebuilt.py @@ -31,6 +31,7 @@ from langchain_core.runnables import Runnable, RunnableLambda from langchain_core.tools import BaseTool, ToolException from langchain_core.tools import tool as dec_tool +from langchain_core.tools.base import InjectedToolCallId from pydantic import BaseModel, ValidationError from pydantic.v1 import BaseModel as BaseModelV1 from pydantic.v1 import ValidationError as ValidationErrorV1 @@ -995,7 +996,7 @@ def handle(e: NodeInterrupt): ) async def test_tool_node_command(): @dec_tool - def transfer_to_bob(tool_call_id: str): + def transfer_to_bob(tool_call_id: Annotated[str, InjectedToolCallId]): """Transfer to Bob""" return Command( update={ @@ -1008,7 +1009,7 @@ def transfer_to_bob(tool_call_id: str): ) @dec_tool - async def async_transfer_to_bob(tool_call_id: str): + async def async_transfer_to_bob(tool_call_id: Annotated[str, InjectedToolCallId]): """Transfer to Bob""" return Command( update={ @@ -1020,13 +1021,17 @@ async def async_transfer_to_bob(tool_call_id: str): graph=Command.PARENT, ) + class CustomToolSchema(BaseModel): + tool_call_id: Annotated[str, InjectedToolCallId] + class MyCustomTool(BaseTool): - def _run(*args: Any, tool_call_id: str, **kwargs: Any): + def _run(*args: Any, **kwargs: Any): return Command( update={ "messages": [ ToolMessage( - content="Transferred to Bob", tool_call_id=tool_call_id + content="Transferred to Bob", + tool_call_id=kwargs["tool_call_id"], ) ] }, @@ -1034,12 +1039,13 @@ def _run(*args: Any, tool_call_id: str, **kwargs: Any): graph=Command.PARENT, ) - async def _arun(*args: Any, tool_call_id: str, **kwargs: Any): + async def _arun(*args: Any, **kwargs: Any): return Command( update={ "messages": [ ToolMessage( - content="Transferred to Bob", tool_call_id=tool_call_id + content="Transferred to Bob", + tool_call_id=kwargs["tool_call_id"], ) ] }, @@ -1048,10 +1054,14 @@ async def _arun(*args: Any, tool_call_id: str, **kwargs: Any): ) custom_tool = MyCustomTool( - name="custom_transfer_to_bob", description="Transfer to bob" + name="custom_transfer_to_bob", + description="Transfer to bob", + args_schema=CustomToolSchema, ) async_custom_tool = MyCustomTool( - name="async_custom_transfer_to_bob", description="Transfer to bob" + name="async_custom_transfer_to_bob", + description="Transfer to bob", + args_schema=CustomToolSchema, ) # test mixing regular tools and tools returning commands @@ -1201,7 +1211,7 @@ def add(a: int, b: int) -> int: with pytest.raises(InvalidToolCommandError): @dec_tool - def list_update_tool(tool_call_id: str): + def list_update_tool(tool_call_id: Annotated[str, InjectedToolCallId]): """My tool""" return Command( update=[ToolMessage(content="foo", tool_call_id=tool_call_id)] @@ -1299,7 +1309,7 @@ def multiple_tool_messages_tool(): ) async def test_tool_node_command_list_input(): @dec_tool - def transfer_to_bob(tool_call_id: str): + def transfer_to_bob(tool_call_id: Annotated[str, InjectedToolCallId]): """Transfer to Bob""" return Command( update=[ @@ -1310,7 +1320,7 @@ def transfer_to_bob(tool_call_id: str): ) @dec_tool - async def async_transfer_to_bob(tool_call_id: str): + async def async_transfer_to_bob(tool_call_id: Annotated[str, InjectedToolCallId]): """Transfer to Bob""" return Command( update=[ @@ -1320,30 +1330,43 @@ async def async_transfer_to_bob(tool_call_id: str): graph=Command.PARENT, ) + class CustomToolSchema(BaseModel): + tool_call_id: Annotated[str, InjectedToolCallId] + class MyCustomTool(BaseTool): - def _run(*args: Any, tool_call_id: str, **kwargs: Any): + def _run(*args: Any, **kwargs: Any): return Command( update=[ - ToolMessage(content="Transferred to Bob", tool_call_id=tool_call_id) + ToolMessage( + content="Transferred to Bob", + tool_call_id=kwargs["tool_call_id"], + ) ], goto="bob", graph=Command.PARENT, ) - async def _arun(*args: Any, tool_call_id: str, **kwargs: Any): + async def _arun(*args: Any, **kwargs: Any): return Command( update=[ - ToolMessage(content="Transferred to Bob", tool_call_id=tool_call_id) + ToolMessage( + content="Transferred to Bob", + tool_call_id=kwargs["tool_call_id"], + ) ], goto="bob", graph=Command.PARENT, ) custom_tool = MyCustomTool( - name="custom_transfer_to_bob", description="Transfer to bob" + name="custom_transfer_to_bob", + description="Transfer to bob", + args_schema=CustomToolSchema, ) async_custom_tool = MyCustomTool( - name="async_custom_transfer_to_bob", description="Transfer to bob" + name="async_custom_transfer_to_bob", + description="Transfer to bob", + args_schema=CustomToolSchema, ) # test mixing regular tools and tools returning commands @@ -1465,7 +1488,7 @@ def add(a: int, b: int) -> int: with pytest.raises(InvalidToolCommandError): @dec_tool - def list_update_tool(tool_call_id: str): + def list_update_tool(tool_call_id: Annotated[str, InjectedToolCallId]): """My tool""" return Command( update={ @@ -1554,14 +1577,16 @@ class State(AgentState): user_name: str @dec_tool - def get_user_name(): + def get_user_name(tool_call_id: Annotated[str, InjectedToolCallId]): """Retrieve user name""" user_name = interrupt("Please provider user name:") return Command( update={ "user_name": user_name, "messages": [ - ToolMessage("Successfully retrieved user name", tool_call_id="") + ToolMessage( + "Successfully retrieved user name", tool_call_id=tool_call_id + ) ], } )