Skip to content

Commit

Permalink
Retry subgraph starting at failing node (#1695)
Browse files Browse the repository at this point in the history
* Failing test

* Ensure retried subgraphs resume from current point (if any)

* Lint

* Cleanup Test

---------

Co-authored-by: Nuno Campos <[email protected]>
  • Loading branch information
hinthornw and nfcampos authored Sep 16, 2024
1 parent fa792e9 commit 0b1d0eb
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 3 deletions.
14 changes: 11 additions & 3 deletions libs/langgraph/langgraph/pregel/retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
import time
from typing import Optional

from langgraph.constants import CONFIG_KEY_RESUMING
from langgraph.errors import GraphInterrupt
from langgraph.pregel.types import PregelExecutableTask, RetryPolicy
from langgraph.utils.config import patch_configurable

logger = logging.getLogger(__name__)

Expand All @@ -18,12 +20,13 @@ def run_with_retry(
retry_policy = task.retry_policy or retry_policy
interval = retry_policy.initial_interval if retry_policy else 0
attempts = 0
config = task.config
while True:
try:
# clear any writes from previous attempts
task.writes.clear()
# run the task
task.proc.invoke(task.input, task.config)
task.proc.invoke(task.input, config)
# if successful, end
break
except GraphInterrupt:
Expand Down Expand Up @@ -56,6 +59,8 @@ def run_with_retry(
f"Retrying task {task.name} after {interval:.2f} seconds (attempt {attempts}) after {exc.__class__.__name__} {exc}",
exc_info=exc,
)
# signal subgraphs to resume (if available)
config = patch_configurable(config, {CONFIG_KEY_RESUMING: True})


async def arun_with_retry(
Expand All @@ -67,16 +72,17 @@ async def arun_with_retry(
retry_policy = task.retry_policy or retry_policy
interval = retry_policy.initial_interval if retry_policy else 0
attempts = 0
config = task.config
while True:
try:
# clear any writes from previous attempts
task.writes.clear()
# run the task
if stream:
async for _ in task.proc.astream(task.input, task.config):
async for _ in task.proc.astream(task.input, config):
pass
else:
await task.proc.ainvoke(task.input, task.config)
await task.proc.ainvoke(task.input, config)
# if successful, end
break
except GraphInterrupt:
Expand Down Expand Up @@ -109,3 +115,5 @@ async def arun_with_retry(
f"Retrying task {task.name} after {interval:.2f} seconds (attempt {attempts}) after {exc.__class__.__name__} {exc}",
exc_info=exc,
)
# signal subgraphs to resume (if available)
config = patch_configurable(config, {CONFIG_KEY_RESUMING: True})
53 changes: 53 additions & 0 deletions libs/langgraph/tests/test_pregel.py
Original file line number Diff line number Diff line change
Expand Up @@ -11023,3 +11023,56 @@ def _node(state: State):
app = parent.compile()

assert app.get_graph(xray=True).draw_mermaid() == snapshot


def test_subgraph_retries():
class State(TypedDict):
count: int

class ChildState(State):
some_list: Annotated[list, operator.add]

called_times = 0

class RandomError(ValueError):
"""This will be retried on."""

def parent_node(state: State):
return {"count": state["count"] + 1}

def child_node_a(state: ChildState):
nonlocal called_times
# We want it to retry only on node_b
# NOT re-compute the whole graph.
assert not called_times
called_times += 1
return {"some_list": ["val"]}

def child_node_b(state: ChildState):
raise RandomError("First attempt fails")

child = StateGraph(ChildState)
child.add_node(child_node_a)
child.add_node(child_node_b)
child.add_edge("__start__", "child_node_a")
child.add_edge("child_node_a", "child_node_b")

parent = StateGraph(State)
parent.add_node("parent_node", parent_node)
parent.add_node(
"child_graph",
child.compile(),
retry=RetryPolicy(
max_attempts=3,
retry_on=(RandomError,),
backoff_factor=0.0001,
initial_interval=0.0001,
),
)
parent.add_edge("parent_node", "child_graph")
parent.set_entry_point("parent_node")

checkpointer = MemorySaver()
app = parent.compile(checkpointer=checkpointer)
with pytest.raises(RandomError):
app.invoke({"count": 0}, {"configurable": {"thread_id": "foo"}})

0 comments on commit 0b1d0eb

Please sign in to comment.