Skip to content

Commit

Permalink
Merge branch 'main' into patch-1
Browse files Browse the repository at this point in the history
  • Loading branch information
mohit-mangal authored Dec 7, 2024
2 parents 04f52a5 + 6784a5a commit 8445e4e
Show file tree
Hide file tree
Showing 6 changed files with 209 additions and 245 deletions.
130 changes: 36 additions & 94 deletions docs/docs/tutorials/multi_agent/agent_supervisor.ipynb

Large diffs are not rendered by default.

158 changes: 82 additions & 76 deletions docs/docs/tutorials/multi_agent/hierarchical_agent_teams.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -289,15 +289,10 @@
"from langchain_core.language_models.chat_models import BaseChatModel\n",
"\n",
"from langgraph.graph import StateGraph, MessagesState, START, END\n",
"from langgraph.types import Command\n",
"from langchain_core.messages import HumanMessage, trim_messages\n",
"\n",
"\n",
"# The agent state is the input to each node in the graph\n",
"class AgentState(MessagesState):\n",
" # The 'next' field indicates where to route to next\n",
" next: str\n",
"\n",
"\n",
"def make_supervisor_node(llm: BaseChatModel, members: list[str]) -> str:\n",
" options = [\"FINISH\"] + members\n",
" system_prompt = (\n",
Expand All @@ -313,17 +308,17 @@
"\n",
" next: Literal[*options]\n",
"\n",
" def supervisor_node(state: MessagesState) -> MessagesState:\n",
" def supervisor_node(state: MessagesState) -> Command[Literal[*members, \"__end__\"]]:\n",
" \"\"\"An LLM-based router.\"\"\"\n",
" messages = [\n",
" {\"role\": \"system\", \"content\": system_prompt},\n",
" ] + state[\"messages\"]\n",
" response = llm.with_structured_output(Router).invoke(messages)\n",
" next_ = response[\"next\"]\n",
" if next_ == \"FINISH\":\n",
" next_ = END\n",
" goto = response[\"next\"]\n",
" if goto == \"FINISH\":\n",
" goto = END\n",
"\n",
" return {\"next\": next_}\n",
" return Command(goto=goto)\n",
"\n",
" return supervisor_node"
]
Expand Down Expand Up @@ -363,25 +358,33 @@
"search_agent = create_react_agent(llm, tools=[tavily_tool])\n",
"\n",
"\n",
"def search_node(state: AgentState) -> AgentState:\n",
"def search_node(state: MessagesState) -> Command[Literal[\"supervisor\"]]:\n",
" result = search_agent.invoke(state)\n",
" return {\n",
" \"messages\": [\n",
" HumanMessage(content=result[\"messages\"][-1].content, name=\"search\")\n",
" ]\n",
" }\n",
" return Command(\n",
" update={\n",
" \"messages\": [\n",
" HumanMessage(content=result[\"messages\"][-1].content, name=\"search\")\n",
" ]\n",
" },\n",
" # We want our workers to ALWAYS \"report back\" to the supervisor when done\n",
" goto=\"supervisor\",\n",
" )\n",
"\n",
"\n",
"web_scraper_agent = create_react_agent(llm, tools=[scrape_webpages])\n",
"\n",
"\n",
"def web_scraper_node(state: AgentState) -> AgentState:\n",
"def web_scraper_node(state: MessagesState) -> Command[Literal[\"supervisor\"]]:\n",
" result = web_scraper_agent.invoke(state)\n",
" return {\n",
" \"messages\": [\n",
" HumanMessage(content=result[\"messages\"][-1].content, name=\"web_scraper\")\n",
" ]\n",
" }\n",
" return Command(\n",
" update={\n",
" \"messages\": [\n",
" HumanMessage(content=result[\"messages\"][-1].content, name=\"web_scraper\")\n",
" ]\n",
" },\n",
" # We want our workers to ALWAYS \"report back\" to the supervisor when done\n",
" goto=\"supervisor\",\n",
" )\n",
"\n",
"\n",
"research_supervisor_node = make_supervisor_node(llm, [\"search\", \"web_scraper\"])"
Expand Down Expand Up @@ -412,14 +415,7 @@
"research_builder.add_node(\"search\", search_node)\n",
"research_builder.add_node(\"web_scraper\", web_scraper_node)\n",
"\n",
"# Define the control flow\n",
"research_builder.add_edge(START, \"supervisor\")\n",
"# We want our workers to ALWAYS \"report back\" to the supervisor when done\n",
"research_builder.add_edge(\"search\", \"supervisor\")\n",
"research_builder.add_edge(\"web_scraper\", \"supervisor\")\n",
"# Add the edges where routing applies\n",
"research_builder.add_conditional_edges(\"supervisor\", lambda state: state[\"next\"])\n",
"\n",
"research_graph = research_builder.compile()"
]
},
Expand Down Expand Up @@ -532,13 +528,17 @@
")\n",
"\n",
"\n",
"def doc_writing_node(state: AgentState) -> AgentState:\n",
"def doc_writing_node(state: MessagesState) -> Command[Literal[\"supervisor\"]]:\n",
" result = doc_writer_agent.invoke(state)\n",
" return {\n",
" \"messages\": [\n",
" HumanMessage(content=result[\"messages\"][-1].content, name=\"doc_writer\")\n",
" ]\n",
" }\n",
" return Command(\n",
" update={\n",
" \"messages\": [\n",
" HumanMessage(content=result[\"messages\"][-1].content, name=\"doc_writer\")\n",
" ]\n",
" },\n",
" # We want our workers to ALWAYS \"report back\" to the supervisor when done\n",
" goto=\"supervisor\",\n",
" )\n",
"\n",
"\n",
"note_taking_agent = create_react_agent(\n",
Expand All @@ -551,27 +551,37 @@
")\n",
"\n",
"\n",
"def note_taking_node(state: AgentState) -> AgentState:\n",
"def note_taking_node(state: MessagesState) -> Command[Literal[\"supervisor\"]]:\n",
" result = note_taking_agent.invoke(state)\n",
" return {\n",
" \"messages\": [\n",
" HumanMessage(content=result[\"messages\"][-1].content, name=\"note_taker\")\n",
" ]\n",
" }\n",
" return Command(\n",
" update={\n",
" \"messages\": [\n",
" HumanMessage(content=result[\"messages\"][-1].content, name=\"note_taker\")\n",
" ]\n",
" },\n",
" # We want our workers to ALWAYS \"report back\" to the supervisor when done\n",
" goto=\"supervisor\",\n",
" )\n",
"\n",
"\n",
"chart_generating_agent = create_react_agent(\n",
" llm, tools=[read_document, python_repl_tool]\n",
")\n",
"\n",
"\n",
"def chart_generating_node(state: AgentState) -> AgentState:\n",
"def chart_generating_node(state: MessagesState) -> Command[Literal[\"supervisor\"]]:\n",
" result = chart_generating_agent.invoke(state)\n",
" return {\n",
" \"messages\": [\n",
" HumanMessage(content=result[\"messages\"][-1].content, name=\"chart_generator\")\n",
" ]\n",
" }\n",
" return Command(\n",
" update={\n",
" \"messages\": [\n",
" HumanMessage(\n",
" content=result[\"messages\"][-1].content, name=\"chart_generator\"\n",
" )\n",
" ]\n",
" },\n",
" # We want our workers to ALWAYS \"report back\" to the supervisor when done\n",
" goto=\"supervisor\",\n",
" )\n",
"\n",
"\n",
"doc_writing_supervisor_node = make_supervisor_node(\n",
Expand Down Expand Up @@ -600,21 +610,13 @@
"outputs": [],
"source": [
"# Create the graph here\n",
"paper_writing_builder = StateGraph(AgentState)\n",
"paper_writing_builder = StateGraph(MessagesState)\n",
"paper_writing_builder.add_node(\"supervisor\", doc_writing_supervisor_node)\n",
"paper_writing_builder.add_node(\"doc_writer\", doc_writing_node)\n",
"paper_writing_builder.add_node(\"note_taker\", note_taking_node)\n",
"paper_writing_builder.add_node(\"chart_generator\", chart_generating_node)\n",
"\n",
"# Define the control flow\n",
"paper_writing_builder.add_edge(START, \"supervisor\")\n",
"# We want our workers to ALWAYS \"report back\" to the supervisor when done\n",
"paper_writing_builder.add_edge(\"doc_writer\", \"supervisor\")\n",
"paper_writing_builder.add_edge(\"note_taker\", \"supervisor\")\n",
"paper_writing_builder.add_edge(\"chart_generator\", \"supervisor\")\n",
"# Add the edges where routing applies\n",
"paper_writing_builder.add_conditional_edges(\"supervisor\", lambda state: state[\"next\"])\n",
"\n",
"paper_writing_graph = paper_writing_builder.compile()"
]
},
Expand Down Expand Up @@ -728,37 +730,41 @@
},
"outputs": [],
"source": [
"def call_research_team(state: AgentState) -> AgentState:\n",
"def call_research_team(state: MessagesState) -> Command[Literal[\"supervisor\"]]:\n",
" response = research_graph.invoke({\"messages\": state[\"messages\"][-1]})\n",
" return {\n",
" \"messages\": [\n",
" HumanMessage(content=response[\"messages\"][-1].content, name=\"research_team\")\n",
" ]\n",
" }\n",
" return Command(\n",
" update={\n",
" \"messages\": [\n",
" HumanMessage(\n",
" content=response[\"messages\"][-1].content, name=\"research_team\"\n",
" )\n",
" ]\n",
" },\n",
" goto=\"supervisor\",\n",
" )\n",
"\n",
"\n",
"def call_paper_writing_team(state: AgentState) -> AgentState:\n",
"def call_paper_writing_team(state: MessagesState) -> Command[Literal[\"supervisor\"]]:\n",
" response = paper_writing_graph.invoke({\"messages\": state[\"messages\"][-1]})\n",
" return {\n",
" \"messages\": [\n",
" HumanMessage(content=response[\"messages\"][-1].content, name=\"writing_team\")\n",
" ]\n",
" }\n",
" return Command(\n",
" update={\n",
" \"messages\": [\n",
" HumanMessage(\n",
" content=response[\"messages\"][-1].content, name=\"writing_team\"\n",
" )\n",
" ]\n",
" },\n",
" goto=\"supervisor\",\n",
" )\n",
"\n",
"\n",
"# Define the graph.\n",
"super_builder = StateGraph(AgentState)\n",
"super_builder = StateGraph(MessagesState)\n",
"super_builder.add_node(\"supervisor\", teams_supervisor_node)\n",
"super_builder.add_node(\"research_team\", call_research_team)\n",
"super_builder.add_node(\"writing_team\", call_paper_writing_team)\n",
"\n",
"# Define the control flow\n",
"super_builder.add_edge(START, \"supervisor\")\n",
"# We want our teams to ALWAYS \"report back\" to the top-level supervisor when done\n",
"super_builder.add_edge(\"research_team\", \"supervisor\")\n",
"super_builder.add_edge(\"writing_team\", \"supervisor\")\n",
"# Add the edges where routing applies\n",
"super_builder.add_conditional_edges(\"supervisor\", lambda state: state[\"next\"])\n",
"super_graph = super_builder.compile()"
]
},
Expand Down
100 changes: 41 additions & 59 deletions docs/docs/tutorials/multi_agent/multi-agent-collaboration.ipynb

Large diffs are not rendered by default.

32 changes: 19 additions & 13 deletions libs/langgraph/langgraph/graph/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,21 +613,24 @@ def attach_node(self, key: str, node: Optional[StateNodeSpec]) -> None:
]

def _get_root(input: Any) -> Optional[Sequence[tuple[str, Any]]]:
if (
if isinstance(input, Command):
if input.graph == Command.PARENT:
return ()
return input._update_as_tuples()
elif (
isinstance(input, (list, tuple))
and input
and all(isinstance(i, Command) for i in input)
and any(isinstance(i, Command) for i in input)
):
updates: list[tuple[str, Any]] = []
for i in input:
if i.graph == Command.PARENT:
continue
updates.extend(i._update_as_tuples())
if isinstance(i, Command):
if i.graph == Command.PARENT:
continue
updates.extend(i._update_as_tuples())
else:
updates.append(("__root__", i))
return updates
elif isinstance(input, Command):
if input.graph == Command.PARENT:
return ()
return input._update_as_tuples()
elif input is not None:
return [("__root__", input)]

Expand All @@ -645,13 +648,16 @@ def _get_updates(
elif (
isinstance(input, (list, tuple))
and input
and all(isinstance(i, Command) for i in input)
and any(isinstance(i, Command) for i in input)
):
updates: list[tuple[str, Any]] = []
for i in input:
if i.graph == Command.PARENT:
continue
updates.extend(i._update_as_tuples())
if isinstance(i, Command):
if i.graph == Command.PARENT:
continue
updates.extend(i._update_as_tuples())
else:
updates.extend(_get_updates(i) or ())
return updates
elif get_type_hints(type(input)):
return [
Expand Down
28 changes: 28 additions & 0 deletions libs/langgraph/tests/test_pregel.py
Original file line number Diff line number Diff line change
Expand Up @@ -14807,3 +14807,31 @@ def ask_age(s: State):
assert [event for event in graph.stream(Command(resume="19"), thread1)] == [
{"node": {"age": 19}},
]


def test_root_mixed_return() -> None:
def my_node(state: list[str]):
return [Command(update=["a"]), ["b"]]

graph = StateGraph(Annotated[list[str], operator.add])

graph.add_node(my_node)
graph.add_edge(START, "my_node")
graph = graph.compile()

assert graph.invoke([]) == ["a", "b"]


def test_dict_mixed_return() -> None:
class State(TypedDict):
foo: Annotated[str, operator.add]

def my_node(state: State):
return [Command(update={"foo": "a"}), {"foo": "b"}]

graph = StateGraph(State)
graph.add_node(my_node)
graph.add_edge(START, "my_node")
graph = graph.compile()

assert graph.invoke({"foo": ""}) == {"foo": "ab"}
6 changes: 3 additions & 3 deletions libs/langgraph/tests/test_pregel_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -9670,14 +9670,14 @@ async def outer_2(state: State):
),
(FloatBetween(0.2, 0.4), ((), {"outer_1": {"my_key": " and parallel"}})),
(
FloatBetween(0.5, 0.7),
FloatBetween(0.5, 0.8),
(
(AnyStr("inner:"),),
{"inner_2": {"my_key": " and there", "my_other_key": "got here"}},
),
),
(FloatBetween(0.5, 0.7), ((), {"inner": {"my_key": "got here and there"}})),
(FloatBetween(0.5, 0.7), ((), {"outer_2": {"my_key": " and back again"}})),
(FloatBetween(0.5, 0.8), ((), {"inner": {"my_key": "got here and there"}})),
(FloatBetween(0.5, 0.8), ((), {"outer_2": {"my_key": " and back again"}})),
]


Expand Down

0 comments on commit 8445e4e

Please sign in to comment.