Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update chatbot and RAG assistant to use StateGraph in the backend #305

Merged
merged 24 commits into from
Apr 18, 2024

Conversation

andrewnguonly
Copy link
Contributor

@andrewnguonly andrewnguonly commented Apr 16, 2024

Summary

As a demonstration, get_chatbot_executor() and get_retrieval_executor() are updated to use StateGraph. Moving forward, when using the RAG assistant (assistantType === "chat_retrieval"), the POST /runs and /runs/stream API must be called with the follow request body:

{
    // input is an object instead of a list
    "input": {
        // messages key must be present
        "messages": [
            {
                "content": "hello!", // content key must be present
                "role": "human",     // role key must be present
                ...
            }
        ],
        ...
    },
    ...
}

All other assistant types require the existing request body format (e.g. "input": [{...}]).

Implementation

  1. The MessageGraph in get_chatbot_executor() is migrated to StateGraph.
  2. The MessageGraph in get_retrieval_executor() is migrated to StateGraph. The graph now accepts a TypedDict for the state. The interfaces for the corresponding nodes are updated accordingly.
  3. API GET /threads/<tid>/state (storage layer) is updated to retrieve the graph state based on the assistant type.
  4. Frontend is updated to call the API POST /runs/stream with the correct request body format based on the assistant type.

To Do

  • Update API documentation (if necessary, details TBD).
  • Remove TODOs in code. Update langchain_core dependency and implement new API. See core: forward config params to default langchain#20402 for details.
  • Add new keys to the get_retrieval_executor() StateGraph state.
  • The code that's causing the broken unit tests will be removed when the TODOs are resolved.
  • Retest everything once all the To Do's are resolved.

