Skip to content

Commit

Permalink
Merge pull request #2695 from langchain-ai/nc/10dec/multistep-plan
Browse files Browse the repository at this point in the history
lib: Add unit test for multistep planner graph
  • Loading branch information
nfcampos authored Dec 10, 2024
2 parents 1fd9da6 + ef6c5b4 commit 0f287d9
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 0 deletions.
61 changes: 61 additions & 0 deletions libs/langgraph/tests/test_pregel.py
Original file line number Diff line number Diff line change
Expand Up @@ -14944,3 +14944,64 @@ def node2(state: State):
graph.invoke({"foo": "abc"}, config)
result = graph.invoke(Command(resume="node1"), config)
assert result == {"foo": "abc|node-1|node-2"}


@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
def test_multistep_plan(request: pytest.FixtureRequest, checkpointer_name: str):
from langchain_core.messages import AnyMessage

checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}")

class State(TypedDict, total=False):
plan: list[Union[str, list[str]]]
messages: Annotated[list[AnyMessage], add_messages]

def planner(state: State):
if state.get("plan") is None:
# create plan somehow
plan = ["step1", ["step2", "step3"], "step4"]
# pick the first step to execute next
first_step, *plan = plan
# put the rest of plan in state
return Command(goto=first_step, update={"plan": plan})
elif state["plan"]:
# go to the next step of the plan
next_step, *next_plan = state["plan"]
return Command(goto=next_step, update={"plan": next_plan})
else:
# the end of the plan
pass

def step1(state: State):
return Command(goto="planner", update={"messages": [("human", "step1")]})

def step2(state: State):
return Command(goto="planner", update={"messages": [("human", "step2")]})

def step3(state: State):
return Command(goto="planner", update={"messages": [("human", "step3")]})

def step4(state: State):
return Command(goto="planner", update={"messages": [("human", "step4")]})

builder = StateGraph(State)
builder.add_node(planner)
builder.add_node(step1)
builder.add_node(step2)
builder.add_node(step3)
builder.add_node(step4)
builder.add_edge(START, "planner")
graph = builder.compile(checkpointer=checkpointer)

config = {"configurable": {"thread_id": "1"}}

assert graph.invoke({"messages": [("human", "start")]}, config) == {
"messages": [
_AnyIdHumanMessage(content="start"),
_AnyIdHumanMessage(content="step1"),
_AnyIdHumanMessage(content="step2"),
_AnyIdHumanMessage(content="step3"),
_AnyIdHumanMessage(content="step4"),
],
"plan": [],
}
61 changes: 61 additions & 0 deletions libs/langgraph/tests/test_pregel_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -13199,3 +13199,64 @@ async def ask_age(s: State):
] == [
{"node": {"age": 19}},
]


@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_multistep_plan(checkpointer_name: str):
from langchain_core.messages import AnyMessage

class State(TypedDict, total=False):
plan: list[Union[str, list[str]]]
messages: Annotated[list[AnyMessage], add_messages]

def planner(state: State):
if state.get("plan") is None:
# create plan somehow
plan = ["step1", ["step2", "step3"], "step4"]
# pick the first step to execute next
first_step, *plan = plan
# put the rest of plan in state
return Command(goto=first_step, update={"plan": plan})
elif state["plan"]:
# go to the next step of the plan
next_step, *next_plan = state["plan"]
return Command(goto=next_step, update={"plan": next_plan})
else:
# the end of the plan
pass

def step1(state: State):
return Command(goto="planner", update={"messages": [("human", "step1")]})

def step2(state: State):
return Command(goto="planner", update={"messages": [("human", "step2")]})

def step3(state: State):
return Command(goto="planner", update={"messages": [("human", "step3")]})

def step4(state: State):
return Command(goto="planner", update={"messages": [("human", "step4")]})

builder = StateGraph(State)
builder.add_node(planner)
builder.add_node(step1)
builder.add_node(step2)
builder.add_node(step3)
builder.add_node(step4)
builder.add_edge(START, "planner")

async with awith_checkpointer(checkpointer_name) as checkpointer:
graph = builder.compile(checkpointer=checkpointer)

config = {"configurable": {"thread_id": "1"}}

assert await graph.ainvoke({"messages": [("human", "start")]}, config) == {
"messages": [
_AnyIdHumanMessage(content="start"),
_AnyIdHumanMessage(content="step1"),
_AnyIdHumanMessage(content="step2"),
_AnyIdHumanMessage(content="step3"),
_AnyIdHumanMessage(content="step4"),
],
"plan": [],
}

0 comments on commit 0f287d9

Please sign in to comment.