Skip to content

Commit

Permalink
feat: system message support (#435)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yonom authored Jul 9, 2024
1 parent 5a54b8c commit 679cd54
Show file tree
Hide file tree
Showing 19 changed files with 117 additions and 50 deletions.
7 changes: 7 additions & 0 deletions .changeset/tough-buttons-refuse.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
"@assistant-ui/react-ai-sdk": minor
"@assistant-ui/react-ui": minor
"@assistant-ui/react": minor
---

feat: system message support
21 changes: 20 additions & 1 deletion apps/docs/components/docs/parameters/context.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -334,9 +334,28 @@ export const ContentPartState: ParametersTableProps = {
},
{
name: "status",
type: "'done' | 'in_progress' | 'error'",
type: "MessageStatus",
required: true,
description: "The current content part status.",
children: [
{
type: "MessageStatus",
parameters: [
{
name: "type",
type: "'in_progress' | 'done' | 'error'",
required: true,
description: "The status.",
},
{
name: "error",
type: "unknown",
required: false,
description: "The error object if the status is 'error'.",
},
],
},
],
},
],
};
Expand Down
2 changes: 1 addition & 1 deletion examples/with-openai-assistants/app/page.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ const WeatherTool = makeAssistantToolUI<WeatherArgs, WeatherResult>({
<p
className={cn(
"my-4 text-center font-mono text-sm font-bold text-blue-500 first:mt-0",
status === "in_progress" && "animate-pulse",
status.type === "in_progress" && "animate-pulse",
)}
>
get_weather({JSON.stringify(part.args)})
Expand Down
4 changes: 4 additions & 0 deletions packages/react-ai-sdk/src/core/convertToCoreMessage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ import type {
} from "ai";

export const convertToCoreMessage = (message: ThreadMessage): CoreMessage[] => {
if (message.role === "system") {
return [{ role: "system", content: message.content[0].text }];
}

const expandedMessages: CoreMessage[] = [
{
role: message.role,
Expand Down
2 changes: 1 addition & 1 deletion packages/react-ai-sdk/src/rsc/useVercelRSCSync.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ const vercelToThreadMessage = <T,>(
role: message.role,
content: [{ type: "ui", display: message.display }],
createdAt: message.createdAt ?? new Date(),
...{ status: "done" },
...{ status: { type: "done" } },
[symbolInnerRSCMessage]: rawMessage,
};
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ export class VercelUseAssistantThreadRuntime implements ReactThreadRuntime {
vm.push({
id: "__optimistic__result",
createdAt: new Date(),
status: "in_progress",
status: { type: "in_progress" },
role: "assistant",
content: [{ type: "text", text: "" }],
});
Expand Down
26 changes: 20 additions & 6 deletions packages/react-ai-sdk/src/ui/utils/useVercelAIThreadSync.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import type {
TextContentPart,
ThreadMessage,
ToolCallContentPart,
MessageStatus,
} from "@assistant-ui/react";
import type { Message } from "ai";
import { useEffect, useMemo } from "react";
Expand All @@ -23,7 +24,7 @@ const getIsRunning = (vercel: VercelHelpers) => {

const vercelToThreadMessage = (
messages: Message[],
status: "in_progress" | "done" | "error",
status: MessageStatus,
): VercelAIThreadMessage => {
const firstMessage = messages[0];
if (!firstMessage) throw new Error("No messages found");
Expand All @@ -34,7 +35,8 @@ const vercelToThreadMessage = (
[symbolInnerAIMessage]: messages,
};

switch (firstMessage.role) {
const role = firstMessage.role;
switch (role) {
case "user":
if (messages.length > 1) {
throw new Error(
Expand All @@ -48,6 +50,13 @@ const vercelToThreadMessage = (
content: [{ type: "text", text: firstMessage.content }],
};

case "system":
return {
...common,
role: "system",
content: [{ type: "text", text: firstMessage.content }],
};

case "data":
case "assistant": {
const res: AssistantMessage = {
Expand Down Expand Up @@ -97,8 +106,9 @@ const vercelToThreadMessage = (
}

default:
const _unsupported: "function" | "tool" = role;
throw new Error(
`123 You have a message with an unsupported role. The role ${firstMessage.role} is not supported.`,
`You have a message with an unsupported role. The role ${_unsupported} is not supported.`,
);
}
};
Expand Down Expand Up @@ -151,13 +161,17 @@ export const useVercelAIThreadSync = (
useEffect(() => {
const lastMessageId = vercel.messages.at(-1)?.id;
const convertCallback: ConverterCallback<Chunk> = (messages, cache) => {
const status =
lastMessageId === messages[0].id && isRunning ? "in_progress" : "done";
const status: MessageStatus = {
type:
lastMessageId === messages[0].id && isRunning
? "in_progress"
: "done",
};

if (
cache &&
shallowArrayEqual(cache.content, messages) &&
(cache.role === "user" || cache.status === status)
(cache.role !== "assistant" || cache.status.type === status.type)
)
return cache;

Expand Down
7 changes: 5 additions & 2 deletions packages/react-ui/src/components/markdown-text.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ export const makeMarkdownText = ({
<div
className={classNames(
"aui-md-root",
status === "in_progress" && "aui-md-in-progress",
status.type === "in_progress" && "aui-md-in-progress",
className,
)}
>
Expand All @@ -32,5 +32,8 @@ export const makeMarkdownText = ({
};
MarkdownTextImpl.displayName = "MarkdownText";

return memo(MarkdownTextImpl, (prev, next) => prev.status === next.status);
return memo(
MarkdownTextImpl,
(prev, next) => prev.status.type === next.status.type,
);
};
2 changes: 1 addition & 1 deletion packages/react-ui/src/components/text.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ export const Text: FC<TextContentPartProps> = ({ status }) => {
<p
className={classNames(
"aui-text",
status === "in_progress" && "aui-text-in-progress",
status.type === "in_progress" && "aui-text-in-progress",
)}
>
<ContentPartPrimitive.Text />
Expand Down
4 changes: 4 additions & 0 deletions packages/react-ui/src/components/thread.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,14 @@ export const ThreadViewportFooter = withDefaults("div", {

ThreadViewportFooter.displayName = "ThreadViewportFooter";

const SystemMessage = () => null;

export const ThreadMessages: FC<{
components?: {
UserMessage?: ComponentType | undefined;
EditComposer?: ComponentType | undefined;
AssistantMessage?: ComponentType | undefined;
SystemMessage?: ComponentType | undefined;
};
}> = ({ components, ...rest }) => {
return (
Expand All @@ -80,6 +83,7 @@ export const ThreadMessages: FC<{
UserMessage: components?.UserMessage ?? UserMessage,
EditComposer: components?.EditComposer ?? EditComposer,
AssistantMessage: components?.AssistantMessage ?? AssistantMessage,
SystemMessage: components?.SystemMessage ?? SystemMessage,
}}
{...rest}
/>
Expand Down
8 changes: 6 additions & 2 deletions packages/react/src/context/providers/ContentPartProvider.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@ import type { ContentPartContextValue } from "../react/ContentPartContext";
import { useMessageContext } from "../react/MessageContext";
import type { MessageState } from "../stores";
import type { ContentPartState } from "../stores/ContentPart";
import { MessageStatus } from "../../types";

type ContentPartProviderProps = PropsWithChildren<{
partIndex: number;
}>;

const DONE_STATUS: MessageStatus = { type: "done" };

const syncContentPart = (
{ message }: MessageState,
useContentPart: ContentPartContextValue["useContentPart"],
Expand All @@ -20,9 +23,10 @@ const syncContentPart = (
const part = message.content[partIndex];
if (!part) return;

const messageStatus = message.role === "assistant" ? message.status : "done";
const messageStatus =
message.role === "assistant" ? message.status : DONE_STATUS;
const status =
partIndex === message.content.length - 1 ? messageStatus : "done";
partIndex === message.content.length - 1 ? messageStatus : DONE_STATUS;

// if the content part is the same, don't update
const currentState = useContentPart.getState();
Expand Down
11 changes: 6 additions & 5 deletions packages/react/src/context/stores/ContentPart.ts
Original file line number Diff line number Diff line change
@@ -1,32 +1,33 @@
import type {
ImageContentPart,
MessageStatus,
TextContentPart,
ToolCallContentPart,
UIContentPart,
} from "../../types/AssistantTypes";

export type TextContentPartState = Readonly<{
status: "in_progress" | "done" | "error";
status: MessageStatus;
part: TextContentPart;
}>;

export type ImageContentPartState = Readonly<{
status: "in_progress" | "done" | "error";
status: MessageStatus;
part: ImageContentPart;
}>;

export type UIContentPartState = Readonly<{
status: "in_progress" | "done" | "error";
status: MessageStatus;
part: UIContentPart;
}>;

export type ToolCallContentPartState = Readonly<{
status: "in_progress" | "done" | "error";
status: MessageStatus;
part: ToolCallContentPart;
}>;

export type ContentPartState = Readonly<{
status: "in_progress" | "done" | "error";
status: MessageStatus;
part:
| TextContentPart
| ImageContentPart
Expand Down
2 changes: 2 additions & 0 deletions packages/react/src/primitive-hooks/message/useMessageIf.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import { useCombinedStore } from "../../utils/combined/useCombinedStore";
type MessageIfFilters = {
user: boolean | undefined;
assistant: boolean | undefined;
system: boolean | undefined;
hasBranches: boolean | undefined;
copied: boolean | undefined;
lastOrHover: boolean | undefined;
Expand All @@ -22,6 +23,7 @@ export const useMessageIf = (props: UseMessageIfProps) => {

if (props.user && message.role !== "user") return false;
if (props.assistant && message.role !== "assistant") return false;
if (props.system && message.role !== "system") return false;

if (props.lastOrHover === true && !isHovering && !isLast) return false;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@ import { useContentPartContext } from "../../context";

export type ContentPartPrimitiveInProgressProps = PropsWithChildren;

export const ContentPartPrimitiveInProgress: FC<ContentPartPrimitiveInProgressProps> = ({
children,
}) => {
export const ContentPartPrimitiveInProgress: FC<
ContentPartPrimitiveInProgressProps
> = ({ children }) => {
const { useContentPart } = useContentPartContext();
const isInProgress = useContentPart((c) => c.status === "in_progress");
const isInProgress = useContentPart((c) => c.status.type === "in_progress");

return isInProgress ? children : null;
};
Expand Down
16 changes: 13 additions & 3 deletions packages/react/src/primitives/thread/ThreadMessages.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,19 @@ export type ThreadPrimitiveMessagesProps = {
UserMessage?: ComponentType | undefined;
EditComposer?: ComponentType | undefined;
AssistantMessage?: ComponentType | undefined;
SystemMessage?: ComponentType | undefined;
}
| {
Message?: ComponentType | undefined;
UserMessage: ComponentType;
EditComposer?: ComponentType | undefined;
AssistantMessage: ComponentType;
SystemMessage?: ComponentType | undefined;
};
};

const DEFAULT_SYSTEM_MESSAGE = () => null;

const getComponents = (
components: ThreadPrimitiveMessagesProps["components"],
) => {
Expand All @@ -34,6 +38,7 @@ const getComponents = (
components.UserMessage ?? (components.Message as ComponentType),
AssistantMessage:
components.AssistantMessage ?? (components.Message as ComponentType),
SystemMessage: components.SystemMessage ?? DEFAULT_SYSTEM_MESSAGE,
};
};

Expand All @@ -46,7 +51,7 @@ const ThreadMessageImpl: FC<ThreadMessageProps> = ({
messageIndex,
components,
}) => {
const { UserMessage, EditComposer, AssistantMessage } =
const { UserMessage, EditComposer, AssistantMessage, SystemMessage } =
getComponents(components);
return (
<MessageProvider messageIndex={messageIndex}>
Expand All @@ -61,6 +66,9 @@ const ThreadMessageImpl: FC<ThreadMessageProps> = ({
<MessagePrimitiveIf assistant>
<AssistantMessage />
</MessagePrimitiveIf>
<MessagePrimitiveIf system>
<SystemMessage />
</MessagePrimitiveIf>
</MessageProvider>
);
};
Expand All @@ -72,7 +80,8 @@ const ThreadMessage = memo(
prev.components.Message === next.components.Message &&
prev.components.UserMessage === next.components.UserMessage &&
prev.components.EditComposer === next.components.EditComposer &&
prev.components.AssistantMessage === next.components.AssistantMessage,
prev.components.AssistantMessage === next.components.AssistantMessage &&
prev.components.SystemMessage === next.components.SystemMessage,
);

export const ThreadPrimitiveMessagesImpl: FC<ThreadPrimitiveMessagesProps> = ({
Expand Down Expand Up @@ -103,5 +112,6 @@ export const ThreadPrimitiveMessages = memo(
prev.components?.Message === next.components?.Message &&
prev.components?.UserMessage === next.components?.UserMessage &&
prev.components?.EditComposer === next.components?.EditComposer &&
prev.components?.AssistantMessage === next.components?.AssistantMessage,
prev.components?.AssistantMessage === next.components?.AssistantMessage &&
prev.components?.SystemMessage === next.components?.SystemMessage,
);
7 changes: 3 additions & 4 deletions packages/react/src/runtime/local/LocalRuntime.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ class LocalThreadRuntime implements ThreadRuntime {
const message: AssistantMessage = {
id,
role: "assistant",
status: "in_progress",
status: { type: "in_progress" },
content: [{ type: "text", text: "" }],
createdAt: new Date(),
};
Expand Down Expand Up @@ -132,11 +132,10 @@ class LocalThreadRuntime implements ThreadRuntime {
updateHandler(result);
}

message.status = "done";
message.status = { type: "done" };
this.repository.addOrUpdateMessage(parentId, { ...message });
} catch (e) {
(message as any).status = "error";
(message as any).error = e;
message.status = { type: "error", error: e };
this.repository.addOrUpdateMessage(parentId, { ...message });
console.error(e);
} finally {
Expand Down
Loading

0 comments on commit 679cd54

Please sign in to comment.