Skip to content

Commit

Permalink
feat(langgraph): add support for switching threads (#852)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yonom authored Sep 18, 2024
1 parent ff426b0 commit e4863bb
Show file tree
Hide file tree
Showing 10 changed files with 57 additions and 13 deletions.
5 changes: 5 additions & 0 deletions .changeset/quick-buckets-drum.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@assistant-ui/react-langgraph": patch
---

feat(langgraph): add support for switching threads
6 changes: 6 additions & 0 deletions .changeset/twenty-drinks-pull.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
"@assistant-ui/react-ai-sdk": patch
"@assistant-ui/react": patch
---

feat(runtimes/external): add onSwitchToNewThread callback
5 changes: 5 additions & 0 deletions .changeset/violet-fans-pull.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@assistant-ui/react": patch
---

feat: add attachmentAccept to ThreadComposer
14 changes: 13 additions & 1 deletion examples/with-langgraph/app/MyRuntimeProvider.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import { useRef } from "react";
import { AssistantRuntimeProvider } from "@assistant-ui/react";
import { useLangGraphRuntime } from "@assistant-ui/react-langgraph";
import { createThread, sendMessage } from "@/lib/chatApi";
import { createThread, getThreadState, sendMessage } from "@/lib/chatApi";
import { LangChainMessage } from "@assistant-ui/react-langgraph";

export function MyRuntimeProvider({
children,
Expand All @@ -24,6 +25,17 @@ export function MyRuntimeProvider({
messages,
});
},
onSwitchToNewThread: async () => {
const { thread_id } = await createThread();
threadIdRef.current = thread_id;
},
onSwitchToThread: async (threadId) => {
const state = await getThreadState(threadId);
threadIdRef.current = threadId;
return {
messages: (state.values as { messages: LangChainMessage[] }).messages,
};
},
});

return (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ export const useVercelUseAssistantRuntime = (
onNew: async (message) => {
await assistantHelpers.append(await toCreateMessage(message));
},
onNewThread: () => {
onSwitchToNewThread: () => {
assistantHelpers.messages = [];
assistantHelpers.input = "";
assistantHelpers.setMessages([]);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ export const useVercelUseChatRuntime = (
onAddToolResult: ({ toolCallId, result }) => {
chatHelpers.addToolResult({ toolCallId, result });
},
onNewThread: () => {
onSwitchToNewThread: () => {
chatHelpers.messages = [];
chatHelpers.input = "";
chatHelpers.setMessages([]);
Expand Down
2 changes: 1 addition & 1 deletion packages/react-langgraph/src/useLangGraphMessages.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,5 +49,5 @@ export const useLangGraphMessages = <TMessage>({
[stream],
);

return { messages, sendMessage };
return { messages, sendMessage, setMessages };
};
20 changes: 19 additions & 1 deletion packages/react-langgraph/src/useLangGraphRuntime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ const getPendingToolCalls = (messages: LangChainMessage[]) => {
export const useLangGraphRuntime = ({
threadId,
stream,
onSwitchToNewThread,
onSwitchToThread,
}: {
threadId?: string | undefined;
stream: (messages: LangChainMessage[]) => Promise<
Expand All @@ -35,8 +37,12 @@ export const useLangGraphRuntime = ({
data: any;
}>
>;
onSwitchToNewThread?: () => Promise<void> | void;
onSwitchToThread?: (
threadId: string,
) => Promise<{ messages: LangChainMessage[] }>;
}): ExternalStoreRuntime => {
const { messages, sendMessage } = useLangGraphMessages({
const { messages, sendMessage, setMessages } = useLangGraphMessages({
stream,
});

Expand Down Expand Up @@ -93,5 +99,17 @@ export const useLangGraphRuntime = ({
},
]);
},
onSwitchToNewThread: !onSwitchToNewThread
? undefined
: async () => {
await onSwitchToNewThread();
setMessages([]);
},
onSwitchToThread: !onSwitchToThread
? undefined
: async (threadId) => {
const { messages } = await onSwitchToThread(threadId);
setMessages(messages);
},
});
};
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,11 @@ type ExternalStoreAdapterBase<T> = {
onEdit?: ((message: AppendMessage) => Promise<void>) | undefined;
onReload?: ((parentId: string | null) => Promise<void>) | undefined;
onCancel?: (() => Promise<void>) | undefined;
onNewThread?: (() => Promise<void> | void) | undefined;
onAddToolResult?:
| ((options: AddToolResultOptions) => Promise<void> | void)
| undefined;
onSwitchThread?:
| ((threadId: string | null) => Promise<void> | void)
| undefined;
onSwitchToThread?: ((threadId: string) => Promise<void> | void) | undefined;
onSwitchToNewThread?: (() => Promise<void> | void) | undefined;
onSpeak?:
| ((message: ThreadMessage) => SpeechSynthesisAdapter.Utterance)
| undefined;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,26 +27,26 @@ export class ExternalStoreRuntime extends BaseAssistantRuntime<ExternalStoreThre
}

public async switchToNewThread() {
if (!this.store.onNewThread)
if (!this.store.onSwitchToNewThread)
throw new Error("Runtime does not support switching to new threads.");

this.thread = new ExternalStoreThreadRuntime({
messages: [],
onNew: this.store.onNew,
});
await this.store.onNewThread();
await this.store.onSwitchToNewThread();
}

public async switchToThread(threadId: string | null) {
if (threadId !== null) {
if (!this.store.onSwitchThread)
if (!this.store.onSwitchToThread)
throw new Error("Runtime does not support switching threads.");

this.thread = new ExternalStoreThreadRuntime({
messages: [],
onNew: this.store.onNew,
});
this.store.onSwitchThread(threadId);
this.store.onSwitchToThread(threadId);
} else {
this.switchToNewThread();
}
Expand Down

0 comments on commit e4863bb

Please sign in to comment.