From b4b3ac6f57adad2b0458026622c1e0ceb07c6c9f Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Wed, 4 Dec 2024 15:12:03 -0800 Subject: [PATCH 1/2] lib: Merge GraphCommand and Command - Now we have only Command - Command(goto=) combines the previous functionality of Command(send=) and Command(goto=) --- libs/langgraph/langgraph/graph/__init__.py | 3 +- libs/langgraph/langgraph/graph/state.py | 48 ++++---------- libs/langgraph/langgraph/pregel/io.py | 9 +-- libs/langgraph/langgraph/types.py | 2 +- libs/langgraph/tests/test_pregel.py | 58 ++++++++--------- libs/langgraph/tests/test_pregel_async.py | 68 ++++++++++---------- libs/scheduler-kafka/tests/test_push.py | 18 +++--- libs/scheduler-kafka/tests/test_push_sync.py | 18 +++--- libs/sdk-py/langgraph_sdk/schema.py | 2 +- 9 files changed, 101 insertions(+), 125 deletions(-) diff --git a/libs/langgraph/langgraph/graph/__init__.py b/libs/langgraph/langgraph/graph/__init__.py index 241106a3a..c81ad9903 100644 --- a/libs/langgraph/langgraph/graph/__init__.py +++ b/libs/langgraph/langgraph/graph/__init__.py @@ -1,13 +1,12 @@ from langgraph.graph.graph import END, START, Graph from langgraph.graph.message import MessageGraph, MessagesState, add_messages -from langgraph.graph.state import GraphCommand, StateGraph +from langgraph.graph.state import StateGraph __all__ = [ "END", "START", "Graph", "StateGraph", - "GraphCommand", "MessageGraph", "add_messages", "MessagesState", diff --git a/libs/langgraph/langgraph/graph/state.py b/libs/langgraph/langgraph/graph/state.py index 1a7208a2a..e63f25111 100644 --- a/libs/langgraph/langgraph/graph/state.py +++ b/libs/langgraph/langgraph/graph/state.py @@ -1,4 +1,3 @@ -import dataclasses import inspect import logging import typing @@ -9,7 +8,6 @@ from typing import ( Any, Callable, - Generic, Literal, NamedTuple, Optional, @@ -55,7 +53,7 @@ from langgraph.pregel.read import ChannelRead, PregelNode from langgraph.pregel.write import SKIP_WRITE, ChannelWrite, ChannelWriteEntry from langgraph.store.base import BaseStore -from langgraph.types import _DC_KWARGS, All, Checkpointer, Command, N, RetryPolicy +from langgraph.types import All, Checkpointer, Command, RetryPolicy from langgraph.utils.fields import get_field_default from langgraph.utils.pydantic import create_model from langgraph.utils.runnable import RunnableCallable, coerce_to_runnable @@ -84,22 +82,6 @@ def _get_node_name(node: RunnableLike) -> str: raise TypeError(f"Unsupported node type: {type(node)}") -@dataclasses.dataclass(**_DC_KWARGS) -class GraphCommand(Generic[N], Command[N]): - """One or more commands to update a StateGraph's state and go to, or send messages to nodes.""" - - goto: Union[str, Sequence[str]] = () - - def __repr__(self) -> str: - # get all non-None values - contents = ", ".join( - f"{key}={value!r}" - for key, value in dataclasses.asdict(self).items() - if value - ) - return f"Command({contents})" - - class StateNodeSpec(NamedTuple): runnable: Runnable metadata: Optional[dict[str, Any]] @@ -392,7 +374,7 @@ def add_node( input = input_hint if ( (rtn := hints.get("return")) - and get_origin(rtn) in (Command, GraphCommand) + and get_origin(rtn) is Command and (rargs := get_args(rtn)) and get_origin(rargs[0]) is Literal and (vals := get_args(rargs[0])) @@ -834,15 +816,12 @@ def _control_branch(value: Any) -> Sequence[Union[str, Send]]: if value.graph == Command.PARENT: raise ParentCommand(value) rtn: list[Union[str, Send]] = [] - if isinstance(value, GraphCommand): - if isinstance(value.goto, str): - rtn.append(value.goto) - else: - rtn.extend(value.goto) - if isinstance(value.send, Send): - rtn.append(value.send) + if isinstance(value.goto, Send): + rtn.append(value.goto) + elif isinstance(value.goto, str): + rtn.append(value.goto) else: - rtn.extend(value.send) + rtn.extend(value.goto) return rtn @@ -854,15 +833,12 @@ async def _acontrol_branch(value: Any) -> Sequence[Union[str, Send]]: if value.graph == Command.PARENT: raise ParentCommand(value) rtn: list[Union[str, Send]] = [] - if isinstance(value, GraphCommand): - if isinstance(value.goto, str): - rtn.append(value.goto) - else: - rtn.extend(value.goto) - if isinstance(value.send, Send): - rtn.append(value.send) + if isinstance(value.goto, Send): + rtn.append(value.goto) + elif isinstance(value.goto, str): + rtn.append(value.goto) else: - rtn.extend(value.send) + rtn.extend(value.goto) return rtn diff --git a/libs/langgraph/langgraph/pregel/io.py b/libs/langgraph/langgraph/pregel/io.py index ed2c28938..c1fed349a 100644 --- a/libs/langgraph/langgraph/pregel/io.py +++ b/libs/langgraph/langgraph/pregel/io.py @@ -72,17 +72,18 @@ def map_command( """Map input chunk to a sequence of pending writes in the form (channel, value).""" if cmd.graph == Command.PARENT: raise InvalidUpdateError("There is not parent graph") - if cmd.send: + if cmd.goto: if isinstance(cmd.send, (tuple, list)): - sends = cmd.send + sends = cmd.goto else: - sends = [cmd.send] + sends = [cmd.goto] for send in sends: if not isinstance(send, Send): raise TypeError( - f"In Command.send, expected Send, got {type(send).__name__}" + f"In Command.goto, expected Send, got {type(send).__name__}" ) yield (NULL_TASK_ID, PUSH if FF_SEND_V2 else TASKS, send) + # TODO handle goto str for state graph if cmd.resume: if isinstance(cmd.resume, dict) and all(is_task_id(k) for k in cmd.resume): for tid, resume in cmd.resume.items(): diff --git a/libs/langgraph/langgraph/types.py b/libs/langgraph/langgraph/types.py index 4047a28f7..67c7e53f8 100644 --- a/libs/langgraph/langgraph/types.py +++ b/libs/langgraph/langgraph/types.py @@ -249,8 +249,8 @@ class Command(Generic[N]): graph: Optional[str] = None update: Optional[dict[str, Any]] = None - send: Union[Send, Sequence[Send]] = () resume: Optional[Union[Any, dict[str, Any]]] = None + goto: Union[Send, Sequence[Union[Send, str]], str] = () def __repr__(self) -> str: # get all non-None values diff --git a/libs/langgraph/tests/test_pregel.py b/libs/langgraph/tests/test_pregel.py index 4f768badd..68071594b 100644 --- a/libs/langgraph/tests/test_pregel.py +++ b/libs/langgraph/tests/test_pregel.py @@ -65,7 +65,7 @@ START, ) from langgraph.errors import InvalidUpdateError, MultipleSubgraphsError, NodeInterrupt -from langgraph.graph import END, Graph, GraphCommand, StateGraph +from langgraph.graph import END, Graph, StateGraph from langgraph.graph.message import MessageGraph, MessagesState, add_messages from langgraph.managed.shared_value import SharedValue from langgraph.prebuilt.chat_agent_executor import create_tool_calling_executor @@ -270,10 +270,10 @@ class State(TypedDict): bar: str def node_a(state: State): - return GraphCommand(goto="b", update={"foo": "bar"}) + return Command(goto="b", update={"foo": "bar"}) def node_b(state: State): - return GraphCommand(goto=END, update={"bar": "baz"}) + return Command(goto=END, update={"bar": "baz"}) builder = StateGraph(State) builder.add_node("a", node_a) @@ -1925,8 +1925,8 @@ def __call__(self, state): def send_for_fun(state): return [ - Send("2", Command(send=Send("2", 3))), - Send("2", GraphCommand(send=Send("2", 4))), + Send("2", Command(goto=Send("2", 3))), + Send("2", Command(goto=Send("2", 4))), "3.1", ] @@ -1947,8 +1947,8 @@ def route_to_three(state) -> Literal["3"]: == [ "0", "1", - "2|Command(send=Send(node='2', arg=3))", - "2|Command(send=Send(node='2', arg=4))", + "2|Command(goto=Send(node='2', arg=3))", + "2|Command(goto=Send(node='2', arg=4))", "2|3", "2|4", "3", @@ -1959,8 +1959,8 @@ def route_to_three(state) -> Literal["3"]: "0", "1", "3.1", - "2|Command(send=Send(node='2', arg=3))", - "2|Command(send=Send(node='2', arg=4))", + "2|Command(goto=Send(node='2', arg=3))", + "2|Command(goto=Send(node='2', arg=4))", "3", "2|3", "2|4", @@ -2000,15 +2000,15 @@ def __call__(self, state): if isinstance(state, list) else ["|".join((self.name, str(state)))] ) - if isinstance(state, GraphCommand): + if isinstance(state, Command): return replace(state, update=update) else: return update def send_for_fun(state): return [ - Send("2", GraphCommand(send=Send("2", 3))), - Send("2", GraphCommand(send=Send("flaky", 4))), + Send("2", Command(goto=Send("2", 3))), + Send("2", Command(goto=Send("flaky", 4))), "3.1", ] @@ -2030,8 +2030,8 @@ def route_to_three(state) -> Literal["3"]: assert graph.invoke(["0"], thread1, debug=1) == [ "0", "1", - "2|Command(send=Send(node='2', arg=3))", - "2|Command(send=Send(node='flaky', arg=4))", + "2|Command(goto=Send(node='2', arg=3))", + "2|Command(goto=Send(node='flaky', arg=4))", "2|3", ] assert builder.nodes["2"].runnable.func.ticks == 3 @@ -2046,8 +2046,8 @@ def route_to_three(state) -> Literal["3"]: assert graph.invoke(None, thread1, debug=1) == [ "0", "1", - "2|Command(send=Send(node='2', arg=3))", - "2|Command(send=Send(node='flaky', arg=4))", + "2|Command(goto=Send(node='2', arg=3))", + "2|Command(goto=Send(node='flaky', arg=4))", "2|3", "flaky|4", "3", @@ -2069,8 +2069,8 @@ def route_to_three(state) -> Literal["3"]: values=[ "0", "1", - "2|Command(send=Send(node='2', arg=3))", - "2|Command(send=Send(node='flaky', arg=4))", + "2|Command(goto=Send(node='2', arg=3))", + "2|Command(goto=Send(node='flaky', arg=4))", "2|3", "flaky|4", "3", @@ -2105,8 +2105,8 @@ def route_to_three(state) -> Literal["3"]: values=[ "0", "1", - "2|Command(send=Send(node='2', arg=3))", - "2|Command(send=Send(node='flaky', arg=4))", + "2|Command(goto=Send(node='2', arg=3))", + "2|Command(goto=Send(node='flaky', arg=4))", "2|3", "flaky|4", ], @@ -2123,8 +2123,8 @@ def route_to_three(state) -> Literal["3"]: "writes": { "1": ["1"], "2": [ - ["2|Command(send=Send(node='2', arg=3))"], - ["2|Command(send=Send(node='flaky', arg=4))"], + ["2|Command(goto=Send(node='2', arg=3))"], + ["2|Command(goto=Send(node='flaky', arg=4))"], ["2|3"], ], "flaky": ["flaky|4"], @@ -2209,7 +2209,7 @@ def route_to_three(state) -> Literal["3"]: error=None, interrupts=(), state=None, - result=["2|Command(send=Send(node='2', arg=3))"], + result=["2|Command(goto=Send(node='2', arg=3))"], ), PregelTask( id=AnyStr(), @@ -2223,7 +2223,7 @@ def route_to_three(state) -> Literal["3"]: error=None, interrupts=(), state=None, - result=["2|Command(send=Send(node='flaky', arg=4))"], + result=["2|Command(goto=Send(node='flaky', arg=4))"], ), PregelTask( id=AnyStr(), @@ -2786,10 +2786,10 @@ def test_send_react_interrupt_control( tool_calls=[ToolCall(name="foo", args={"hi": [1, 2, 3]}, id=AnyStr())], ) - def agent(state) -> GraphCommand[Literal["foo"]]: - return GraphCommand( + def agent(state) -> Command[Literal["foo"]]: + return Command( update={"messages": ai_message}, - send=[Send(call["name"], call) for call in ai_message.tool_calls], + goto=[Send(call["name"], call) for call in ai_message.tool_calls], ) foo_called = 0 @@ -14580,9 +14580,9 @@ def test_parent_command(request: pytest.FixtureRequest, checkpointer_name: str) from langchain_core.tools import tool @tool(return_direct=True) - def get_user_name() -> GraphCommand: + def get_user_name() -> Command: """Retrieve user name""" - return GraphCommand(update={"user_name": "Meow"}, graph=GraphCommand.PARENT) + return Command(update={"user_name": "Meow"}, graph=Command.PARENT) subgraph_builder = StateGraph(MessagesState) subgraph_builder.add_node("tool", get_user_name) diff --git a/libs/langgraph/tests/test_pregel_async.py b/libs/langgraph/tests/test_pregel_async.py index addb18b2b..538730c78 100644 --- a/libs/langgraph/tests/test_pregel_async.py +++ b/libs/langgraph/tests/test_pregel_async.py @@ -62,7 +62,7 @@ START, ) from langgraph.errors import InvalidUpdateError, MultipleSubgraphsError, NodeInterrupt -from langgraph.graph import END, Graph, GraphCommand, StateGraph +from langgraph.graph import END, Graph, StateGraph from langgraph.graph.message import MessageGraph, MessagesState, add_messages from langgraph.managed.shared_value import SharedValue from langgraph.prebuilt.chat_agent_executor import create_tool_calling_executor @@ -2580,8 +2580,8 @@ async def __call__(self, state): async def send_for_fun(state): return [ - Send("2", Command(send=Send("2", 3))), - Send("2", GraphCommand(send=Send("2", 4))), + Send("2", Command(goto=Send("2", 3))), + Send("2", Command(goto=Send("2", 4))), "3.1", ] @@ -2602,8 +2602,8 @@ async def route_to_three(state) -> Literal["3"]: == [ "0", "1", - "2|Command(send=Send(node='2', arg=3))", - "2|Command(send=Send(node='2', arg=4))", + "2|Command(goto=Send(node='2', arg=3))", + "2|Command(goto=Send(node='2', arg=4))", "2|3", "2|4", "3", @@ -2614,8 +2614,8 @@ async def route_to_three(state) -> Literal["3"]: "0", "1", "3.1", - "2|Command(send=Send(node='2', arg=3))", - "2|Command(send=Send(node='2', arg=4))", + "2|Command(goto=Send(node='2', arg=3))", + "2|Command(goto=Send(node='2', arg=4))", "3", "2|3", "2|4", @@ -2632,16 +2632,16 @@ async def route_to_three(state) -> Literal["3"]: assert await graph.ainvoke(["0"], thread1) == [ "0", "1", - "2|Command(send=Send(node='2', arg=3))", - "2|Command(send=Send(node='2', arg=4))", + "2|Command(goto=Send(node='2', arg=3))", + "2|Command(goto=Send(node='2', arg=4))", "2|3", "2|4", ] assert await graph.ainvoke(None, thread1) == [ "0", "1", - "2|Command(send=Send(node='2', arg=3))", - "2|Command(send=Send(node='2', arg=4))", + "2|Command(goto=Send(node='2', arg=3))", + "2|Command(goto=Send(node='2', arg=4))", "2|3", "2|4", "3", @@ -2677,15 +2677,15 @@ def __call__(self, state): if isinstance(state, list) else ["|".join((self.name, str(state)))] ) - if isinstance(state, GraphCommand): + if isinstance(state, Command): return replace(state, update=update) else: return update def send_for_fun(state): return [ - Send("2", GraphCommand(send=Send("2", 3))), - Send("2", GraphCommand(send=Send("flaky", 4))), + Send("2", Command(goto=Send("2", 3))), + Send("2", Command(goto=Send("flaky", 4))), "3.1", ] @@ -2708,8 +2708,8 @@ def route_to_three(state) -> Literal["3"]: assert await graph.ainvoke(["0"], thread1, debug=1) == [ "0", "1", - "2|Command(send=Send(node='2', arg=3))", - "2|Command(send=Send(node='flaky', arg=4))", + "2|Command(goto=Send(node='2', arg=3))", + "2|Command(goto=Send(node='flaky', arg=4))", "2|3", ] assert builder.nodes["2"].runnable.func.ticks == 3 @@ -2718,8 +2718,8 @@ def route_to_three(state) -> Literal["3"]: assert await graph.ainvoke(None, thread1, debug=1) == [ "0", "1", - "2|Command(send=Send(node='2', arg=3))", - "2|Command(send=Send(node='flaky', arg=4))", + "2|Command(goto=Send(node='2', arg=3))", + "2|Command(goto=Send(node='flaky', arg=4))", "2|3", "flaky|4", "3", @@ -2736,8 +2736,8 @@ def route_to_three(state) -> Literal["3"]: values=[ "0", "1", - "2|Command(send=Send(node='2', arg=3))", - "2|Command(send=Send(node='flaky', arg=4))", + "2|Command(goto=Send(node='2', arg=3))", + "2|Command(goto=Send(node='flaky', arg=4))", "2|3", "flaky|4", "3", @@ -2772,8 +2772,8 @@ def route_to_three(state) -> Literal["3"]: values=[ "0", "1", - "2|Command(send=Send(node='2', arg=3))", - "2|Command(send=Send(node='flaky', arg=4))", + "2|Command(goto=Send(node='2', arg=3))", + "2|Command(goto=Send(node='flaky', arg=4))", "2|3", "flaky|4", ], @@ -2790,8 +2790,8 @@ def route_to_three(state) -> Literal["3"]: "writes": { "1": ["1"], "2": [ - ["2|Command(send=Send(node='2', arg=3))"], - ["2|Command(send=Send(node='flaky', arg=4))"], + ["2|Command(goto=Send(node='2', arg=3))"], + ["2|Command(goto=Send(node='flaky', arg=4))"], ["2|3"], ], "flaky": ["flaky|4"], @@ -2876,7 +2876,7 @@ def route_to_three(state) -> Literal["3"]: error=None, interrupts=(), state=None, - result=["2|Command(send=Send(node='2', arg=3))"], + result=["2|Command(goto=Send(node='2', arg=3))"], ), PregelTask( id=AnyStr(), @@ -2890,7 +2890,7 @@ def route_to_three(state) -> Literal["3"]: error=None, interrupts=(), state=None, - result=["2|Command(send=Send(node='flaky', arg=4))"], + result=["2|Command(goto=Send(node='flaky', arg=4))"], ), PregelTask( id=AnyStr(), @@ -3448,9 +3448,9 @@ async def test_send_react_interrupt_control( ) async def agent(state) -> Command[Literal["foo"]]: - return GraphCommand( + return Command( update={"messages": ai_message}, - send=[Send(call["name"], call) for call in ai_message.tool_calls], + goto=[Send(call["name"], call) for call in ai_message.tool_calls], ) foo_called = 0 @@ -3761,13 +3761,13 @@ async def route_to_three(state) -> Literal["3"]: @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) async def test_max_concurrency_control(checkpointer_name: str) -> None: - async def node1(state) -> GraphCommand[Literal["2"]]: - return GraphCommand(update=["1"], send=[Send("2", idx) for idx in range(100)]) + async def node1(state) -> Command[Literal["2"]]: + return Command(update=["1"], goto=[Send("2", idx) for idx in range(100)]) node2_currently = 0 node2_max_currently = 0 - async def node2(state) -> GraphCommand[Literal["3"]]: + async def node2(state) -> Command[Literal["3"]]: nonlocal node2_currently, node2_max_currently node2_currently += 1 if node2_currently > node2_max_currently: @@ -3775,7 +3775,7 @@ async def node2(state) -> GraphCommand[Literal["3"]]: await asyncio.sleep(0.1) node2_currently -= 1 - return GraphCommand(update=[state], goto="3") + return Command(update=[state], goto="3") async def node3(state) -> Literal["3"]: return ["3"] @@ -12788,9 +12788,9 @@ async def test_parent_command(checkpointer_name: str) -> None: from langchain_core.tools import tool @tool(return_direct=True) - def get_user_name() -> GraphCommand: + def get_user_name() -> Command: """Retrieve user name""" - return GraphCommand(update={"user_name": "Meow"}, graph=GraphCommand.PARENT) + return Command(update={"user_name": "Meow"}, graph=Command.PARENT) subgraph_builder = StateGraph(MessagesState) subgraph_builder.add_node("tool", get_user_name) diff --git a/libs/scheduler-kafka/tests/test_push.py b/libs/scheduler-kafka/tests/test_push.py index 15e9211a2..3d2e4d43d 100644 --- a/libs/scheduler-kafka/tests/test_push.py +++ b/libs/scheduler-kafka/tests/test_push.py @@ -11,10 +11,10 @@ from langgraph.checkpoint.base import BaseCheckpointSaver from langgraph.constants import FF_SEND_V2, START from langgraph.errors import NodeInterrupt -from langgraph.graph.state import CompiledStateGraph, GraphCommand, StateGraph +from langgraph.graph.state import CompiledStateGraph, StateGraph from langgraph.scheduler.kafka import serde from langgraph.scheduler.kafka.types import MessageToOrchestrator, Topics -from langgraph.types import Send +from langgraph.types import Command, Send from tests.any import AnyDict from tests.drain import drain_topics_async @@ -48,15 +48,15 @@ def __call__(self, state): if isinstance(state, list) else ["|".join((self.name, str(state)))] ) - if isinstance(state, GraphCommand): + if isinstance(state, Command): return state.copy(update=update) else: return update def send_for_fun(state): return [ - Send("2", GraphCommand(send=Send("2", 3))), - Send("2", GraphCommand(send=Send("flaky", 4))), + Send("2", Command(goto=Send("2", 3))), + Send("2", Command(goto=Send("flaky", 4))), "3.1", ] @@ -105,8 +105,8 @@ async def test_push_graph(topics: Topics, acheckpointer: BaseCheckpointSaver) -> == [ "0", "1", - "2|Control(send=Send(node='2', arg=3))", - "2|Control(send=Send(node='flaky', arg=4))", + "2|Control(goto=Send(node='2', arg=3))", + "2|Control(goto=Send(node='flaky', arg=4))", "2|3", ] ) @@ -182,8 +182,8 @@ async def test_push_graph(topics: Topics, acheckpointer: BaseCheckpointSaver) -> == [ "0", "1", - "2|Control(send=Send(node='2', arg=3))", - "2|Control(send=Send(node='flaky', arg=4))", + "2|Control(goto=Send(node='2', arg=3))", + "2|Control(goto=Send(node='flaky', arg=4))", "2|3", "flaky|4", "3", diff --git a/libs/scheduler-kafka/tests/test_push_sync.py b/libs/scheduler-kafka/tests/test_push_sync.py index 27cd96cb7..ee33d613e 100644 --- a/libs/scheduler-kafka/tests/test_push_sync.py +++ b/libs/scheduler-kafka/tests/test_push_sync.py @@ -10,11 +10,11 @@ from langgraph.checkpoint.base import BaseCheckpointSaver from langgraph.constants import FF_SEND_V2, START from langgraph.errors import NodeInterrupt -from langgraph.graph.state import CompiledStateGraph, GraphCommand, StateGraph +from langgraph.graph.state import CompiledStateGraph, StateGraph from langgraph.scheduler.kafka import serde from langgraph.scheduler.kafka.default_sync import DefaultProducer from langgraph.scheduler.kafka.types import MessageToOrchestrator, Topics -from langgraph.types import Send +from langgraph.types import Command, Send from tests.any import AnyDict from tests.drain import drain_topics @@ -48,15 +48,15 @@ def __call__(self, state): if isinstance(state, list) else ["|".join((self.name, str(state)))] ) - if isinstance(state, GraphCommand): + if isinstance(state, Command): return state.copy(update=update) else: return update def send_for_fun(state): return [ - Send("2", GraphCommand(send=Send("2", 3))), - Send("2", GraphCommand(send=Send("flaky", 4))), + Send("2", Command(goto=Send("2", 3))), + Send("2", Command(goto=Send("flaky", 4))), "3.1", ] @@ -106,8 +106,8 @@ def test_push_graph(topics: Topics, acheckpointer: BaseCheckpointSaver) -> None: == [ "0", "1", - "2|Control(send=Send(node='2', arg=3))", - "2|Control(send=Send(node='flaky', arg=4))", + "2|Control(goto=Send(node='2', arg=3))", + "2|Control(goto=Send(node='flaky', arg=4))", "2|3", ] ) @@ -184,8 +184,8 @@ def test_push_graph(topics: Topics, acheckpointer: BaseCheckpointSaver) -> None: == [ "0", "1", - "2|Control(send=Send(node='2', arg=3))", - "2|Control(send=Send(node='flaky', arg=4))", + "2|Control(goto=Send(node='2', arg=3))", + "2|Control(goto=Send(node='flaky', arg=4))", "2|3", "flaky|4", "3", diff --git a/libs/sdk-py/langgraph_sdk/schema.py b/libs/sdk-py/langgraph_sdk/schema.py index 1ccae3e89..6237ea5bd 100644 --- a/libs/sdk-py/langgraph_sdk/schema.py +++ b/libs/sdk-py/langgraph_sdk/schema.py @@ -373,6 +373,6 @@ class Send(TypedDict): class Command(TypedDict, total=False): - send: Union[Send, Sequence[Send]] + goto: Union[Send, str, Sequence[Union[Send, str]]] update: dict[str, Any] resume: Any From df70e91daecac6b7d2b187b7c67a5c725e7dbe11 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Wed, 4 Dec 2024 15:13:55 -0800 Subject: [PATCH 2/2] Lint --- libs/langgraph/langgraph/pregel/io.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/langgraph/langgraph/pregel/io.py b/libs/langgraph/langgraph/pregel/io.py index c1fed349a..b2596d3ad 100644 --- a/libs/langgraph/langgraph/pregel/io.py +++ b/libs/langgraph/langgraph/pregel/io.py @@ -73,7 +73,7 @@ def map_command( if cmd.graph == Command.PARENT: raise InvalidUpdateError("There is not parent graph") if cmd.goto: - if isinstance(cmd.send, (tuple, list)): + if isinstance(cmd.goto, (tuple, list)): sends = cmd.goto else: sends = [cmd.goto]