Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lib: Add support for returning multiple commands from a node #2658

Merged
merged 3 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 90 additions & 49 deletions libs/langgraph/langgraph/graph/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,11 @@
is_writable_managed_value,
)
from langgraph.pregel.read import ChannelRead, PregelNode
from langgraph.pregel.write import SKIP_WRITE, ChannelWrite, ChannelWriteEntry
from langgraph.pregel.write import (
ChannelWrite,
ChannelWriteEntry,
ChannelWriteTupleEntry,
)
from langgraph.store.base import BaseStore
from langgraph.types import All, Checkpointer, Command, RetryPolicy
from langgraph.utils.fields import get_field_default
Expand Down Expand Up @@ -608,33 +612,53 @@ def attach_node(self, key: str, node: Optional[StateNodeSpec]) -> None:
if is_writable_managed_value(v)
]

def _get_root(input: Any) -> Any:
if isinstance(input, Command):
def _get_root(input: Any) -> Optional[Sequence[tuple[str, Any]]]:
if (
isinstance(input, (list, tuple))
and input
and all(isinstance(i, Command) for i in input)
):
updates: list[tuple[str, Any]] = []
for i in input:
if i.graph == Command.PARENT:
continue
updates.extend(i._update_as_tuples())
return updates
elif isinstance(input, Command):
if input.graph == Command.PARENT:
return SKIP_WRITE
return input.update
else:
return input

# to avoid name collision below
node_key = key

def _get_state_key(input: Union[None, dict, Any], *, key: str) -> Any:
return ()
return input._update_as_tuples()
elif input is not None:
return [("__root__", input)]

