Skip to content

Commit

Permalink
Handle Command returned from node (in addition to GraphCommand)
Browse files Browse the repository at this point in the history
  • Loading branch information
nfcampos committed Dec 3, 2024
1 parent 5e3c326 commit a203dde
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 12 deletions.
20 changes: 11 additions & 9 deletions libs/langgraph/langgraph/graph/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,15 +829,16 @@ def _coerce_state(schema: Type[Any], input: dict[str, Any]) -> dict[str, Any]:
def _control_branch(value: Any) -> Sequence[Union[str, Send]]:
if isinstance(value, Send):
return [value]
if not isinstance(value, GraphCommand):
if not isinstance(value, Command):
return EMPTY_SEQ
if value.graph == Command.PARENT:
raise ParentCommand(value)
rtn: list[Union[str, Send]] = []
if isinstance(value.goto, str):
rtn.append(value.goto)
else:
rtn.extend(value.goto)
if isinstance(value, GraphCommand):
if isinstance(value.goto, str):
rtn.append(value.goto)
else:
rtn.extend(value.goto)
if isinstance(value.send, Send):
rtn.append(value.send)
else:
Expand All @@ -853,10 +854,11 @@ async def _acontrol_branch(value: Any) -> Sequence[Union[str, Send]]:
if value.graph == Command.PARENT:
raise ParentCommand(value)
rtn: list[Union[str, Send]] = []
if isinstance(value.goto, str):
rtn.append(value.goto)
else:
rtn.extend(value.goto)
if isinstance(value, GraphCommand):
if isinstance(value.goto, str):
rtn.append(value.goto)
else:
rtn.extend(value.goto)
if isinstance(value.send, Send):
rtn.append(value.send)
else:
Expand Down
2 changes: 1 addition & 1 deletion libs/langgraph/tests/test_pregel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1925,7 +1925,7 @@ def __call__(self, state):

def send_for_fun(state):
return [
Send("2", GraphCommand(send=Send("2", 3))),
Send("2", Command(send=Send("2", 3))),
Send("2", GraphCommand(send=Send("2", 4))),
"3.1",
]
Expand Down
4 changes: 2 additions & 2 deletions libs/langgraph/tests/test_pregel_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -2573,14 +2573,14 @@ async def __call__(self, state):
if isinstance(state, list) # or isinstance(state, Control)
else ["|".join((self.name, str(state)))]
)
if isinstance(state, GraphCommand):
if isinstance(state, Command):
return replace(state, update=update)
else:
return update

async def send_for_fun(state):
return [
Send("2", GraphCommand(send=Send("2", 3))),
Send("2", Command(send=Send("2", 3))),
Send("2", GraphCommand(send=Send("2", 4))),
"3.1",
]
Expand Down

0 comments on commit a203dde

Please sign in to comment.