Skip to content

Commit

Permalink
Merge pull request #2683 from langchain-ai/nc/9dec/invoke-command-goto
Browse files Browse the repository at this point in the history
lib: Add support for invoke(Command(goto=<str>))
  • Loading branch information
nfcampos authored Dec 10, 2024
2 parents 3d97b97 + a7ac9ff commit 60d742e
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 10 deletions.
1 change: 1 addition & 0 deletions libs/langgraph/langgraph/graph/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,7 @@ def compile(
for key, node in self.nodes.items():
compiled.attach_node(key, node)

compiled.attach_branch(START, SELF, CONTROL_BRANCH, with_reader=False)
for key, node in self.nodes.items():
compiled.attach_branch(key, SELF, CONTROL_BRANCH, with_reader=False)

Expand Down
12 changes: 8 additions & 4 deletions libs/langgraph/langgraph/pregel/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
PUSH,
RESUME,
RETURN,
SELF,
START,
TAG_HIDDEN,
TASKS,
)
Expand Down Expand Up @@ -79,12 +81,14 @@ def map_command(
else:
sends = [cmd.goto]
for send in sends:
if not isinstance(send, Send):
if isinstance(send, Send):
yield (NULL_TASK_ID, PUSH if FF_SEND_V2 else TASKS, send)
elif isinstance(send, str):
yield (NULL_TASK_ID, f"branch:{START}:{SELF}:{send}", START)
else:
raise TypeError(
f"In Command.goto, expected Send, got {type(send).__name__}"
f"In Command.goto, expected Send/str, got {type(send).__name__}"
)
yield (NULL_TASK_ID, PUSH if FF_SEND_V2 else TASKS, send)
# TODO handle goto str for state graph
if cmd.resume:
if isinstance(cmd.resume, dict) and all(is_task_id(k) for k in cmd.resume):
for tid, resume in cmd.resume.items():
Expand Down
51 changes: 45 additions & 6 deletions libs/langgraph/tests/test_pregel.py
Original file line number Diff line number Diff line change
Expand Up @@ -7602,7 +7602,7 @@ class State(TypedDict):
content="result for query",
name="search_api",
tool_call_id="tool_call123",
id="00000000-0000-4000-8000-000000000033",
id="00000000-0000-4000-8000-000000000037",
)
]
},
Expand All @@ -7625,7 +7625,7 @@ class State(TypedDict):
content="result for another",
name="search_api",
tool_call_id="tool_call456",
id="00000000-0000-4000-8000-000000000041",
id="00000000-0000-4000-8000-000000000045",
)
]
},
Expand Down Expand Up @@ -8235,7 +8235,7 @@ class MoreState(TypedDict):
"__root__": [
HumanMessage(
content="what is weather in sf",
id="00000000-0000-4000-8000-000000000070",
id="00000000-0000-4000-8000-000000000078",
),
AIMessage(
content="",
Expand All @@ -8255,7 +8255,7 @@ class MoreState(TypedDict):
),
AIMessage(content="answer", id="ai2"),
AIMessage(
content="an extra message", id="00000000-0000-4000-8000-000000000092"
content="an extra message", id="00000000-0000-4000-8000-000000000100"
),
HumanMessage(content="what is weather in la"),
],
Expand Down Expand Up @@ -14940,8 +14940,8 @@ def node2(state: State):

# Start the graph and interrupt at the first node
graph.invoke({"foo": "abc"}, config)
result = graph.invoke(Command(update={"foo": "def"}), config)
assert result == {"foo": "def|node-1|node-2"}
result = graph.invoke(Command(resume="node1"), config)
assert result == {"foo": "abc|node-1|node-2"}


@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
Expand Down Expand Up @@ -15003,3 +15003,42 @@ def step4(state: State):
],
"plan": [],
}


@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
def test_command_goto_with_static_breakpoints(
request: pytest.FixtureRequest, checkpointer_name: str
) -> None:
"""Use Command goto with static breakpoints."""

checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}")

class State(TypedDict):
"""The graph state."""

foo: Annotated[str, operator.add]

def node1(state: State):
return {
"foo": "|node-1",
}

def node2(state: State):
return {
"foo": "|node-2",
}

builder = StateGraph(State)
builder.add_node("node1", node1)
builder.add_node("node2", node2)
builder.add_edge(START, "node1")
builder.add_edge("node1", "node2")

graph = builder.compile(checkpointer=checkpointer, interrupt_before=["node1"])

config = {"configurable": {"thread_id": str(uuid.uuid4())}}

# Start the graph and interrupt at the first node
graph.invoke({"foo": "abc"}, config)
result = graph.invoke(Command(goto=["node2"]), config)
assert result == {"foo": "abc|node-1|node-2|node-2"}
36 changes: 36 additions & 0 deletions libs/langgraph/tests/test_pregel_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -13295,3 +13295,39 @@ def step4(state: State):
],
"plan": [],
}


@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_command_goto_with_static_breakpoints(checkpointer_name: str) -> None:
"""Use Command goto with static breakpoints."""

class State(TypedDict):
"""The graph state."""

foo: Annotated[str, operator.add]

def node1(state: State):
return {
"foo": "|node-1",
}

def node2(state: State):
return {
"foo": "|node-2",
}

builder = StateGraph(State)
builder.add_node("node1", node1)
builder.add_node("node2", node2)
builder.add_edge(START, "node1")
builder.add_edge("node1", "node2")

async with awith_checkpointer(checkpointer_name) as checkpointer:
graph = builder.compile(checkpointer=checkpointer, interrupt_before=["node1"])

config = {"configurable": {"thread_id": str(uuid.uuid4())}}

# Start the graph and interrupt at the first node
await graph.ainvoke({"foo": "abc"}, config)
result = await graph.ainvoke(Command(goto=["node2"]), config)
assert result == {"foo": "abc|node-1|node-2|node-2"}

0 comments on commit 60d742e

Please sign in to comment.