Skip to content

Commit

Permalink
Merge pull request #2638 from langchain-ai/nc/4dec/command
Browse files Browse the repository at this point in the history
lib: Merge GraphCommand and Command
  • Loading branch information
nfcampos authored Dec 4, 2024
2 parents 78e6b36 + df70e91 commit dcc2617
Show file tree
Hide file tree
Showing 9 changed files with 102 additions and 126 deletions.
3 changes: 1 addition & 2 deletions libs/langgraph/langgraph/graph/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from langgraph.graph.graph import END, START, Graph
from langgraph.graph.message import MessageGraph, MessagesState, add_messages
from langgraph.graph.state import GraphCommand, StateGraph
from langgraph.graph.state import StateGraph

__all__ = [
"END",
"START",
"Graph",
"StateGraph",
"GraphCommand",
"MessageGraph",
"add_messages",
"MessagesState",
Expand Down
48 changes: 12 additions & 36 deletions libs/langgraph/langgraph/graph/state.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import dataclasses
import inspect
import logging
import typing
Expand All @@ -9,7 +8,6 @@
from typing import (
Any,
Callable,
Generic,
Literal,
NamedTuple,
Optional,
Expand Down Expand Up @@ -55,7 +53,7 @@
from langgraph.pregel.read import ChannelRead, PregelNode
from langgraph.pregel.write import SKIP_WRITE, ChannelWrite, ChannelWriteEntry
from langgraph.store.base import BaseStore
from langgraph.types import _DC_KWARGS, All, Checkpointer, Command, N, RetryPolicy
from langgraph.types import All, Checkpointer, Command, RetryPolicy
from langgraph.utils.fields import get_field_default
from langgraph.utils.pydantic import create_model
from langgraph.utils.runnable import RunnableCallable, coerce_to_runnable
Expand Down Expand Up @@ -84,22 +82,6 @@ def _get_node_name(node: RunnableLike) -> str:
raise TypeError(f"Unsupported node type: {type(node)}")


@dataclasses.dataclass(**_DC_KWARGS)
class GraphCommand(Generic[N], Command[N]):
"""One or more commands to update a StateGraph's state and go to, or send messages to nodes."""

goto: Union[str, Sequence[str]] = ()

def __repr__(self) -> str:
# get all non-None values
contents = ", ".join(
f"{key}={value!r}"
for key, value in dataclasses.asdict(self).items()
if value
)
return f"Command({contents})"


class StateNodeSpec(NamedTuple):
runnable: Runnable
metadata: Optional[dict[str, Any]]
Expand Down Expand Up @@ -392,7 +374,7 @@ def add_node(
input = input_hint
if (
(rtn := hints.get("return"))
and get_origin(rtn) in (Command, GraphCommand)
and get_origin(rtn) is Command
and (rargs := get_args(rtn))
and get_origin(rargs[0]) is Literal
and (vals := get_args(rargs[0]))
Expand Down Expand Up @@ -834,15 +816,12 @@ def _control_branch(value: Any) -> Sequence[Union[str, Send]]:
if value.graph == Command.PARENT:
raise ParentCommand(value)
rtn: list[Union[str, Send]] = []
if isinstance(value, GraphCommand):
if isinstance(value.goto, str):
rtn.append(value.goto)
else:
rtn.extend(value.goto)
if isinstance(value.send, Send):
rtn.append(value.send)
if isinstance(value.goto, Send):
rtn.append(value.goto)
elif isinstance(value.goto, str):
rtn.append(value.goto)
else:
rtn.extend(value.send)
rtn.extend(value.goto)
return rtn


Expand All @@ -854,15 +833,12 @@ async def _acontrol_branch(value: Any) -> Sequence[Union[str, Send]]:
if value.graph == Command.PARENT:
raise ParentCommand(value)
rtn: list[Union[str, Send]] = []
if isinstance(value, GraphCommand):
if isinstance(value.goto, str):
rtn.append(value.goto)
else:
rtn.extend(value.goto)
if isinstance(value.send, Send):
rtn.append(value.send)
if isinstance(value.goto, Send):
rtn.append(value.goto)
elif isinstance(value.goto, str):
rtn.append(value.goto)
else:
rtn.extend(value.send)
rtn.extend(value.goto)
return rtn


Expand Down
11 changes: 6 additions & 5 deletions libs/langgraph/langgraph/pregel/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,17 +72,18 @@ def map_command(
"""Map input chunk to a sequence of pending writes in the form (channel, value)."""
if cmd.graph == Command.PARENT:
raise InvalidUpdateError("There is not parent graph")
if cmd.send:
if isinstance(cmd.send, (tuple, list)):
sends = cmd.send
if cmd.goto:
if isinstance(cmd.goto, (tuple, list)):
sends = cmd.goto
else:
sends = [cmd.send]
sends = [cmd.goto]
for send in sends:
if not isinstance(send, Send):
raise TypeError(
f"In Command.send, expected Send, got {type(send).__name__}"
f"In Command.goto, expected Send, got {type(send).__name__}"
)
yield (NULL_TASK_ID, PUSH if FF_SEND_V2 else TASKS, send)
# TODO handle goto str for state graph
if cmd.resume:
if isinstance(cmd.resume, dict) and all(is_task_id(k) for k in cmd.resume):
for tid, resume in cmd.resume.items():
Expand Down
2 changes: 1 addition & 1 deletion libs/langgraph/langgraph/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,8 +249,8 @@ class Command(Generic[N]):

graph: Optional[str] = None
update: Optional[dict[str, Any]] = None
send: Union[Send, Sequence[Send]] = ()
resume: Optional[Union[Any, dict[str, Any]]] = None
goto: Union[Send, Sequence[Union[Send, str]], str] = ()

def __repr__(self) -> str:
# get all non-None values
Expand Down
58 changes: 29 additions & 29 deletions libs/langgraph/tests/test_pregel.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
START,
)
from langgraph.errors import InvalidUpdateError, MultipleSubgraphsError, NodeInterrupt
from langgraph.graph import END, Graph, GraphCommand, StateGraph
from langgraph.graph import END, Graph, StateGraph
from langgraph.graph.message import MessageGraph, MessagesState, add_messages
from langgraph.managed.shared_value import SharedValue
from langgraph.prebuilt.chat_agent_executor import create_tool_calling_executor
Expand Down Expand Up @@ -270,10 +270,10 @@ class State(TypedDict):
bar: str

def node_a(state: State):
return GraphCommand(goto="b", update={"foo": "bar"})
return Command(goto="b", update={"foo": "bar"})

def node_b(state: State):
return GraphCommand(goto=END, update={"bar": "baz"})
return Command(goto=END, update={"bar": "baz"})

builder = StateGraph(State)
builder.add_node("a", node_a)
Expand Down Expand Up @@ -1925,8 +1925,8 @@ def __call__(self, state):

def send_for_fun(state):
return [
Send("2", Command(send=Send("2", 3))),
Send("2", GraphCommand(send=Send("2", 4))),
Send("2", Command(goto=Send("2", 3))),
Send("2", Command(goto=Send("2", 4))),
"3.1",
]

Expand All @@ -1947,8 +1947,8 @@ def route_to_three(state) -> Literal["3"]:
== [
"0",
"1",
"2|Command(send=Send(node='2', arg=3))",
"2|Command(send=Send(node='2', arg=4))",
"2|Command(goto=Send(node='2', arg=3))",
"2|Command(goto=Send(node='2', arg=4))",
"2|3",
"2|4",
"3",
Expand All @@ -1959,8 +1959,8 @@ def route_to_three(state) -> Literal["3"]:
"0",
"1",
"3.1",
"2|Command(send=Send(node='2', arg=3))",
"2|Command(send=Send(node='2', arg=4))",
"2|Command(goto=Send(node='2', arg=3))",
"2|Command(goto=Send(node='2', arg=4))",
"3",
"2|3",
"2|4",
Expand Down Expand Up @@ -2000,15 +2000,15 @@ def __call__(self, state):
if isinstance(state, list)
else ["|".join((self.name, str(state)))]
)
if isinstance(state, GraphCommand):
if isinstance(state, Command):
return replace(state, update=update)
else:
return update

def send_for_fun(state):
return [
Send("2", GraphCommand(send=Send("2", 3))),
Send("2", GraphCommand(send=Send("flaky", 4))),
Send("2", Command(goto=Send("2", 3))),
Send("2", Command(goto=Send("flaky", 4))),
"3.1",
]

Expand All @@ -2030,8 +2030,8 @@ def route_to_three(state) -> Literal["3"]:
assert graph.invoke(["0"], thread1, debug=1) == [
"0",
"1",
"2|Command(send=Send(node='2', arg=3))",
"2|Command(send=Send(node='flaky', arg=4))",
"2|Command(goto=Send(node='2', arg=3))",
"2|Command(goto=Send(node='flaky', arg=4))",
"2|3",
]
assert builder.nodes["2"].runnable.func.ticks == 3
Expand All @@ -2046,8 +2046,8 @@ def route_to_three(state) -> Literal["3"]:
assert graph.invoke(None, thread1, debug=1) == [
"0",
"1",
"2|Command(send=Send(node='2', arg=3))",
"2|Command(send=Send(node='flaky', arg=4))",
"2|Command(goto=Send(node='2', arg=3))",
"2|Command(goto=Send(node='flaky', arg=4))",
"2|3",
"flaky|4",
"3",
Expand All @@ -2069,8 +2069,8 @@ def route_to_three(state) -> Literal["3"]:
values=[
"0",
"1",
"2|Command(send=Send(node='2', arg=3))",
"2|Command(send=Send(node='flaky', arg=4))",
"2|Command(goto=Send(node='2', arg=3))",
"2|Command(goto=Send(node='flaky', arg=4))",
"2|3",
"flaky|4",
"3",
Expand Down Expand Up @@ -2105,8 +2105,8 @@ def route_to_three(state) -> Literal["3"]:
values=[
"0",
"1",
"2|Command(send=Send(node='2', arg=3))",
"2|Command(send=Send(node='flaky', arg=4))",
"2|Command(goto=Send(node='2', arg=3))",
"2|Command(goto=Send(node='flaky', arg=4))",
"2|3",
"flaky|4",
],
Expand All @@ -2123,8 +2123,8 @@ def route_to_three(state) -> Literal["3"]:
"writes": {
"1": ["1"],
"2": [
["2|Command(send=Send(node='2', arg=3))"],
["2|Command(send=Send(node='flaky', arg=4))"],
["2|Command(goto=Send(node='2', arg=3))"],
["2|Command(goto=Send(node='flaky', arg=4))"],
["2|3"],
],
"flaky": ["flaky|4"],
Expand Down Expand Up @@ -2209,7 +2209,7 @@ def route_to_three(state) -> Literal["3"]:
error=None,
interrupts=(),
state=None,
result=["2|Command(send=Send(node='2', arg=3))"],
result=["2|Command(goto=Send(node='2', arg=3))"],
),
PregelTask(
id=AnyStr(),
Expand All @@ -2223,7 +2223,7 @@ def route_to_three(state) -> Literal["3"]:
error=None,
interrupts=(),
state=None,
result=["2|Command(send=Send(node='flaky', arg=4))"],
result=["2|Command(goto=Send(node='flaky', arg=4))"],
),
PregelTask(
id=AnyStr(),
Expand Down Expand Up @@ -2786,10 +2786,10 @@ def test_send_react_interrupt_control(
tool_calls=[ToolCall(name="foo", args={"hi": [1, 2, 3]}, id=AnyStr())],
)

def agent(state) -> GraphCommand[Literal["foo"]]:
return GraphCommand(
def agent(state) -> Command[Literal["foo"]]:
return Command(
update={"messages": ai_message},
send=[Send(call["name"], call) for call in ai_message.tool_calls],
goto=[Send(call["name"], call) for call in ai_message.tool_calls],
)

foo_called = 0
Expand Down Expand Up @@ -14580,9 +14580,9 @@ def test_parent_command(request: pytest.FixtureRequest, checkpointer_name: str)
from langchain_core.tools import tool

@tool(return_direct=True)
def get_user_name() -> GraphCommand:
def get_user_name() -> Command:
"""Retrieve user name"""
return GraphCommand(update={"user_name": "Meow"}, graph=GraphCommand.PARENT)
return Command(update={"user_name": "Meow"}, graph=Command.PARENT)

subgraph_builder = StateGraph(MessagesState)
subgraph_builder.add_node("tool", get_user_name)
Expand Down
Loading

0 comments on commit dcc2617

Please sign in to comment.