@andrewnguonly andrewnguonly requested a review from nfcampos April 16, 2024 05:57
@@ -12,6 +12,7 @@ export interface MessageDocument {
export interface Message {
id: string;
type: string;
role?: string; // for chat_retrieval bot
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I couldn't decide what to do about this...

Calling POST /runs/stream requires the role field in the message. Otherwise, the request body will not be serialized appropriately in the backend and an error is raised langgraph.graph.message.add_messages (state reducer function).

It seems more appropriate for the client (frontend) to handle this instead of the API (backend), but I'm not sure.

@ptgoetz ptgoetz added enhancement New feature or request documentation Improvements or additions to documentation backend Changes to the backend service labels Apr 16, 2024
@@ -19,8 +25,13 @@ def loads(value: bytes) -> Checkpoint:


class PostgresCheckpoint(BaseCheckpointSaver):
class Config:
arbitrary_types_allowed = True
def __init__(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is required since the introduction of the serde API in BaseCheckpointSaver.

@andrewnguonly andrewnguonly changed the title Draft: Update chatbot and RAG assistant to use StateGraph in the backend Update chatbot and RAG assistant to use StateGraph in the backend Apr 17, 2024
@@ -244,7 +246,10 @@ def __init__(
llm=ConfigurableField(id="llm_type", name="LLM Type"),
system_message=ConfigurableField(id="system_message", name="Instructions"),
)
.with_types(input_type=Sequence[AnyMessage], output_type=Sequence[AnyMessage])
.with_types(
input_type=Union[Sequence[AnyMessage], Dict[str, Any]],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these types don't seem right, it either accepts list of messages or dict, not both?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated types to Messages/Sequence[AnyMessage] for input/output respectively.

state_chunk_msgs: Union[Sequence[AnyMessage], Dict[str, Any]] = event[
"data"
]["chunk"]
if isinstance(state_chunk_msgs, Dict):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be lowercase dict

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. I'll update all other instances of this error.

@@ -15,7 +18,7 @@ def _get_messages(messages):

chatbot = _get_messages | llm

workflow = MessageGraph()
workflow = StateGraph(Annotated[Sequence[BaseMessage], add_messages])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should be using list?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It didn't make a difference after this PR was merged: langchain-ai/langgraph#321. I'll update both instances to use List.

This is where the API feels ambiguous with respect to the underlying functionality. If there's a "preferred" or required type that should be used, then an end user isn't necessarily aware of it. Something to think about for later.

@@ -39,6 +42,10 @@ def get_retrieval_executor(
system_message: str,
checkpoint: BaseCheckpointSaver,
):
class AgentState(TypedDict):
messages: Annotated[Sequence[BaseMessage], add_messages]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should be using list?

// Each message must contain a `role` field.
input = {
messages: input.map((msg: Message) => {
msg.role = "human";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how come?

Copy link
Contributor Author

@andrewnguonly andrewnguonly Apr 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The following error is raised:

opengpts-backend   | Traceback (most recent call last):
opengpts-backend   |   File "/backend/app/stream.py", line 63, in to_sse
opengpts-backend   |     async for chunk in messages_stream:
opengpts-backend   |   File "/backend/app/stream.py", line 23, in astream_state
opengpts-backend   |     async for event in app.astream_events(
opengpts-backend   |   File "/usr/local/lib/python3.11/site-packages/langchain_core/runnables/base.py", line 4711, in astream_events
opengpts-backend   |     async for item in self.bound.astream_events(
opengpts-backend   |   File "/usr/local/lib/python3.11/site-packages/langchain_core/runnables/base.py", line 1137, in astream_events
opengpts-backend   |     async for log in _astream_log_implementation(  # type: ignore[misc]
opengpts-backend   |   File "/usr/local/lib/python3.11/site-packages/langchain_core/tracers/log_stream.py", line 616, in _astream_log_implementation
opengpts-backend   |     await task
opengpts-backend   |   File "/usr/local/lib/python3.11/site-packages/langchain_core/tracers/log_stream.py", line 570, in consume_astream
opengpts-backend   |     async for chunk in runnable.astream(input, config, **kwargs):
opengpts-backend   |   File "/usr/local/lib/python3.11/site-packages/langchain_core/runnables/configurable.py", line 221, in astream
opengpts-backend   |     async for chunk in runnable.astream(input, config, **kwargs):
opengpts-backend   |   File "/usr/local/lib/python3.11/site-packages/langchain_core/runnables/base.py", line 4698, in astream
opengpts-backend   |     async for item in self.bound.astream(
opengpts-backend   |   File "/usr/local/lib/python3.11/site-packages/langchain_core/runnables/configurable.py", line 221, in astream
opengpts-backend   |     async for chunk in runnable.astream(input, config, **kwargs):
opengpts-backend   |   File "/usr/local/lib/python3.11/site-packages/langchain_core/runnables/base.py", line 4698, in astream
opengpts-backend   |     async for item in self.bound.astream(
opengpts-backend   |   File "/usr/local/lib/python3.11/site-packages/langgraph/pregel/__init__.py", line 924, in astream
opengpts-backend   |     _apply_writes(checkpoint, channels, pending_writes)
opengpts-backend   |   File "/usr/local/lib/python3.11/site-packages/langgraph/pregel/__init__.py", line 1170, in _apply_writes
opengpts-backend   |     channels[chan].update(vals)
opengpts-backend   |   File "/usr/local/lib/python3.11/site-packages/langgraph/channels/binop.py", line 66, in update
opengpts-backend   |     self.value = self.operator(self.value, value)
opengpts-backend   |                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
opengpts-backend   |   File "/usr/local/lib/python3.11/site-packages/langgraph/graph/message.py", line 24, in add_messages
opengpts-backend   |     right = [message_chunk_to_message(m) for m in convert_to_messages(right)]
opengpts-backend   |                                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^
opengpts-backend   |   File "/usr/local/lib/python3.11/site-packages/langchain_core/messages/utils.py", line 234, in convert_to_messages
opengpts-backend   |     return [_convert_to_message(m) for m in messages]
opengpts-backend   |            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
opengpts-backend   |   File "/usr/local/lib/python3.11/site-packages/langchain_core/messages/utils.py", line 234, in <listcomp>
opengpts-backend   |     return [_convert_to_message(m) for m in messages]
opengpts-backend   |             ^^^^^^^^^^^^^^^^^^^^^^
opengpts-backend   |   File "/usr/local/lib/python3.11/site-packages/langchain_core/messages/utils.py", line 211, in _convert_to_message
opengpts-backend   |     raise ValueError(
opengpts-backend   | ValueError: Message dict must contain 'role' and 'content' keys, got {'content': 'hello', 'additional_kwargs': {}, 'type': 'human', 'example': False, 'id': 'human-0.38348279892186277'}

FastAPI is not serializing the dict to an AnyMessage type (which contains role).

@andrewnguonly
Copy link
Contributor Author

Thanks for the assist @nfcampos 😄

@nfcampos nfcampos merged commit f0c25df into langchain-ai:main Apr 18, 2024
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend Changes to the backend service documentation Improvements or additions to documentation enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants