-
Notifications
You must be signed in to change notification settings - Fork 878
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update chatbot and RAG assistant to use StateGraph
in the backend
#305
Update chatbot and RAG assistant to use StateGraph
in the backend
#305
Conversation
@@ -12,6 +12,7 @@ export interface MessageDocument { | |||
export interface Message { | |||
id: string; | |||
type: string; | |||
role?: string; // for chat_retrieval bot |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I couldn't decide what to do about this...
Calling POST /runs/stream
requires the role
field in the message. Otherwise, the request body will not be serialized appropriately in the backend and an error is raised langgraph.graph.message.add_messages
(state reducer function).
It seems more appropriate for the client (frontend) to handle this instead of the API (backend), but I'm not sure.
@@ -19,8 +25,13 @@ def loads(value: bytes) -> Checkpoint: | |||
|
|||
|
|||
class PostgresCheckpoint(BaseCheckpointSaver): | |||
class Config: | |||
arbitrary_types_allowed = True | |||
def __init__( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change is required since the introduction of the serde
API in BaseCheckpointSaver
.
StateGraph
in the backendStateGraph
in the backend
backend/app/agent.py
Outdated
@@ -244,7 +246,10 @@ def __init__( | |||
llm=ConfigurableField(id="llm_type", name="LLM Type"), | |||
system_message=ConfigurableField(id="system_message", name="Instructions"), | |||
) | |||
.with_types(input_type=Sequence[AnyMessage], output_type=Sequence[AnyMessage]) | |||
.with_types( | |||
input_type=Union[Sequence[AnyMessage], Dict[str, Any]], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these types don't seem right, it either accepts list of messages or dict, not both?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated types to Messages
/Sequence[AnyMessage]
for input/output respectively.
backend/app/stream.py
Outdated
state_chunk_msgs: Union[Sequence[AnyMessage], Dict[str, Any]] = event[ | ||
"data" | ||
]["chunk"] | ||
if isinstance(state_chunk_msgs, Dict): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should be lowercase dict
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch. I'll update all other instances of this error.
backend/app/chatbot.py
Outdated
@@ -15,7 +18,7 @@ def _get_messages(messages): | |||
|
|||
chatbot = _get_messages | llm | |||
|
|||
workflow = MessageGraph() | |||
workflow = StateGraph(Annotated[Sequence[BaseMessage], add_messages]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should be using list?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It didn't make a difference after this PR was merged: langchain-ai/langgraph#321. I'll update both instances to use List
.
This is where the API feels ambiguous with respect to the underlying functionality. If there's a "preferred" or required type that should be used, then an end user isn't necessarily aware of it. Something to think about for later.
backend/app/retrieval.py
Outdated
@@ -39,6 +42,10 @@ def get_retrieval_executor( | |||
system_message: str, | |||
checkpoint: BaseCheckpointSaver, | |||
): | |||
class AgentState(TypedDict): | |||
messages: Annotated[Sequence[BaseMessage], add_messages] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should be using list?
frontend/src/App.tsx
Outdated
// Each message must contain a `role` field. | ||
input = { | ||
messages: input.map((msg: Message) => { | ||
msg.role = "human"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how come?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The following error is raised:
opengpts-backend | Traceback (most recent call last):
opengpts-backend | File "/backend/app/stream.py", line 63, in to_sse
opengpts-backend | async for chunk in messages_stream:
opengpts-backend | File "/backend/app/stream.py", line 23, in astream_state
opengpts-backend | async for event in app.astream_events(
opengpts-backend | File "/usr/local/lib/python3.11/site-packages/langchain_core/runnables/base.py", line 4711, in astream_events
opengpts-backend | async for item in self.bound.astream_events(
opengpts-backend | File "/usr/local/lib/python3.11/site-packages/langchain_core/runnables/base.py", line 1137, in astream_events
opengpts-backend | async for log in _astream_log_implementation( # type: ignore[misc]
opengpts-backend | File "/usr/local/lib/python3.11/site-packages/langchain_core/tracers/log_stream.py", line 616, in _astream_log_implementation
opengpts-backend | await task
opengpts-backend | File "/usr/local/lib/python3.11/site-packages/langchain_core/tracers/log_stream.py", line 570, in consume_astream
opengpts-backend | async for chunk in runnable.astream(input, config, **kwargs):
opengpts-backend | File "/usr/local/lib/python3.11/site-packages/langchain_core/runnables/configurable.py", line 221, in astream
opengpts-backend | async for chunk in runnable.astream(input, config, **kwargs):
opengpts-backend | File "/usr/local/lib/python3.11/site-packages/langchain_core/runnables/base.py", line 4698, in astream
opengpts-backend | async for item in self.bound.astream(
opengpts-backend | File "/usr/local/lib/python3.11/site-packages/langchain_core/runnables/configurable.py", line 221, in astream
opengpts-backend | async for chunk in runnable.astream(input, config, **kwargs):
opengpts-backend | File "/usr/local/lib/python3.11/site-packages/langchain_core/runnables/base.py", line 4698, in astream
opengpts-backend | async for item in self.bound.astream(
opengpts-backend | File "/usr/local/lib/python3.11/site-packages/langgraph/pregel/__init__.py", line 924, in astream
opengpts-backend | _apply_writes(checkpoint, channels, pending_writes)
opengpts-backend | File "/usr/local/lib/python3.11/site-packages/langgraph/pregel/__init__.py", line 1170, in _apply_writes
opengpts-backend | channels[chan].update(vals)
opengpts-backend | File "/usr/local/lib/python3.11/site-packages/langgraph/channels/binop.py", line 66, in update
opengpts-backend | self.value = self.operator(self.value, value)
opengpts-backend | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
opengpts-backend | File "/usr/local/lib/python3.11/site-packages/langgraph/graph/message.py", line 24, in add_messages
opengpts-backend | right = [message_chunk_to_message(m) for m in convert_to_messages(right)]
opengpts-backend | ^^^^^^^^^^^^^^^^^^^^^^^^^^
opengpts-backend | File "/usr/local/lib/python3.11/site-packages/langchain_core/messages/utils.py", line 234, in convert_to_messages
opengpts-backend | return [_convert_to_message(m) for m in messages]
opengpts-backend | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
opengpts-backend | File "/usr/local/lib/python3.11/site-packages/langchain_core/messages/utils.py", line 234, in <listcomp>
opengpts-backend | return [_convert_to_message(m) for m in messages]
opengpts-backend | ^^^^^^^^^^^^^^^^^^^^^^
opengpts-backend | File "/usr/local/lib/python3.11/site-packages/langchain_core/messages/utils.py", line 211, in _convert_to_message
opengpts-backend | raise ValueError(
opengpts-backend | ValueError: Message dict must contain 'role' and 'content' keys, got {'content': 'hello', 'additional_kwargs': {}, 'type': 'human', 'example': False, 'id': 'human-0.38348279892186277'}
FastAPI is not serializing the dict
to an AnyMessage
type (which contains role
).
Thanks for the assist @nfcampos 😄 |
Summary
As a demonstration,
get_chatbot_executor()
andget_retrieval_executor()
are updated to useStateGraph
. Moving forward, when using the RAG assistant (assistantType === "chat_retrieval"
), thePOST /runs and /runs/stream
API must be called with the follow request body:All other assistant types require the existing request body format (e.g.
"input": [{...}]
).Implementation
MessageGraph
inget_chatbot_executor()
is migrated toStateGraph
.MessageGraph
inget_retrieval_executor()
is migrated toStateGraph
. The graph now accepts aTypedDict
for the state. The interfaces for the corresponding nodes are updated accordingly.GET /threads/<tid>/state
(storage layer) is updated to retrieve the graph state based on the assistant type.POST /runs/stream
with the correct request body format based on the assistant type.To Do
TODO
s in code. Updatelangchain_core
dependency and implement new API. See core: forward config params to default langchain#20402 for details.get_retrieval_executor()
StateGraph
state.TODO
s are resolved.