Skip to content

Commit

Permalink
update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
vbarda committed Dec 10, 2024
1 parent 6a79de0 commit 91313f3
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 21 deletions.
4 changes: 3 additions & 1 deletion libs/langgraph/langgraph/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@
try:
from langchain_core.messages.tool import ToolOutputMixin
except ImportError:
ToolOutputMixin = object

class ToolOutputMixin: # type: ignore[no-redef]
pass


All = Literal["*"]
Expand Down
65 changes: 45 additions & 20 deletions libs/langgraph/tests/test_prebuilt.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from langchain_core.runnables import Runnable, RunnableLambda
from langchain_core.tools import BaseTool, ToolException
from langchain_core.tools import tool as dec_tool
from langchain_core.tools.base import InjectedToolCallId
from pydantic import BaseModel, ValidationError
from pydantic.v1 import BaseModel as BaseModelV1
from pydantic.v1 import ValidationError as ValidationErrorV1
Expand Down Expand Up @@ -995,7 +996,7 @@ def handle(e: NodeInterrupt):
)
async def test_tool_node_command():
@dec_tool
def transfer_to_bob(tool_call_id: str):
def transfer_to_bob(tool_call_id: Annotated[str, InjectedToolCallId]):
"""Transfer to Bob"""
return Command(
update={
Expand All @@ -1008,7 +1009,7 @@ def transfer_to_bob(tool_call_id: str):
)

@dec_tool
async def async_transfer_to_bob(tool_call_id: str):
async def async_transfer_to_bob(tool_call_id: Annotated[str, InjectedToolCallId]):
"""Transfer to Bob"""
return Command(
update={
Expand All @@ -1020,26 +1021,31 @@ async def async_transfer_to_bob(tool_call_id: str):
graph=Command.PARENT,
)

class CustomToolSchema(BaseModel):
tool_call_id: Annotated[str, InjectedToolCallId]

class MyCustomTool(BaseTool):
def _run(*args: Any, tool_call_id: str, **kwargs: Any):
def _run(*args: Any, **kwargs: Any):
return Command(
update={
"messages": [
ToolMessage(
content="Transferred to Bob", tool_call_id=tool_call_id
content="Transferred to Bob",
tool_call_id=kwargs["tool_call_id"],
)
]
},
goto="bob",
graph=Command.PARENT,
)

async def _arun(*args: Any, tool_call_id: str, **kwargs: Any):
async def _arun(*args: Any, **kwargs: Any):
return Command(
update={
"messages": [
ToolMessage(
content="Transferred to Bob", tool_call_id=tool_call_id
content="Transferred to Bob",
tool_call_id=kwargs["tool_call_id"],
)
]
},
Expand All @@ -1048,10 +1054,14 @@ async def _arun(*args: Any, tool_call_id: str, **kwargs: Any):
)

custom_tool = MyCustomTool(
name="custom_transfer_to_bob", description="Transfer to bob"
name="custom_transfer_to_bob",
description="Transfer to bob",
args_schema=CustomToolSchema,
)
async_custom_tool = MyCustomTool(
name="async_custom_transfer_to_bob", description="Transfer to bob"
name="async_custom_transfer_to_bob",
description="Transfer to bob",
args_schema=CustomToolSchema,
)

# test mixing regular tools and tools returning commands
Expand Down Expand Up @@ -1201,7 +1211,7 @@ def add(a: int, b: int) -> int:
with pytest.raises(InvalidToolCommandError):

@dec_tool
def list_update_tool(tool_call_id: str):
def list_update_tool(tool_call_id: Annotated[str, InjectedToolCallId]):
"""My tool"""
return Command(
update=[ToolMessage(content="foo", tool_call_id=tool_call_id)]
Expand Down Expand Up @@ -1299,7 +1309,7 @@ def multiple_tool_messages_tool():
)
async def test_tool_node_command_list_input():
@dec_tool
def transfer_to_bob(tool_call_id: str):
def transfer_to_bob(tool_call_id: Annotated[str, InjectedToolCallId]):
"""Transfer to Bob"""
return Command(
update=[
Expand All @@ -1310,7 +1320,7 @@ def transfer_to_bob(tool_call_id: str):
)

@dec_tool
async def async_transfer_to_bob(tool_call_id: str):
async def async_transfer_to_bob(tool_call_id: Annotated[str, InjectedToolCallId]):
"""Transfer to Bob"""
return Command(
update=[
Expand All @@ -1320,30 +1330,43 @@ async def async_transfer_to_bob(tool_call_id: str):
graph=Command.PARENT,
)

class CustomToolSchema(BaseModel):
tool_call_id: Annotated[str, InjectedToolCallId]

class MyCustomTool(BaseTool):
def _run(*args: Any, tool_call_id: str, **kwargs: Any):
def _run(*args: Any, **kwargs: Any):
return Command(
update=[
ToolMessage(content="Transferred to Bob", tool_call_id=tool_call_id)
ToolMessage(
content="Transferred to Bob",
tool_call_id=kwargs["tool_call_id"],
)
],
goto="bob",
graph=Command.PARENT,
)

async def _arun(*args: Any, tool_call_id: str, **kwargs: Any):
async def _arun(*args: Any, **kwargs: Any):
return Command(
update=[
ToolMessage(content="Transferred to Bob", tool_call_id=tool_call_id)
ToolMessage(
content="Transferred to Bob",
tool_call_id=kwargs["tool_call_id"],
)
],
goto="bob",
graph=Command.PARENT,
)

custom_tool = MyCustomTool(
name="custom_transfer_to_bob", description="Transfer to bob"
name="custom_transfer_to_bob",
description="Transfer to bob",
args_schema=CustomToolSchema,
)
async_custom_tool = MyCustomTool(
name="async_custom_transfer_to_bob", description="Transfer to bob"
name="async_custom_transfer_to_bob",
description="Transfer to bob",
args_schema=CustomToolSchema,
)

# test mixing regular tools and tools returning commands
Expand Down Expand Up @@ -1465,7 +1488,7 @@ def add(a: int, b: int) -> int:
with pytest.raises(InvalidToolCommandError):

@dec_tool
def list_update_tool(tool_call_id: str):
def list_update_tool(tool_call_id: Annotated[str, InjectedToolCallId]):
"""My tool"""
return Command(
update={
Expand Down Expand Up @@ -1554,14 +1577,16 @@ class State(AgentState):
user_name: str

@dec_tool
def get_user_name():
def get_user_name(tool_call_id: Annotated[str, InjectedToolCallId]):
"""Retrieve user name"""
user_name = interrupt("Please provider user name:")
return Command(
update={
"user_name": user_name,
"messages": [
ToolMessage("Successfully retrieved user name", tool_call_id="")
ToolMessage(
"Successfully retrieved user name", tool_call_id=tool_call_id
)
],
}
)
Expand Down

0 comments on commit 91313f3

Please sign in to comment.