diff --git a/.changeset/tough-buttons-refuse.md b/.changeset/tough-buttons-refuse.md new file mode 100644 index 000000000..5d60cf338 --- /dev/null +++ b/.changeset/tough-buttons-refuse.md @@ -0,0 +1,7 @@ +--- +"@assistant-ui/react-ai-sdk": minor +"@assistant-ui/react-ui": minor +"@assistant-ui/react": minor +--- + +feat: system message support diff --git a/apps/docs/components/docs/parameters/context.tsx b/apps/docs/components/docs/parameters/context.tsx index 7302092c3..b9c0cebbf 100644 --- a/apps/docs/components/docs/parameters/context.tsx +++ b/apps/docs/components/docs/parameters/context.tsx @@ -334,9 +334,28 @@ export const ContentPartState: ParametersTableProps = { }, { name: "status", - type: "'done' | 'in_progress' | 'error'", + type: "MessageStatus", required: true, description: "The current content part status.", + children: [ + { + type: "MessageStatus", + parameters: [ + { + name: "type", + type: "'in_progress' | 'done' | 'error'", + required: true, + description: "The status.", + }, + { + name: "error", + type: "unknown", + required: false, + description: "The error object if the status is 'error'.", + }, + ], + }, + ], }, ], }; diff --git a/examples/with-openai-assistants/app/page.tsx b/examples/with-openai-assistants/app/page.tsx index 2a37d1cfc..0066e05fb 100644 --- a/examples/with-openai-assistants/app/page.tsx +++ b/examples/with-openai-assistants/app/page.tsx @@ -20,7 +20,7 @@ const WeatherTool = makeAssistantToolUI({

get_weather({JSON.stringify(part.args)}) diff --git a/packages/react-ai-sdk/src/core/convertToCoreMessage.ts b/packages/react-ai-sdk/src/core/convertToCoreMessage.ts index 6f242f959..2bfeaaf52 100644 --- a/packages/react-ai-sdk/src/core/convertToCoreMessage.ts +++ b/packages/react-ai-sdk/src/core/convertToCoreMessage.ts @@ -8,6 +8,10 @@ import type { } from "ai"; export const convertToCoreMessage = (message: ThreadMessage): CoreMessage[] => { + if (message.role === "system") { + return [{ role: "system", content: message.content[0].text }]; + } + const expandedMessages: CoreMessage[] = [ { role: message.role, diff --git a/packages/react-ai-sdk/src/rsc/useVercelRSCSync.tsx b/packages/react-ai-sdk/src/rsc/useVercelRSCSync.tsx index dde46eff9..2b3158ae1 100644 --- a/packages/react-ai-sdk/src/rsc/useVercelRSCSync.tsx +++ b/packages/react-ai-sdk/src/rsc/useVercelRSCSync.tsx @@ -22,7 +22,7 @@ const vercelToThreadMessage = ( role: message.role, content: [{ type: "ui", display: message.display }], createdAt: message.createdAt ?? new Date(), - ...{ status: "done" }, + ...{ status: { type: "done" } }, [symbolInnerRSCMessage]: rawMessage, }; }; diff --git a/packages/react-ai-sdk/src/ui/use-assistant/VercelUseAssistantThreadRuntime.tsx b/packages/react-ai-sdk/src/ui/use-assistant/VercelUseAssistantThreadRuntime.tsx index e1b91ed92..d03174ead 100644 --- a/packages/react-ai-sdk/src/ui/use-assistant/VercelUseAssistantThreadRuntime.tsx +++ b/packages/react-ai-sdk/src/ui/use-assistant/VercelUseAssistantThreadRuntime.tsx @@ -90,7 +90,7 @@ export class VercelUseAssistantThreadRuntime implements ReactThreadRuntime { vm.push({ id: "__optimistic__result", createdAt: new Date(), - status: "in_progress", + status: { type: "in_progress" }, role: "assistant", content: [{ type: "text", text: "" }], }); diff --git a/packages/react-ai-sdk/src/ui/utils/useVercelAIThreadSync.tsx b/packages/react-ai-sdk/src/ui/utils/useVercelAIThreadSync.tsx index a0911ceda..0aad6707d 100644 --- a/packages/react-ai-sdk/src/ui/utils/useVercelAIThreadSync.tsx +++ b/packages/react-ai-sdk/src/ui/utils/useVercelAIThreadSync.tsx @@ -3,6 +3,7 @@ import type { TextContentPart, ThreadMessage, ToolCallContentPart, + MessageStatus, } from "@assistant-ui/react"; import type { Message } from "ai"; import { useEffect, useMemo } from "react"; @@ -23,7 +24,7 @@ const getIsRunning = (vercel: VercelHelpers) => { const vercelToThreadMessage = ( messages: Message[], - status: "in_progress" | "done" | "error", + status: MessageStatus, ): VercelAIThreadMessage => { const firstMessage = messages[0]; if (!firstMessage) throw new Error("No messages found"); @@ -34,7 +35,8 @@ const vercelToThreadMessage = ( [symbolInnerAIMessage]: messages, }; - switch (firstMessage.role) { + const role = firstMessage.role; + switch (role) { case "user": if (messages.length > 1) { throw new Error( @@ -48,6 +50,13 @@ const vercelToThreadMessage = ( content: [{ type: "text", text: firstMessage.content }], }; + case "system": + return { + ...common, + role: "system", + content: [{ type: "text", text: firstMessage.content }], + }; + case "data": case "assistant": { const res: AssistantMessage = { @@ -97,8 +106,9 @@ const vercelToThreadMessage = ( } default: + const _unsupported: "function" | "tool" = role; throw new Error( - `123 You have a message with an unsupported role. The role ${firstMessage.role} is not supported.`, + `You have a message with an unsupported role. The role ${_unsupported} is not supported.`, ); } }; @@ -151,13 +161,17 @@ export const useVercelAIThreadSync = ( useEffect(() => { const lastMessageId = vercel.messages.at(-1)?.id; const convertCallback: ConverterCallback = (messages, cache) => { - const status = - lastMessageId === messages[0].id && isRunning ? "in_progress" : "done"; + const status: MessageStatus = { + type: + lastMessageId === messages[0].id && isRunning + ? "in_progress" + : "done", + }; if ( cache && shallowArrayEqual(cache.content, messages) && - (cache.role === "user" || cache.status === status) + (cache.role !== "assistant" || cache.status.type === status.type) ) return cache; diff --git a/packages/react-ui/src/components/markdown-text.tsx b/packages/react-ui/src/components/markdown-text.tsx index 866d508f1..9bfeec59c 100644 --- a/packages/react-ui/src/components/markdown-text.tsx +++ b/packages/react-ui/src/components/markdown-text.tsx @@ -22,7 +22,7 @@ export const makeMarkdownText = ({

@@ -32,5 +32,8 @@ export const makeMarkdownText = ({ }; MarkdownTextImpl.displayName = "MarkdownText"; - return memo(MarkdownTextImpl, (prev, next) => prev.status === next.status); + return memo( + MarkdownTextImpl, + (prev, next) => prev.status.type === next.status.type, + ); }; diff --git a/packages/react-ui/src/components/text.tsx b/packages/react-ui/src/components/text.tsx index 015a6a747..3a2de6384 100644 --- a/packages/react-ui/src/components/text.tsx +++ b/packages/react-ui/src/components/text.tsx @@ -10,7 +10,7 @@ export const Text: FC = ({ status }) => {

diff --git a/packages/react-ui/src/components/thread.tsx b/packages/react-ui/src/components/thread.tsx index 366c5935a..9134ab852 100644 --- a/packages/react-ui/src/components/thread.tsx +++ b/packages/react-ui/src/components/thread.tsx @@ -67,11 +67,14 @@ export const ThreadViewportFooter = withDefaults("div", { ThreadViewportFooter.displayName = "ThreadViewportFooter"; +const SystemMessage = () => null; + export const ThreadMessages: FC<{ components?: { UserMessage?: ComponentType | undefined; EditComposer?: ComponentType | undefined; AssistantMessage?: ComponentType | undefined; + SystemMessage?: ComponentType | undefined; }; }> = ({ components, ...rest }) => { return ( @@ -80,6 +83,7 @@ export const ThreadMessages: FC<{ UserMessage: components?.UserMessage ?? UserMessage, EditComposer: components?.EditComposer ?? EditComposer, AssistantMessage: components?.AssistantMessage ?? AssistantMessage, + SystemMessage: components?.SystemMessage ?? SystemMessage, }} {...rest} /> diff --git a/packages/react/src/context/providers/ContentPartProvider.tsx b/packages/react/src/context/providers/ContentPartProvider.tsx index 1fbd2df7e..e7d1ba752 100644 --- a/packages/react/src/context/providers/ContentPartProvider.tsx +++ b/packages/react/src/context/providers/ContentPartProvider.tsx @@ -7,11 +7,14 @@ import type { ContentPartContextValue } from "../react/ContentPartContext"; import { useMessageContext } from "../react/MessageContext"; import type { MessageState } from "../stores"; import type { ContentPartState } from "../stores/ContentPart"; +import { MessageStatus } from "../../types"; type ContentPartProviderProps = PropsWithChildren<{ partIndex: number; }>; +const DONE_STATUS: MessageStatus = { type: "done" }; + const syncContentPart = ( { message }: MessageState, useContentPart: ContentPartContextValue["useContentPart"], @@ -20,9 +23,10 @@ const syncContentPart = ( const part = message.content[partIndex]; if (!part) return; - const messageStatus = message.role === "assistant" ? message.status : "done"; + const messageStatus = + message.role === "assistant" ? message.status : DONE_STATUS; const status = - partIndex === message.content.length - 1 ? messageStatus : "done"; + partIndex === message.content.length - 1 ? messageStatus : DONE_STATUS; // if the content part is the same, don't update const currentState = useContentPart.getState(); diff --git a/packages/react/src/context/stores/ContentPart.ts b/packages/react/src/context/stores/ContentPart.ts index 4f273d732..d18076a92 100644 --- a/packages/react/src/context/stores/ContentPart.ts +++ b/packages/react/src/context/stores/ContentPart.ts @@ -1,32 +1,33 @@ import type { ImageContentPart, + MessageStatus, TextContentPart, ToolCallContentPart, UIContentPart, } from "../../types/AssistantTypes"; export type TextContentPartState = Readonly<{ - status: "in_progress" | "done" | "error"; + status: MessageStatus; part: TextContentPart; }>; export type ImageContentPartState = Readonly<{ - status: "in_progress" | "done" | "error"; + status: MessageStatus; part: ImageContentPart; }>; export type UIContentPartState = Readonly<{ - status: "in_progress" | "done" | "error"; + status: MessageStatus; part: UIContentPart; }>; export type ToolCallContentPartState = Readonly<{ - status: "in_progress" | "done" | "error"; + status: MessageStatus; part: ToolCallContentPart; }>; export type ContentPartState = Readonly<{ - status: "in_progress" | "done" | "error"; + status: MessageStatus; part: | TextContentPart | ImageContentPart diff --git a/packages/react/src/primitive-hooks/message/useMessageIf.tsx b/packages/react/src/primitive-hooks/message/useMessageIf.tsx index aeb9dbff8..d295c8d12 100644 --- a/packages/react/src/primitive-hooks/message/useMessageIf.tsx +++ b/packages/react/src/primitive-hooks/message/useMessageIf.tsx @@ -6,6 +6,7 @@ import { useCombinedStore } from "../../utils/combined/useCombinedStore"; type MessageIfFilters = { user: boolean | undefined; assistant: boolean | undefined; + system: boolean | undefined; hasBranches: boolean | undefined; copied: boolean | undefined; lastOrHover: boolean | undefined; @@ -22,6 +23,7 @@ export const useMessageIf = (props: UseMessageIfProps) => { if (props.user && message.role !== "user") return false; if (props.assistant && message.role !== "assistant") return false; + if (props.system && message.role !== "system") return false; if (props.lastOrHover === true && !isHovering && !isLast) return false; diff --git a/packages/react/src/primitives/contentPart/ContentPartInProgress.tsx b/packages/react/src/primitives/contentPart/ContentPartInProgress.tsx index c806a3f29..a93080b7f 100644 --- a/packages/react/src/primitives/contentPart/ContentPartInProgress.tsx +++ b/packages/react/src/primitives/contentPart/ContentPartInProgress.tsx @@ -3,11 +3,11 @@ import { useContentPartContext } from "../../context"; export type ContentPartPrimitiveInProgressProps = PropsWithChildren; -export const ContentPartPrimitiveInProgress: FC = ({ - children, -}) => { +export const ContentPartPrimitiveInProgress: FC< + ContentPartPrimitiveInProgressProps +> = ({ children }) => { const { useContentPart } = useContentPartContext(); - const isInProgress = useContentPart((c) => c.status === "in_progress"); + const isInProgress = useContentPart((c) => c.status.type === "in_progress"); return isInProgress ? children : null; }; diff --git a/packages/react/src/primitives/thread/ThreadMessages.tsx b/packages/react/src/primitives/thread/ThreadMessages.tsx index 2e9de6500..81228c3b0 100644 --- a/packages/react/src/primitives/thread/ThreadMessages.tsx +++ b/packages/react/src/primitives/thread/ThreadMessages.tsx @@ -13,15 +13,19 @@ export type ThreadPrimitiveMessagesProps = { UserMessage?: ComponentType | undefined; EditComposer?: ComponentType | undefined; AssistantMessage?: ComponentType | undefined; + SystemMessage?: ComponentType | undefined; } | { Message?: ComponentType | undefined; UserMessage: ComponentType; EditComposer?: ComponentType | undefined; AssistantMessage: ComponentType; + SystemMessage?: ComponentType | undefined; }; }; +const DEFAULT_SYSTEM_MESSAGE = () => null; + const getComponents = ( components: ThreadPrimitiveMessagesProps["components"], ) => { @@ -34,6 +38,7 @@ const getComponents = ( components.UserMessage ?? (components.Message as ComponentType), AssistantMessage: components.AssistantMessage ?? (components.Message as ComponentType), + SystemMessage: components.SystemMessage ?? DEFAULT_SYSTEM_MESSAGE, }; }; @@ -46,7 +51,7 @@ const ThreadMessageImpl: FC = ({ messageIndex, components, }) => { - const { UserMessage, EditComposer, AssistantMessage } = + const { UserMessage, EditComposer, AssistantMessage, SystemMessage } = getComponents(components); return ( @@ -61,6 +66,9 @@ const ThreadMessageImpl: FC = ({ + + + ); }; @@ -72,7 +80,8 @@ const ThreadMessage = memo( prev.components.Message === next.components.Message && prev.components.UserMessage === next.components.UserMessage && prev.components.EditComposer === next.components.EditComposer && - prev.components.AssistantMessage === next.components.AssistantMessage, + prev.components.AssistantMessage === next.components.AssistantMessage && + prev.components.SystemMessage === next.components.SystemMessage, ); export const ThreadPrimitiveMessagesImpl: FC = ({ @@ -103,5 +112,6 @@ export const ThreadPrimitiveMessages = memo( prev.components?.Message === next.components?.Message && prev.components?.UserMessage === next.components?.UserMessage && prev.components?.EditComposer === next.components?.EditComposer && - prev.components?.AssistantMessage === next.components?.AssistantMessage, + prev.components?.AssistantMessage === next.components?.AssistantMessage && + prev.components?.SystemMessage === next.components?.SystemMessage, ); diff --git a/packages/react/src/runtime/local/LocalRuntime.tsx b/packages/react/src/runtime/local/LocalRuntime.tsx index 09bd6177b..93229a2f8 100644 --- a/packages/react/src/runtime/local/LocalRuntime.tsx +++ b/packages/react/src/runtime/local/LocalRuntime.tsx @@ -104,7 +104,7 @@ class LocalThreadRuntime implements ThreadRuntime { const message: AssistantMessage = { id, role: "assistant", - status: "in_progress", + status: { type: "in_progress" }, content: [{ type: "text", text: "" }], createdAt: new Date(), }; @@ -132,11 +132,10 @@ class LocalThreadRuntime implements ThreadRuntime { updateHandler(result); } - message.status = "done"; + message.status = { type: "done" }; this.repository.addOrUpdateMessage(parentId, { ...message }); } catch (e) { - (message as any).status = "error"; - (message as any).error = e; + message.status = { type: "error", error: e }; this.repository.addOrUpdateMessage(parentId, { ...message }); console.error(e); } finally { diff --git a/packages/react/src/types/AssistantTypes.ts b/packages/react/src/types/AssistantTypes.ts index 084e242bb..88dd3f5c1 100644 --- a/packages/react/src/types/AssistantTypes.ts +++ b/packages/react/src/types/AssistantTypes.ts @@ -40,13 +40,13 @@ type MessageCommonProps = { createdAt: Date; }; -type MessageStatusProps = +export type MessageStatus = | { - status: "in_progress" | "done"; + type: "in_progress" | "done"; error?: undefined; } | { - status: "error"; + type: "error"; error: unknown; }; @@ -60,11 +60,11 @@ export type UserMessage = MessageCommonProps & { content: UserContentPart[]; }; -export type AssistantMessage = MessageCommonProps & - MessageStatusProps & { - role: "assistant"; - content: AssistantContentPart[]; - }; +export type AssistantMessage = MessageCommonProps & { + role: "assistant"; + content: AssistantContentPart[]; + status: MessageStatus; +}; export type AppendMessage = { parentId: string | null; @@ -84,11 +84,11 @@ export type CoreUserMessage = MessageCommonProps & { content: CoreUserContentPart[]; }; -export type CoreAssistantMessage = MessageCommonProps & - MessageStatusProps & { - role: "assistant"; - content: CoreAssistantContentPart[]; - }; +export type CoreAssistantMessage = MessageCommonProps & { + role: "assistant"; + content: CoreAssistantContentPart[]; + status: MessageStatus; +}; export type CoreThreadMessage = | SystemMessage diff --git a/packages/react/src/types/ContentPartComponentTypes.tsx b/packages/react/src/types/ContentPartComponentTypes.tsx index f1f19278f..1ddfa360a 100644 --- a/packages/react/src/types/ContentPartComponentTypes.tsx +++ b/packages/react/src/types/ContentPartComponentTypes.tsx @@ -2,34 +2,33 @@ import type { ComponentType } from "react"; import type { ImageContentPart, + MessageStatus, TextContentPart, ToolCallContentPart, UIContentPart, } from "./AssistantTypes"; -type ContentPartStatus = "done" | "in_progress" | "error"; - export type TextContentPartProps = { part: TextContentPart; - status: ContentPartStatus; + status: MessageStatus; }; export type TextContentPartComponent = ComponentType; export type ImageContentPartProps = { part: ImageContentPart; - status: ContentPartStatus; + status: MessageStatus; }; export type ImageContentPartComponent = ComponentType; export type UIContentPartProps = { part: UIContentPart; - status: ContentPartStatus; + status: MessageStatus; }; export type UIContentPartComponent = ComponentType; export type ToolCallContentPartProps = { part: ToolCallContentPart; - status: ContentPartStatus; + status: MessageStatus; addResult: (result: any) => void; }; diff --git a/packages/react/src/types/index.ts b/packages/react/src/types/index.ts index e00bd6b44..dbe74f7f0 100644 --- a/packages/react/src/types/index.ts +++ b/packages/react/src/types/index.ts @@ -10,6 +10,7 @@ export type { ImageContentPart, ToolCallContentPart, UIContentPart, + MessageStatus, // core message types CoreUserContentPart,