Skip to content

Commit

Permalink
Merge pull request #2693 from langchain-ai/eugene/fix_test
Browse files Browse the repository at this point in the history
langgraph[patch]: Fix unit test for Command(update)
  • Loading branch information
nfcampos authored Dec 10, 2024
2 parents 0f287d9 + 081b2cb commit a7d1ecb
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 23 deletions.
24 changes: 12 additions & 12 deletions libs/langgraph/langgraph/pregel/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,18 +422,6 @@ def tick(
self.status = "out_of_steps"
return False

# apply NULL writes
if null_writes := [
w[1:] for w in self.checkpoint_pending_writes if w[0] == NULL_TASK_ID
]:
mv_writes = apply_writes(
self.checkpoint,
self.channels,
[PregelTaskWrites((), INPUT, null_writes, [])],
self.checkpointer_get_next_version,
)
for key, values in mv_writes.items():
self._update_mv(key, values)
# prepare next tasks
self.tasks = prepare_next_tasks(
self.checkpoint,
Expand Down Expand Up @@ -552,6 +540,18 @@ def _first(self, *, input_keys: Union[str, Sequence[str]]) -> None:
# save writes
for tid, ws in writes.items():
self.put_writes(tid, ws)
# apply NULL writes
if null_writes := [
w[1:] for w in self.checkpoint_pending_writes if w[0] == NULL_TASK_ID
]:
mv_writes = apply_writes(
self.checkpoint,
self.channels,
[PregelTaskWrites((), INPUT, null_writes, [])],
self.checkpointer_get_next_version,
)
for key, values in mv_writes.items():
self._update_mv(key, values)
# proceed past previous checkpoint
if is_resuming:
self.checkpoint["versions_seen"].setdefault(INTERRUPT, {})
Expand Down
20 changes: 9 additions & 11 deletions libs/langgraph/tests/test_pregel.py
Original file line number Diff line number Diff line change
Expand Up @@ -14906,9 +14906,14 @@ def my_node(state: State):
assert graph.invoke({"foo": ""}) == {"foo": "ab"}


def test_command_with_static_breakpoints() -> None:
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
def test_command_with_static_breakpoints(
request: pytest.FixtureRequest, checkpointer_name: str
) -> None:
"""Test that we can use Command to resume and update with static breakpoints."""

checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}")

class State(TypedDict):
"""The graph state."""

Expand All @@ -14930,20 +14935,13 @@ def node2(state: State):
builder.add_edge(START, "node1")
builder.add_edge("node1", "node2")

# A checkpointer must be enabled for interrupts to work!
checkpointer = MemorySaver()
graph = builder.compile(checkpointer=checkpointer, interrupt_before=["node1"])

config = {
"configurable": {
"thread_id": uuid.uuid4(),
}
}
config = {"configurable": {"thread_id": str(uuid.uuid4())}}

# Start the graph and interrupt at the first node
graph.invoke({"foo": "abc"}, config)
result = graph.invoke(Command(resume="node1"), config)
assert result == {"foo": "abc|node-1|node-2"}
result = graph.invoke(Command(update={"foo": "def"}), config)
assert result == {"foo": "def|node-1|node-2"}


@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
Expand Down
35 changes: 35 additions & 0 deletions libs/langgraph/tests/test_pregel_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -13201,6 +13201,41 @@ async def ask_age(s: State):
]


@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_command_with_static_breakpoints(checkpointer_name: str) -> None:
"""Test that we can use Command to resume and update with static breakpoints."""

class State(TypedDict):
"""The graph state."""

foo: str

def node1(state: State):
return {
"foo": state["foo"] + "|node-1",
}

def node2(state: State):
return {
"foo": state["foo"] + "|node-2",
}

builder = StateGraph(State)
builder.add_node("node1", node1)
builder.add_node("node2", node2)
builder.add_edge(START, "node1")
builder.add_edge("node1", "node2")

async with awith_checkpointer(checkpointer_name) as checkpointer:
graph = builder.compile(checkpointer=checkpointer, interrupt_before=["node1"])
config = {"configurable": {"thread_id": str(uuid.uuid4())}}

# Start the graph and interrupt at the first node
await graph.ainvoke({"foo": "abc"}, config)
result = await graph.ainvoke(Command(update={"foo": "def"}), config)
assert result == {"foo": "def|node-1|node-2"}


@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_multistep_plan(checkpointer_name: str):
from langchain_core.messages import AnyMessage
Expand Down

0 comments on commit a7d1ecb

Please sign in to comment.