def _get_updates(
input: Union[None, dict, Any],
) -> Optional[Sequence[tuple[str, Any]]]:
if input is None:
return SKIP_WRITE
return None
elif isinstance(input, dict):
if all(k not in output_keys for k in input):
raise InvalidUpdateError(
f"Expected node {node_key} to update at least one of {output_keys}, got {input}"
)
return input.get(key, SKIP_WRITE)
return [(k, v) for k, v in input.items() if k in output_keys]
elif isinstance(input, Command):
if input.graph == Command.PARENT:
return SKIP_WRITE
return _get_state_key(input.update, key=key)
return None
return input._update_as_tuples()
elif (
isinstance(input, (list, tuple))
and input
and all(isinstance(i, Command) for i in input)
):
updates: list[tuple[str, Any]] = []
for i in input:
if i.graph == Command.PARENT:
continue
updates.extend(i._update_as_tuples())
return updates
elif get_type_hints(type(input)):
value = getattr(input, key, SKIP_WRITE)
return value if value is not None else SKIP_WRITE
return [
(k, getattr(input, k))
for k in output_keys
if getattr(input, k, None) is not None
]
else:
msg = create_error_message(
message=f"Expected dict, got {input}",
Expand All @@ -643,14 +667,11 @@ def _get_state_key(input: Union[None, dict, Any], *, key: str) -> Any:
raise InvalidUpdateError(msg)

# state updaters
write_entries = (
[ChannelWriteEntry("__root__", skip_none=True, mapper=_get_root)]
if output_keys == ["__root__"]
else [
ChannelWriteEntry(key, mapper=partial(_get_state_key, key=key))
for key in output_keys
]
)
write_entries: list[Union[ChannelWriteEntry, ChannelWriteTupleEntry]] = [
ChannelWriteTupleEntry(
mapper=_get_root if output_keys == ["__root__"] else _get_updates
)
]

# add node and output channel
if key == START:
Expand Down Expand Up @@ -685,7 +706,7 @@ def _get_state_key(input: Union[None, dict, Any], *, key: str) -> Any:
writers=[
# publish to this channel and state keys
ChannelWrite(
[ChannelWriteEntry(key, key)] + write_entries,
write_entries + [ChannelWriteEntry(key, key)],
tags=[TAG_HIDDEN],
),
],
Expand Down Expand Up @@ -811,34 +832,54 @@ def _coerce_state(schema: Type[Any], input: dict[str, Any]) -> dict[str, Any]:
def _control_branch(value: Any) -> Sequence[Union[str, Send]]:
if isinstance(value, Send):
return [value]
if not isinstance(value, Command):
commands: list[Command] = []
if isinstance(value, Command):
commands.append(value)
elif (
isinstance(value, (list, tuple))
and value
and all(isinstance(i, Command) for i in value)
):
commands.extend(value)
else:
return EMPTY_SEQ
if value.graph == Command.PARENT:
raise ParentCommand(value)
rtn: list[Union[str, Send]] = []
if isinstance(value.goto, Send):
rtn.append(value.goto)
elif isinstance(value.goto, str):
rtn.append(value.goto)
else:
rtn.extend(value.goto)
for command in commands:
if command.graph == Command.PARENT:
raise ParentCommand(command)
if isinstance(command.goto, Send):
rtn.append(command.goto)
elif isinstance(command.goto, str):
rtn.append(command.goto)
else:
rtn.extend(command.goto)
return rtn


async def _acontrol_branch(value: Any) -> Sequence[Union[str, Send]]:
if isinstance(value, Send):
return [value]
if not isinstance(value, Command):
commands: list[Command] = []
if isinstance(value, Command):
commands.append(value)
elif (
isinstance(value, (list, tuple))
and value
and all(isinstance(i, Command) for i in value)
):
commands.extend(value)
else:
return EMPTY_SEQ
if value.graph == Command.PARENT:
raise ParentCommand(value)
rtn: list[Union[str, Send]] = []
if isinstance(value.goto, Send):
rtn.append(value.goto)
elif isinstance(value.goto, str):
rtn.append(value.goto)
else:
rtn.extend(value.goto)
for command in commands:
if command.graph == Command.PARENT:
raise ParentCommand(command)
if isinstance(command.goto, Send):
rtn.append(command.goto)
elif isinstance(command.goto, str):
rtn.append(command.goto)
else:
rtn.extend(command.goto)
return rtn


Expand Down
6 changes: 1 addition & 5 deletions libs/langgraph/langgraph/pregel/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,7 @@ def map_command(
else:
yield (NULL_TASK_ID, RESUME, cmd.resume)
if cmd.update:
if not isinstance(cmd.update, dict):
raise TypeError(
f"Expected cmd.update to be a dict mapping channel names to update values, got {type(cmd.update).__name__}"
)
for k, v in cmd.update.items():
for k, v in cmd._update_as_tuples():
yield (NULL_TASK_ID, k, v)


Expand Down
71 changes: 44 additions & 27 deletions libs/langgraph/langgraph/pregel/write.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,31 +36,40 @@ class ChannelWriteEntry(NamedTuple):
"""Function to transform the value before writing."""


class ChannelWriteTupleEntry(NamedTuple):
mapper: Callable[[Any], Optional[Sequence[tuple[str, Any]]]]
"""Function to extract tuples from value."""
value: Any = PASSTHROUGH
"""Value to write, or PASSTHROUGH to use the input."""


class ChannelWrite(RunnableCallable):
"""Implements th logic for sending writes to CONFIG_KEY_SEND.
"""Implements the logic for sending writes to CONFIG_KEY_SEND.
Can be used as a runnable or as a static method to call imperatively."""

writes: list[Union[ChannelWriteEntry, Send]]
writes: list[Union[ChannelWriteEntry, ChannelWriteTupleEntry, Send]]
"""Sequence of write entries or Send objects to write."""
require_at_least_one_of: Optional[Sequence[str]]
"""If defined, at least one of these channels must be written to."""

def __init__(
self,
writes: Sequence[Union[ChannelWriteEntry, Send]],
writes: Sequence[Union[ChannelWriteEntry, ChannelWriteTupleEntry, Send]],
*,
tags: Optional[Sequence[str]] = None,
require_at_least_one_of: Optional[Sequence[str]] = None,
):
super().__init__(func=self._write, afunc=self._awrite, name=None, tags=tags)
self.writes = cast(list[Union[ChannelWriteEntry, Send]], writes)
self.writes = cast(
list[Union[ChannelWriteEntry, ChannelWriteTupleEntry, Send]], writes
)
self.require_at_least_one_of = require_at_least_one_of

def get_name(
self, suffix: Optional[str] = None, *, name: Optional[str] = None
) -> str:
if not name:
name = f"ChannelWrite<{','.join(w.channel if isinstance(w, ChannelWriteEntry) else w.node for w in self.writes)}>"
name = f"ChannelWrite<{','.join(w.channel if isinstance(w, ChannelWriteEntry) else '...' if isinstance(w, ChannelWriteTupleEntry) else w.node for w in self.writes)}>"
return super().get_name(suffix, name=name)

@property
Expand All @@ -79,6 +88,8 @@ def _write(self, input: Any, config: RunnableConfig) -> None:
writes = [
ChannelWriteEntry(write.channel, input, write.skip_none, write.mapper)
if isinstance(write, ChannelWriteEntry) and write.value is PASSTHROUGH
else ChannelWriteTupleEntry(write.mapper, input)
if isinstance(write, ChannelWriteTupleEntry) and write.value is PASSTHROUGH
else write
for write in self.writes
]
Expand All @@ -93,6 +104,8 @@ async def _awrite(self, input: Any, config: RunnableConfig) -> None:
writes = [
ChannelWriteEntry(write.channel, input, write.skip_none, write.mapper)
if isinstance(write, ChannelWriteEntry) and write.value is PASSTHROUGH
else ChannelWriteTupleEntry(write.mapper, input)
if isinstance(write, ChannelWriteTupleEntry) and write.value is PASSTHROUGH
else write
for write in self.writes
]
Expand All @@ -106,7 +119,7 @@ async def _awrite(self, input: Any, config: RunnableConfig) -> None:
@staticmethod
def do_write(
config: RunnableConfig,
writes: Sequence[Union[ChannelWriteEntry, Send]],
writes: Sequence[Union[ChannelWriteEntry, ChannelWriteTupleEntry, Send]],
require_at_least_one_of: Optional[Sequence[str]] = None,
) -> None:
# validate
Expand All @@ -118,32 +131,36 @@ def do_write(
)
if w.value is PASSTHROUGH:
raise InvalidUpdateError("PASSTHROUGH value must be replaced")
# split packets and entries
sends = [
(PUSH if FF_SEND_V2 else TASKS, packet)
for packet in writes
if isinstance(packet, Send)
]
entries = [write for write in writes if isinstance(write, ChannelWriteEntry)]
# process entries into values
values = [
write.mapper(write.value) if write.mapper is not None else write.value
for write in entries
]
values = [
(write.channel, val)
for val, write in zip(values, entries)
if not write.skip_none or val is not None
]
# filter out SKIP_WRITE values
filtered = [(chan, val) for chan, val in values if val is not SKIP_WRITE]
if isinstance(w, ChannelWriteTupleEntry):
if w.value is PASSTHROUGH:
raise InvalidUpdateError("PASSTHROUGH value must be replaced")
# assemble writes
tuples: list[tuple[str, Any]] = []
print(writes)
for w in writes:
if isinstance(w, Send):
tuples.append((PUSH if FF_SEND_V2 else TASKS, w))
elif isinstance(w, ChannelWriteTupleEntry):
if ww := w.mapper(w.value):
tuples.extend(ww)
elif isinstance(w, ChannelWriteEntry):
value = w.mapper(w.value) if w.mapper is not None else w.value
if value is SKIP_WRITE:
continue
if w.skip_none and value is None:
continue
tuples.append((w.channel, value))
else:
raise ValueError(f"Invalid write entry: {w}")
print(tuples, require_at_least_one_of)
# assert required channels
if require_at_least_one_of is not None:
if not {chan for chan, _ in filtered} & set(require_at_least_one_of):
if not {chan for chan, _ in tuples} & set(require_at_least_one_of):
raise InvalidUpdateError(
f"Must write to at least one of {require_at_least_one_of}"
)
write: TYPE_SEND = config[CONF][CONFIG_KEY_SEND]
write(sends + filtered)
write(tuples)

@staticmethod
def is_writer(runnable: Runnable) -> bool:
Expand Down
13 changes: 12 additions & 1 deletion libs/langgraph/langgraph/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ class Command(Generic[N]):
"""

graph: Optional[str] = None
update: Optional[dict[str, Any]] = None
update: Union[dict[str, Any], Sequence[tuple[str, Any]]] = ()
resume: Optional[Union[Any, dict[str, Any]]] = None
goto: Union[Send, Sequence[Union[Send, str]], str] = ()

Expand All @@ -276,6 +276,17 @@ def __repr__(self) -> str:
)
return f"Command({contents})"

def _update_as_tuples(self) -> Sequence[tuple[str, Any]]:
if isinstance(self.update, dict):
return list(self.update.items())
elif isinstance(self.update, (list, tuple)) and all(
isinstance(t, tuple) and len(t) == 2 and isinstance(t[0], str)
for t in self.update
):
return self.update
else:
return [("__root__", self.update)]

PARENT: ClassVar[Literal["__parent__"]] = "__parent__"


Expand Down
14 changes: 1 addition & 13 deletions libs/langgraph/tests/test_pregel.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,18 +220,6 @@ def logic(inp: str) -> str:
class State(TypedDict):
hello: str

def node_a(state: State) -> State:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ooc, why remove this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this isn't intended to raise exception, not sure why this was still here

# typo
return {"hell": "world"}

builder = StateGraph(State)
builder.add_node("a", node_a)
builder.set_entry_point("a")
builder.set_finish_point("a")
graph = builder.compile()
with pytest.raises(InvalidUpdateError):
graph.invoke({"hello": "there"})

graph = StateGraph(State)
graph.add_node("start", lambda x: x)
graph.add_edge("__start__", "start")
Expand Down Expand Up @@ -1919,7 +1907,7 @@ def __call__(self, state):
else ["|".join((self.name, str(state)))]
)
if isinstance(state, Command):
return replace(state, update=update)
return [state, Command(update=update)]
else:
return update

Expand Down
Loading