diff --git a/.changeset/quiet-impalas-accept.md b/.changeset/quiet-impalas-accept.md new file mode 100644 index 0000000000..53ef322154 --- /dev/null +++ b/.changeset/quiet-impalas-accept.md @@ -0,0 +1,6 @@ +--- +"@assistant-ui/react-ai-sdk": minor +"@assistant-ui/react": patch +--- + +refactor: rewrite ai-sdk integration to use external runtime diff --git a/packages/react-ai-sdk/src/ui/getVercelAIMessage.tsx b/packages/react-ai-sdk/src/ui/getVercelAIMessage.tsx index 9a090a5693..91a94f1712 100644 --- a/packages/react-ai-sdk/src/ui/getVercelAIMessage.tsx +++ b/packages/react-ai-sdk/src/ui/getVercelAIMessage.tsx @@ -1,12 +1,9 @@ -import type { ThreadMessage } from "@assistant-ui/react"; +import { + getExternalStoreMessage, + type ThreadMessage, +} from "@assistant-ui/react"; import type { Message } from "ai"; -export const symbolInnerAIMessage = Symbol("innerVercelAIUIMessage"); - -export type VercelAIThreadMessage = ThreadMessage & { - [symbolInnerAIMessage]?: Message[]; -}; - export const getVercelAIMessage = (message: ThreadMessage) => { - return (message as VercelAIThreadMessage)[symbolInnerAIMessage]; + return getExternalStoreMessage(message) as Message[]; }; diff --git a/packages/react-ai-sdk/src/ui/use-assistant/VercelUseAssistantRuntime.tsx b/packages/react-ai-sdk/src/ui/use-assistant/VercelUseAssistantRuntime.tsx deleted file mode 100644 index a654147544..0000000000 --- a/packages/react-ai-sdk/src/ui/use-assistant/VercelUseAssistantRuntime.tsx +++ /dev/null @@ -1,54 +0,0 @@ -import { - type ThreadMessage, - INTERNAL, - ModelConfigProvider, -} from "@assistant-ui/react"; -import { useAssistant } from "ai/react"; -import { VercelUseAssistantThreadRuntime } from "./VercelUseAssistantThreadRuntime"; - -const { ProxyConfigProvider, BaseAssistantRuntime } = INTERNAL; - -export const hasUpcomingMessage = ( - isRunning: boolean, - messages: ThreadMessage[], -) => { - return isRunning && messages[messages.length - 1]?.role !== "assistant"; -}; - -export class VercelUseAssistantRuntime extends BaseAssistantRuntime { - private readonly _proxyConfigProvider = new ProxyConfigProvider(); - - constructor(vercel: ReturnType) { - super(new VercelUseAssistantThreadRuntime(vercel)); - } - - public set vercel(vercel: ReturnType) { - this.thread.vercel = vercel; - } - - public onVercelUpdated() { - return this.thread.onVercelUpdated(); - } - - public getModelConfig() { - return this._proxyConfigProvider.getModelConfig(); - } - - public registerModelConfigProvider(provider: ModelConfigProvider) { - return this._proxyConfigProvider.registerModelConfigProvider(provider); - } - - public switchToThread(threadId: string | null) { - if (threadId) { - throw new Error("VercelAIRuntime does not yet support switching threads"); - } - - // clear the vercel state (otherwise, it will be captured by the MessageRepository) - this.thread.vercel.messages = []; - this.thread.vercel.input = ""; - this.thread.vercel.setMessages([]); - this.thread.vercel.setInput(""); - - this.thread = new VercelUseAssistantThreadRuntime(this.thread.vercel); - } -} diff --git a/packages/react-ai-sdk/src/ui/use-assistant/VercelUseAssistantThreadRuntime.tsx b/packages/react-ai-sdk/src/ui/use-assistant/VercelUseAssistantThreadRuntime.tsx deleted file mode 100644 index f577437d12..0000000000 --- a/packages/react-ai-sdk/src/ui/use-assistant/VercelUseAssistantThreadRuntime.tsx +++ /dev/null @@ -1,135 +0,0 @@ -import { - type ReactThreadRuntime, - type Unsubscribe, - type AppendMessage, - type ThreadMessage, -} from "@assistant-ui/react"; -import { type StoreApi, type UseBoundStore, create } from "zustand"; -import { useVercelAIComposerSync } from "../utils/useVercelAIComposerSync"; -import { useVercelAIThreadSync } from "../utils/useVercelAIThreadSync"; -import { useAssistant } from "ai/react"; -import { hasUpcomingMessage } from "./VercelUseAssistantRuntime"; - -const EMPTY_BRANCHES: readonly string[] = Object.freeze([]); - -const CAPABILITIES = Object.freeze({ - switchToBranch: false, - edit: false, - reload: false, - cancel: false, - copy: true, -}); - -export class VercelUseAssistantThreadRuntime implements ReactThreadRuntime { - private _subscriptions = new Set<() => void>(); - - public readonly capabilities = CAPABILITIES; - - private useVercel: UseBoundStore< - StoreApi<{ vercel: ReturnType }> - >; - - public messages: readonly ThreadMessage[] = []; - - public readonly composer = { - text: "", - setText: (value: string) => { - this.composer.text = value; - - for (const callback of this._subscriptions) callback(); - }, - }; - - public readonly isDisabled = false; - - constructor(public vercel: ReturnType) { - this.useVercel = create(() => ({ - vercel, - })); - } - - public getBranches(): readonly string[] { - return EMPTY_BRANCHES; - } - - public switchToBranch(): void { - throw new Error( - "VercelUseAssistantRuntime does not support switching branches.", - ); - } - - public async append(message: AppendMessage): Promise { - // add user message - if (message.role !== "user") - throw new Error( - "Only appending user messages are supported in VercelUseAssistantRuntime. This is likely an internal bug in assistant-ui.", - ); - if (message.content.length !== 1 || message.content[0]?.type !== "text") - throw new Error("VercelUseAssistantRuntime only supports text content."); - - if (message.parentId !== (this.messages.at(-1)?.id ?? null)) - throw new Error( - "VercelUseAssistantRuntime does not support editing messages.", - ); - - await this.vercel.append({ - role: "user", - content: message.content[0].text, - }); - } - - public async startRun(): Promise { - throw new Error("VercelUseAssistantRuntime does not support reloading."); - } - - public cancelRun(): void { - const previousMessage = this.vercel.messages.at(-1); - - this.vercel.stop(); - if (previousMessage?.role === "user") { - this.vercel.setInput(previousMessage.content); - } - } - - public subscribe(callback: () => void): Unsubscribe { - this._subscriptions.add(callback); - return () => this._subscriptions.delete(callback); - } - - public onVercelUpdated() { - if (this.useVercel.getState().vercel !== this.vercel) { - this.useVercel.setState({ vercel: this.vercel }); - } - } - - private updateData = (isRunning: boolean, vm: ThreadMessage[]) => { - if (hasUpcomingMessage(isRunning, vm)) { - vm.push({ - id: "__optimistic__result", - createdAt: new Date(), - status: { type: "running" }, - role: "assistant", - content: [], - }); - } - - this.messages = vm; - - for (const callback of this._subscriptions) callback(); - }; - - unstable_synchronizer = () => { - const { vercel } = this.useVercel(); - - useVercelAIThreadSync(vercel, this.updateData); - useVercelAIComposerSync(vercel); - - return null; - }; - - addToolResult() { - throw new Error( - "VercelUseAssistantRuntime does not support adding tool results.", - ); - } -} diff --git a/packages/react-ai-sdk/src/ui/use-assistant/useVercelUseAssistantRuntime.tsx b/packages/react-ai-sdk/src/ui/use-assistant/useVercelUseAssistantRuntime.tsx index 8c07c35d6b..b4ae3889ea 100644 --- a/packages/react-ai-sdk/src/ui/use-assistant/useVercelUseAssistantRuntime.tsx +++ b/packages/react-ai-sdk/src/ui/use-assistant/useVercelUseAssistantRuntime.tsx @@ -1,20 +1,38 @@ import type { useAssistant } from "ai/react"; -import { useEffect, useInsertionEffect, useState } from "react"; -import { VercelUseAssistantRuntime } from "./VercelUseAssistantRuntime"; +import { useExternalStoreRuntime } from "@assistant-ui/react"; +import { useCachedChunkedMessages } from "../utils/useCachedChunkedMessages"; +import { convertMessage } from "../utils/convertMessage"; +import { useInputSync } from "../utils/useInputSync"; export const useVercelUseAssistantRuntime = ( assistantHelpers: ReturnType, ) => { - const [runtime] = useState( - () => new VercelUseAssistantRuntime(assistantHelpers), - ); + const messages = useCachedChunkedMessages(assistantHelpers.messages); + const runtime = useExternalStoreRuntime({ + isRunning: assistantHelpers.status === "in_progress", + messages, + onCancel: async () => assistantHelpers.stop(), + onNew: async (message) => { + if (message.content.length !== 1 || message.content[0]?.type !== "text") + throw new Error( + "VercelUseAssistantRuntime only supports text content.", + ); - useInsertionEffect(() => { - runtime.vercel = assistantHelpers; - }); - useEffect(() => { - runtime.onVercelUpdated(); + await assistantHelpers.append({ + role: message.role, + content: message.content[0].text, + }); + }, + onNewThread: () => { + assistantHelpers.messages = []; + assistantHelpers.input = ""; + assistantHelpers.setMessages([]); + assistantHelpers.setInput(""); + }, + convertMessage, }); + useInputSync(assistantHelpers, runtime); + return runtime; }; diff --git a/packages/react-ai-sdk/src/ui/use-chat/VercelUseChatRuntime.tsx b/packages/react-ai-sdk/src/ui/use-chat/VercelUseChatRuntime.tsx deleted file mode 100644 index 16c78c9c25..0000000000 --- a/packages/react-ai-sdk/src/ui/use-chat/VercelUseChatRuntime.tsx +++ /dev/null @@ -1,45 +0,0 @@ -import { INTERNAL, ModelConfigProvider } from "@assistant-ui/react"; -import { useChat } from "ai/react"; -import { VercelUseChatThreadRuntime } from "./VercelUseChatThreadRuntime"; - -const { ProxyConfigProvider, BaseAssistantRuntime } = INTERNAL; - -export class VercelUseChatRuntime extends BaseAssistantRuntime { - private readonly _proxyConfigProvider = new ProxyConfigProvider(); - - constructor(vercel: ReturnType) { - super(new VercelUseChatThreadRuntime(vercel)); - } - - public set vercel(vercel: ReturnType) { - this.thread.vercel = vercel; - } - - public onVercelUpdated() { - return this.thread.onVercelUpdated(); - } - - public getModelConfig() { - return this._proxyConfigProvider.getModelConfig(); - } - - public registerModelConfigProvider(provider: ModelConfigProvider) { - return this._proxyConfigProvider.registerModelConfigProvider(provider); - } - - public switchToThread(threadId: string | null) { - if (threadId) { - throw new Error( - "VercelAIRuntime does not yet support switching threads.", - ); - } - - // clear the vercel state (otherwise, it will be captured by the MessageRepository) - this.thread.vercel.messages = []; - this.thread.vercel.input = ""; - this.thread.vercel.setMessages([]); - this.thread.vercel.setInput(""); - - this.thread = new VercelUseChatThreadRuntime(this.thread.vercel); - } -} diff --git a/packages/react-ai-sdk/src/ui/use-chat/VercelUseChatThreadRuntime.tsx b/packages/react-ai-sdk/src/ui/use-chat/VercelUseChatThreadRuntime.tsx deleted file mode 100644 index 651e5df083..0000000000 --- a/packages/react-ai-sdk/src/ui/use-chat/VercelUseChatThreadRuntime.tsx +++ /dev/null @@ -1,187 +0,0 @@ -import { - type ReactThreadRuntime, - type Unsubscribe, - type AppendMessage, - type ThreadMessage, - AddToolResultOptions, - INTERNAL, -} from "@assistant-ui/react"; -import type { Message } from "ai"; -import { type StoreApi, type UseBoundStore, create } from "zustand"; -import { useChat } from "ai/react"; -import { getVercelAIMessage } from "../getVercelAIMessage"; -import { sliceMessagesUntil } from "../utils/sliceMessagesUntil"; -import { useVercelAIComposerSync } from "../utils/useVercelAIComposerSync"; -import { useVercelAIThreadSync } from "../utils/useVercelAIThreadSync"; - -const { MessageRepository } = INTERNAL; - -export const hasUpcomingMessage = ( - isRunning: boolean, - messages: ThreadMessage[], -) => { - return isRunning && messages[messages.length - 1]?.role !== "assistant"; -}; - -const CAPABILITIES = Object.freeze({ - switchToBranch: true, - edit: true, - reload: true, - cancel: true, - copy: true, -}); - -export class VercelUseChatThreadRuntime implements ReactThreadRuntime { - private _subscriptions = new Set<() => void>(); - private repository = new MessageRepository(); - private assistantOptimisticId: string | null = null; - - private useVercel: UseBoundStore< - StoreApi<{ vercel: ReturnType }> - >; - - public readonly capabilities = CAPABILITIES; - - public messages: ThreadMessage[] = []; - public readonly isDisabled = false; - - public readonly composer = { - text: "", - setText: (value: string) => { - this.composer.text = value; - - for (const callback of this._subscriptions) callback(); - }, - }; - - constructor(public vercel: ReturnType) { - this.useVercel = create(() => ({ - vercel, - })); - } - - public getBranches(messageId: string): string[] { - return this.repository.getBranches(messageId); - } - - public switchToBranch(branchId: string): void { - this.repository.switchToBranch(branchId); - this.updateVercelMessages(this.repository.getMessages()); - } - - public async append(message: AppendMessage): Promise { - // add user message - if (message.content.length !== 1 || message.content[0]?.type !== "text") - throw new Error( - "Only text content is supported by VercelUseChatRuntime. Use the Edge runtime for image support.", - ); - - const newMessages = sliceMessagesUntil( - this.vercel.messages, - message.parentId, - ); - this.vercel.setMessages(newMessages); - - await this.vercel.append({ - role: message.role, - content: message.content[0].text, - }); - } - - public async startRun(parentId: string | null): Promise { - const newMessages = sliceMessagesUntil(this.vercel.messages, parentId); - this.vercel.setMessages(newMessages); - - await this.vercel.reload(); - } - - public cancelRun(): void { - const previousMessage = this.vercel.messages.at(-1); - - this.vercel.stop(); - - if (this.assistantOptimisticId) { - this.repository.deleteMessage(this.assistantOptimisticId); - this.assistantOptimisticId = null; - } - - let messages = this.repository.getMessages(); - if ( - previousMessage?.role === "user" && - previousMessage.id === messages.at(-1)?.id // ensure the previous message is a leaf node - ) { - this.vercel.setInput(previousMessage.content); - this.repository.deleteMessage(previousMessage.id); - - messages = this.repository.getMessages(); - } - - // resync messages - setTimeout(() => { - this.updateVercelMessages(messages); - }, 0); - } - - public subscribe(callback: () => void): Unsubscribe { - this._subscriptions.add(callback); - return () => this._subscriptions.delete(callback); - } - - private updateVercelMessages = (messages: ThreadMessage[]) => { - this.vercel.setMessages( - messages - .flatMap(getVercelAIMessage) - .filter((m): m is Message => m != null), - ); - }; - - public onVercelUpdated() { - if (this.useVercel.getState().vercel !== this.vercel) { - this.useVercel.setState({ vercel: this.vercel }); - } - } - - private updateData = (isRunning: boolean, vm: ThreadMessage[]) => { - for (let i = 0; i < vm.length; i++) { - const message = vm[i]!; - const parent = vm[i - 1]; - this.repository.addOrUpdateMessage(parent?.id ?? null, message); - } - - if (this.assistantOptimisticId) { - this.repository.deleteMessage(this.assistantOptimisticId); - this.assistantOptimisticId = null; - } - - if (hasUpcomingMessage(isRunning, vm)) { - this.assistantOptimisticId = this.repository.appendOptimisticMessage( - vm.at(-1)?.id ?? null, - { - role: "assistant", - content: [], - }, - ); - } - - this.repository.resetHead( - this.assistantOptimisticId ?? vm.at(-1)?.id ?? null, - ); - - this.messages = this.repository.getMessages(); - - for (const callback of this._subscriptions) callback(); - }; - - unstable_synchronizer = () => { - const { vercel } = this.useVercel(); - - useVercelAIThreadSync(vercel, this.updateData); - useVercelAIComposerSync(vercel); - - return null; - }; - - addToolResult({ toolCallId, result }: AddToolResultOptions) { - this.vercel.addToolResult({ toolCallId, result }); - } -} diff --git a/packages/react-ai-sdk/src/ui/use-chat/useVercelUseChatRuntime.tsx b/packages/react-ai-sdk/src/ui/use-chat/useVercelUseChatRuntime.tsx index 46ce88862d..4e0cdacab8 100644 --- a/packages/react-ai-sdk/src/ui/use-chat/useVercelUseChatRuntime.tsx +++ b/packages/react-ai-sdk/src/ui/use-chat/useVercelUseChatRuntime.tsx @@ -1,18 +1,66 @@ import type { useChat } from "ai/react"; -import { useEffect, useInsertionEffect, useState } from "react"; -import { VercelUseChatRuntime } from "./VercelUseChatRuntime"; +import { useCachedChunkedMessages } from "../utils/useCachedChunkedMessages"; +import { convertMessage } from "../utils/convertMessage"; +import { useExternalStoreRuntime } from "@assistant-ui/react"; +import { useInputSync } from "../utils/useInputSync"; +import { sliceMessagesUntil } from "../utils/sliceMessagesUntil"; export const useVercelUseChatRuntime = ( chatHelpers: ReturnType, ) => { - const [runtime] = useState(() => new VercelUseChatRuntime(chatHelpers)); + const messages = useCachedChunkedMessages(chatHelpers.messages); + const runtime = useExternalStoreRuntime({ + isRunning: chatHelpers.isLoading, + messages, + setMessages: (messages) => chatHelpers.setMessages(messages.flat()), + onCancel: async () => chatHelpers.stop(), + onNew: async (message) => { + if (message.content.length !== 1 || message.content[0]?.type !== "text") + throw new Error( + "Only text content is supported by VercelUseChatRuntime. Use the Edge runtime for image support.", + ); + await chatHelpers.append({ + role: message.role, + content: message.content[0].text, + }); + }, + onEdit: async (message) => { + if (message.content.length !== 1 || message.content[0]?.type !== "text") + throw new Error( + "Only text content is supported by VercelUseChatRuntime. Use the Edge runtime for image support.", + ); - useInsertionEffect(() => { - runtime.vercel = chatHelpers; - }); - useEffect(() => { - runtime.onVercelUpdated(); + const newMessages = sliceMessagesUntil( + chatHelpers.messages, + message.parentId, + ); + chatHelpers.setMessages(newMessages); + + await chatHelpers.append({ + role: message.role, + content: message.content[0].text, + }); + }, + onReload: async (parentId: string | null) => { + const newMessages = sliceMessagesUntil(chatHelpers.messages, parentId); + chatHelpers.setMessages(newMessages); + + await chatHelpers.reload(); + }, + onAddToolResult: ({ toolCallId, result }) => { + chatHelpers.addToolResult({ toolCallId, result }); + }, + // onCopy // TODO + onNewThread: () => { + chatHelpers.messages = []; + chatHelpers.input = ""; + chatHelpers.setMessages([]); + chatHelpers.setInput(""); + }, + convertMessage, }); + useInputSync(chatHelpers, runtime); + return runtime; }; diff --git a/packages/react-ai-sdk/src/ui/utils/VercelHelpers.tsx b/packages/react-ai-sdk/src/ui/utils/VercelHelpers.tsx deleted file mode 100644 index da52ab4fa0..0000000000 --- a/packages/react-ai-sdk/src/ui/utils/VercelHelpers.tsx +++ /dev/null @@ -1,5 +0,0 @@ -import type { useAssistant, useChat } from "ai/react"; - -export type VercelHelpers = - | ReturnType - | ReturnType; diff --git a/packages/react-ai-sdk/src/ui/utils/convertMessage.ts b/packages/react-ai-sdk/src/ui/utils/convertMessage.ts new file mode 100644 index 0000000000..d8414bd2d1 --- /dev/null +++ b/packages/react-ai-sdk/src/ui/utils/convertMessage.ts @@ -0,0 +1,90 @@ +import { Message } from "ai"; +import { ThreadMessageLike } from "@assistant-ui/react"; +import { ToolCallContentPart } from "@assistant-ui/react"; +import { TextContentPart } from "@assistant-ui/react"; + +export const convertMessage = (messages: Message[]): ThreadMessageLike => { + const firstMessage = messages[0]; + if (!firstMessage) throw new Error("No messages found"); + + const common = { + id: firstMessage.id, + createdAt: firstMessage.createdAt ?? new Date(), + }; + + switch (firstMessage.role) { + case "user": + if (messages.length > 1) { + throw new Error( + "Multiple user messages found. This is likely an internal bug in assistant-ui.", + ); + } + + return { + ...common, + role: "user", + content: [{ type: "text", text: firstMessage.content }], + }; + + case "system": + return { + ...common, + role: "system", + content: [{ type: "text", text: firstMessage.content }], + }; + + case "data": + case "assistant": { + const res: ThreadMessageLike = { + ...common, + role: "assistant", + content: messages.flatMap((message) => { + return [ + ...(message.content + ? [{ type: "text", text: message.content } as TextContentPart] + : []), + ...(message.toolInvocations?.map( + (t) => + ({ + type: "tool-call", + toolName: t.toolName, + toolCallId: t.toolCallId, + argsText: JSON.stringify(t.args), + args: t.args, + result: "result" in t ? t.result : undefined, + }) satisfies ToolCallContentPart, + ) ?? []), + ...(typeof message.data === "object" && + !Array.isArray(message.data) && + message.data?.["type"] === "tool-call" + ? [message.data as ToolCallContentPart] + : []), + ]; + }), + }; + + for (const message of messages) { + if ( + typeof message.data === "object" && + !Array.isArray(message.data) && + message.data?.["type"] === "tool-result" + ) { + const toolCallId = message.data["toolCallId"]; + const toolContent = res.content.find( + (c) => c.type === "tool-call" && c.toolCallId === toolCallId, + ) as ToolCallContentPart | undefined; + if (!toolContent) throw new Error("Tool call not found"); + toolContent.result = message.data["result"]; + } + } + + return res; + } + + default: + const _unsupported: "function" | "tool" = firstMessage.role; + throw new Error( + `You have a message with an unsupported role. The role ${_unsupported} is not supported.`, + ); + } +}; diff --git a/packages/react-ai-sdk/src/ui/utils/useCachedChunkedMessages.ts b/packages/react-ai-sdk/src/ui/utils/useCachedChunkedMessages.ts new file mode 100644 index 0000000000..cc38eb79ee --- /dev/null +++ b/packages/react-ai-sdk/src/ui/utils/useCachedChunkedMessages.ts @@ -0,0 +1,53 @@ +import { Message } from "ai"; +import { useMemo } from "react"; + +type Chunk = [Message, ...Message[]]; +const hasItems = (messages: Message[]): messages is Chunk => + messages.length > 0; + +const chunkedMessages = (messages: Message[]): Chunk[] => { + const chunks: Chunk[] = []; + let currentChunk: Message[] = []; + + for (const message of messages) { + if (message.role === "assistant" || message.role === "data") { + currentChunk.push(message); + } else { + if (hasItems(currentChunk)) { + chunks.push(currentChunk); + currentChunk = []; + } + chunks.push([message]); + } + } + + if (hasItems(currentChunk)) { + chunks.push(currentChunk); + } + + return chunks; +}; + +const shallowArrayEqual = (a: unknown[], b: unknown[]) => { + if (a.length !== b.length) return false; + for (let i = 0; i < a.length; i++) { + if (a[i] !== b[i]) return false; + } + return true; +}; + +export const useCachedChunkedMessages = (messages: Message[]) => { + const cache = useMemo(() => new WeakMap(), []); + + return useMemo(() => { + return chunkedMessages(messages).map((m) => { + const key = m[0]; + if (!key) return m; + + const cached = cache.get(key); + if (cached && shallowArrayEqual(cached, m)) return cached; + cache.set(key, m); + return m; + }); + }, [messages, cache]); +}; diff --git a/packages/react-ai-sdk/src/ui/utils/useInputSync.tsx b/packages/react-ai-sdk/src/ui/utils/useInputSync.tsx new file mode 100644 index 0000000000..b4084701b0 --- /dev/null +++ b/packages/react-ai-sdk/src/ui/utils/useInputSync.tsx @@ -0,0 +1,33 @@ +import { useRef, useEffect } from "react"; +import { + ExternalStoreRuntime, + subscribeToMainThread, +} from "@assistant-ui/react"; +import { useAssistant, useChat } from "ai/react"; + +type VercelHelpers = + | ReturnType + | ReturnType; + +export const useInputSync = ( + helpers: VercelHelpers, + runtime: ExternalStoreRuntime, +) => { + // sync input from vercel to assistant-ui + const helpersRef = useRef(helpers); + useEffect(() => { + helpersRef.current = helpers; + if (runtime.thread.composer.text !== helpers.input) { + runtime.thread.composer.setText(helpers.input); + } + }, [helpers, runtime]); + + // sync input from assistant-ui to vercel + useEffect(() => { + return subscribeToMainThread(runtime, () => { + if (runtime.thread.composer.text !== helpersRef.current.input) { + helpersRef.current.setInput(runtime.thread.composer.text); + } + }); + }, [runtime]); +}; diff --git a/packages/react-ai-sdk/src/ui/utils/useVercelAIComposerSync.tsx b/packages/react-ai-sdk/src/ui/utils/useVercelAIComposerSync.tsx deleted file mode 100644 index 6ee30d2ccf..0000000000 --- a/packages/react-ai-sdk/src/ui/utils/useVercelAIComposerSync.tsx +++ /dev/null @@ -1,22 +0,0 @@ -import { useThreadContext, ComposerState } from "@assistant-ui/react"; -import { useEffect } from "react"; -import type { VercelHelpers } from "./VercelHelpers"; -import { StoreApi } from "zustand"; - -// two way sync between vercel helpers input state and composer text state -export const useVercelAIComposerSync = (vercel: VercelHelpers) => { - const { useComposer, useThreadRuntime } = useThreadContext(); - - useEffect(() => { - useThreadRuntime.getState().composer.setText(vercel.input); - }, [useComposer, useThreadRuntime, vercel.input]); - - useEffect(() => { - (useComposer as unknown as StoreApi).setState({ - setText: (t) => { - vercel.setInput(t); - useThreadRuntime.getState().composer.setText(t); - }, - }); - }, [useComposer, useThreadRuntime, vercel]); -}; diff --git a/packages/react-ai-sdk/src/ui/utils/useVercelAIThreadSync.tsx b/packages/react-ai-sdk/src/ui/utils/useVercelAIThreadSync.tsx deleted file mode 100644 index 9648befaae..0000000000 --- a/packages/react-ai-sdk/src/ui/utils/useVercelAIThreadSync.tsx +++ /dev/null @@ -1,192 +0,0 @@ -import type { - TextContentPart, - ThreadMessage, - ToolCallContentPart, - MessageStatus, - ThreadAssistantMessage, -} from "@assistant-ui/react"; -import type { Message } from "ai"; -import { useEffect, useMemo } from "react"; -import { - type ConverterCallback, - ThreadMessageConverter, -} from "../../utils/ThreadMessageConverter"; -import { - type VercelAIThreadMessage, - symbolInnerAIMessage, -} from "../getVercelAIMessage"; -import type { VercelHelpers } from "./VercelHelpers"; - -const getIsRunning = (vercel: VercelHelpers) => { - if ("isLoading" in vercel) return vercel.isLoading; - return vercel.status === "in_progress"; -}; - -const vercelToThreadMessage = ( - messages: Message[], - status: MessageStatus, -): VercelAIThreadMessage => { - const firstMessage = messages[0]; - if (!firstMessage) throw new Error("No messages found"); - - const common = { - id: firstMessage.id, - createdAt: firstMessage.createdAt ?? new Date(), - [symbolInnerAIMessage]: messages, - }; - - switch (firstMessage.role) { - case "user": - if (messages.length > 1) { - throw new Error( - "Multiple user messages found. This is likely an internal bug in assistant-ui.", - ); - } - - return { - ...common, - role: "user", - content: [{ type: "text", text: firstMessage.content }], - }; - - case "system": - return { - ...common, - role: "system", - content: [{ type: "text", text: firstMessage.content }], - }; - - case "data": - case "assistant": { - const res: ThreadAssistantMessage = { - ...common, - role: "assistant", - content: messages.flatMap((message) => { - return [ - ...(message.content - ? [{ type: "text", text: message.content } as TextContentPart] - : []), - ...(message.toolInvocations?.map( - (t) => - ({ - type: "tool-call", - toolName: t.toolName, - toolCallId: t.toolCallId, - argsText: JSON.stringify(t.args), - args: t.args, - result: "result" in t ? t.result : undefined, - }) satisfies ToolCallContentPart, - ) ?? []), - ...(typeof message.data === "object" && - !Array.isArray(message.data) && - message.data?.["type"] === "tool-call" - ? [message.data as ToolCallContentPart] - : []), - ]; - }), - status, - }; - - for (const message of messages) { - if ( - typeof message.data === "object" && - !Array.isArray(message.data) && - message.data?.["type"] === "tool-result" - ) { - const toolCallId = message.data["toolCallId"]; - const toolContent = res.content.find( - (c) => c.type === "tool-call" && c.toolCallId === toolCallId, - ) as ToolCallContentPart | undefined; - if (!toolContent) throw new Error("Tool call not found"); - toolContent.result = message.data["result"]; - } - } - - return res; - } - - default: - const _unsupported: "function" | "tool" = firstMessage.role; - throw new Error( - `You have a message with an unsupported role. The role ${_unsupported} is not supported.`, - ); - } -}; - -type Chunk = [Message, ...Message[]]; -const hasItems = (messages: Message[]): messages is Chunk => - messages.length > 0; - -const chunkedMessages = (messages: Message[]): Chunk[] => { - const chunks: Chunk[] = []; - let currentChunk: Message[] = []; - - for (const message of messages) { - if (message.role === "assistant" || message.role === "data") { - currentChunk.push(message); - } else { - if (hasItems(currentChunk)) { - chunks.push(currentChunk); - currentChunk = []; - } - chunks.push([message]); - } - } - - if (hasItems(currentChunk)) { - chunks.push(currentChunk); - } - - return chunks; -}; - -const shallowArrayEqual = (a: unknown[], b: unknown[]) => { - if (a.length !== b.length) return false; - for (let i = 0; i < a.length; i++) { - if (a[i] !== b[i]) return false; - } - return true; -}; - -type UpdateDataCallback = (isRunning: boolean, vm: ThreadMessage[]) => void; - -export const useVercelAIThreadSync = ( - vercel: VercelHelpers, - updateData: UpdateDataCallback, -) => { - const isRunning = getIsRunning(vercel); - - const converter = useMemo(() => new ThreadMessageConverter(), []); - - useEffect(() => { - const lastMessageId = vercel.messages.at(-1)?.id; - const convertCallback: ConverterCallback = (messages, cache) => { - const status: MessageStatus = - lastMessageId === messages[0].id && isRunning - ? { - type: "running", - } - : { - type: "complete", - reason: "unknown", - }; - - if ( - cache && - shallowArrayEqual(cache.content, messages) && - (cache.role !== "assistant" || cache.status.type === status.type) - ) - return cache; - - return vercelToThreadMessage(messages, status); - }; - - const messages = converter.convertMessages( - chunkedMessages(vercel.messages), - convertCallback, - (m) => m[0], - ); - - updateData(isRunning, messages); - }, [updateData, isRunning, vercel.messages, converter]); -}; diff --git a/packages/react-ai-sdk/src/utils/ThreadMessageConverter.ts b/packages/react-ai-sdk/src/utils/ThreadMessageConverter.ts deleted file mode 100644 index 6c1f571082..0000000000 --- a/packages/react-ai-sdk/src/utils/ThreadMessageConverter.ts +++ /dev/null @@ -1,24 +0,0 @@ -import type { ThreadMessage } from "@assistant-ui/react"; - -export type ConverterCallback = ( - message: TIn, - cache: ThreadMessage | undefined, -) => ThreadMessage; - -export class ThreadMessageConverter { - private readonly cache = new WeakMap(); - - convertMessages( - messages: TIn[], - converter: ConverterCallback, - keyMapper: (m: TIn) => WeakKey = (key) => key, - ): ThreadMessage[] { - return messages.map((m) => { - const key = keyMapper(m); - const cached = this.cache.get(key); - const newMessage = converter(m, cached); - this.cache.set(key, newMessage); - return newMessage; - }); - } -} diff --git a/packages/react/src/context/providers/ThreadProvider.tsx b/packages/react/src/context/providers/ThreadProvider.tsx index c00c44b3cd..f0c8fc1927 100644 --- a/packages/react/src/context/providers/ThreadProvider.tsx +++ b/packages/react/src/context/providers/ThreadProvider.tsx @@ -1,5 +1,5 @@ import type { FC, PropsWithChildren } from "react"; -import { useCallback, useInsertionEffect, useState } from "react"; +import { useEffect, useInsertionEffect, useState } from "react"; import type { ReactThreadRuntime } from "../../runtimes/core/ReactThreadRuntime"; import type { ThreadContextValue } from "../react/ThreadContext"; import { ThreadContext } from "../react/ThreadContext"; @@ -21,7 +21,7 @@ import { makeThreadRuntimeStore, ThreadRuntimeStore, } from "../stores/ThreadRuntime"; -import { useManagedRef } from "../../utils/hooks/useManagedRef"; +import { subscribeToMainThread } from "../../runtimes"; type ThreadProviderProps = { provider: ThreadRuntimeWithSubscribe; @@ -50,65 +50,56 @@ export const ThreadProvider: FC> = ({ }); // TODO it might make sense to move this into the make* functions - const threadRef = useManagedRef( - useCallback( - (thread: ReactThreadRuntime) => { - const onThreadUpdate = () => { - const oldState = context.useThread.getState(); - const state = getThreadStateFromRuntime(thread); - if ( - oldState.isDisabled !== state.isDisabled || - oldState.isRunning !== state.isRunning || - // TODO ensure capabilities is memoized - oldState.capabilities !== state.capabilities - ) { - (context.useThread as unknown as StoreApi).setState( - state, - true, - ); - } + useEffect(() => { + const onThreadUpdate = () => { + const thread = provider.thread; - if (thread.messages !== context.useThreadMessages.getState()) { - ( - context.useThreadMessages as unknown as StoreApi - ).setState(thread.messages, true); - } + const oldState = context.useThread.getState(); + const state = getThreadStateFromRuntime(thread); + if ( + oldState.isDisabled !== state.isDisabled || + oldState.isRunning !== state.isRunning || + // TODO ensure capabilities is memoized + oldState.capabilities !== state.capabilities + ) { + (context.useThread as unknown as StoreApi).setState( + state, + true, + ); + } - const composerState = context.useComposer.getState(); - if ( - thread.composer.text !== composerState.text || - state.capabilities.cancel !== composerState.canCancel - ) { - ( - context.useComposer as unknown as StoreApi - ).setState({ - text: thread.composer.text, - canCancel: state.capabilities.cancel, - }); - } - }; + if (thread.messages !== context.useThreadMessages.getState()) { + ( + context.useThreadMessages as unknown as StoreApi + ).setState(thread.messages, true); + } - onThreadUpdate(); - return thread.subscribe(onThreadUpdate); - }, - [context], - ), - ); - - useInsertionEffect(() => { - const unsubscribe = provider.subscribe(() => { - ( - context.useThreadRuntime as unknown as StoreApi - ).setState(provider.thread, true); - threadRef(provider.thread); - }); - threadRef(provider.thread); - return () => { - unsubscribe(); - threadRef(null); + const composerState = context.useComposer.getState(); + if ( + thread.composer.text !== composerState.text || + state.capabilities.cancel !== composerState.canCancel + ) { + (context.useComposer as unknown as StoreApi).setState({ + text: thread.composer.text, + canCancel: state.capabilities.cancel, + }); + } }; + + onThreadUpdate(); + return subscribeToMainThread(provider, onThreadUpdate); }, [provider, context]); + useInsertionEffect( + () => + provider.subscribe(() => { + ( + context.useThreadRuntime as unknown as StoreApi + ).setState(provider.thread, true); + }), + [provider, context], + ); + // subscribe to thread updates const Synchronizer = context.useThreadRuntime( (t) => (t as ReactThreadRuntime).unstable_synchronizer, diff --git a/packages/react/src/runtimes/core/index.ts b/packages/react/src/runtimes/core/index.ts index e9e92ff013..c4255c5088 100644 --- a/packages/react/src/runtimes/core/index.ts +++ b/packages/react/src/runtimes/core/index.ts @@ -1,3 +1,5 @@ export type { AssistantRuntime } from "./AssistantRuntime"; export type { ThreadRuntime } from "./ThreadRuntime"; export type { ReactThreadRuntime } from "./ReactThreadRuntime"; + +export { subscribeToMainThread } from "./subscribeToMainThread"; diff --git a/packages/react/src/runtimes/core/subscribeToMainThread.ts b/packages/react/src/runtimes/core/subscribeToMainThread.ts new file mode 100644 index 0000000000..f8bca30e1f --- /dev/null +++ b/packages/react/src/runtimes/core/subscribeToMainThread.ts @@ -0,0 +1,27 @@ +import { Unsubscribe } from "../../types"; +import { ThreadRuntimeWithSubscribe } from "./AssistantRuntime"; + +export const subscribeToMainThread = ( + runtime: ThreadRuntimeWithSubscribe, + callback: () => void, +) => { + let first = true; + let cleanup: Unsubscribe | undefined; + const inner = () => { + cleanup?.(); + cleanup = runtime.thread.subscribe(callback); + + if (!first) { + callback(); + } + first = false; + }; + + const unsubscribe = runtime.subscribe(inner); + inner(); + + return () => { + unsubscribe(); + cleanup?.(); + }; +}; diff --git a/packages/react/src/runtimes/external-store/ExternalStoreRuntime.tsx b/packages/react/src/runtimes/external-store/ExternalStoreRuntime.tsx index 931758dfab..5b199e34bd 100644 --- a/packages/react/src/runtimes/external-store/ExternalStoreRuntime.tsx +++ b/packages/react/src/runtimes/external-store/ExternalStoreRuntime.tsx @@ -22,17 +22,25 @@ export class ExternalStoreRuntime extends BaseAssistantRuntime = (cache, m, idx) => { - if (!store.convertMessage) return m; - - const isLast = idx === store.messages.length - 1; - const autoStatus = getAutoStatus(isLast, isRunning); - - if ( - cache && - (cache.role !== "assistant" || - !isAutoStatus(cache.status) || - cache.status === autoStatus) - ) - return cache; - - const newMessage = fromThreadMessageLike( - store.convertMessage(m, idx), - idx.toString(), - autoStatus, - ); - (newMessage as any)[symbolInnerMessage] = m; - return newMessage; - }; - - const messages = this.converter.convertMessages( - store.messages, - convertCallback, - ); + const messages = !store.convertMessage + ? store.messages + : this.converter.convertMessages(store.messages, (cache, m, idx) => { + if (!store.convertMessage) return m; + + const isLast = idx === store.messages.length - 1; + const autoStatus = getAutoStatus(isLast, isRunning); + + if ( + cache && + (cache.role !== "assistant" || + !isAutoStatus(cache.status) || + cache.status === autoStatus) + ) + return cache; + + const newMessage = fromThreadMessageLike( + store.convertMessage(m, idx), + idx.toString(), + autoStatus, + ); + (newMessage as any)[symbolInnerMessage] = m; + return newMessage; + }); for (let i = 0; i < messages.length; i++) { const message = messages[i]!; @@ -189,6 +184,20 @@ export class ExternalStoreThreadRuntime implements ReactThreadRuntime { } let messages = this.repository.getMessages(); + const previousMessage = messages[messages.length - 1]; + if ( + previousMessage?.role === "user" && + previousMessage.id === messages.at(-1)?.id // ensure the previous message is a leaf node + ) { + this.repository.deleteMessage(previousMessage.id); + if (!this.composer.text.trim()) { + this.composer.setText(getThreadMessageText(previousMessage)); + } + + messages = this.repository.getMessages(); + } else { + this.notifySubscribers(); + } // resync messages (for reloading, to restore the previous branch) setTimeout(() => { diff --git a/packages/react/src/runtimes/external-store/ThreadMessageConverter.ts b/packages/react/src/runtimes/external-store/ThreadMessageConverter.ts index 64f0167eba..f25d2339e7 100644 --- a/packages/react/src/runtimes/external-store/ThreadMessageConverter.ts +++ b/packages/react/src/runtimes/external-store/ThreadMessageConverter.ts @@ -12,13 +12,11 @@ export class ThreadMessageConverter { convertMessages( messages: TIn[], converter: ConverterCallback, - keyMapper: (m: TIn) => WeakKey = (key) => key, ): ThreadMessage[] { return messages.map((m, idx) => { - const key = keyMapper(m); - const cached = this.cache.get(key); + const cached = this.cache.get(m); const newMessage = converter(cached, m, idx); - this.cache.set(key, newMessage); + this.cache.set(m, newMessage); return newMessage; }); } diff --git a/packages/react/src/runtimes/external-store/useExternalStoreRuntime.tsx b/packages/react/src/runtimes/external-store/useExternalStoreRuntime.tsx index 73390ea9d9..3811db2e01 100644 --- a/packages/react/src/runtimes/external-store/useExternalStoreRuntime.tsx +++ b/packages/react/src/runtimes/external-store/useExternalStoreRuntime.tsx @@ -1,11 +1,11 @@ -import { useInsertionEffect, useState } from "react"; +import { useEffect, useState } from "react"; import { ExternalStoreRuntime } from "./ExternalStoreRuntime"; import { ExternalStoreAdapter } from "./ExternalStoreAdapter"; export const useExternalStoreRuntime = (store: ExternalStoreAdapter) => { const [runtime] = useState(() => new ExternalStoreRuntime(store)); - useInsertionEffect(() => { + useEffect(() => { runtime.store = store; });