Skip to content

Commit

Permalink
feat: add composer attachments state (#762)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yonom authored Sep 7, 2024
1 parent cc5e7d4 commit e58d61b
Show file tree
Hide file tree
Showing 17 changed files with 165 additions and 31 deletions.
24 changes: 24 additions & 0 deletions apps/docs/components/docs/parameters/context.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,30 @@ export const BaseComposerState: ParametersTableProps = {
required: true,
description: "A function to set the text of the composer.",
},
{
name: "attachments",
type: "readonly Attachment[]",
required: true,
description: "The current attachments of the composer.",
},
{
name: "addAttachment",
type: "(attachment: Attachment) => void",
required: true,
description: "A function to add an attachment to the composer.",
},
{
name: "removeAttachment",
type: "(attachmentId: string) => void",
required: true,
description: "A function to remove an attachment from the composer.",
},
{
name: "reset",
type: "() => void",
required: true,
description: "A function to reset the composer.",
},
],
};

Expand Down
16 changes: 14 additions & 2 deletions apps/docs/content/docs/reference/context.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -166,11 +166,23 @@ const runtime = useThreadRuntime();
```tsx
const { useComposer } = useThreadContext();

const value = useComposer((m) => m.value);
const value = useComposer.getState().value;
const text = useComposer((m) => m.text);
const text = useComposer.getState().text;

const setText = useComposer((m) => m.setText);
const setText = useComposer.getState().setText;

const attachments = useComposer((m) => m.attachments);
const attachments = useComposer.getState().attachments;

const addAttachment = useComposer((m) => m.addAttachment);
const addAttachment = useComposer.getState().addAttachment;

const removeAttachment = useComposer((m) => m.removeAttachment);
const removeAttachment = useComposer.getState().removeAttachment;

const reset = useComposer((m) => m.reset);
const reset = useComposer.getState().reset;
```

<ParametersTable {...ComposerState} />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,13 @@ const Composer: FC = () => {
const text = composer.text;
if (!text) return;

composer.setText("");
composer.reset();

useThreadActions.getState().append({
parentId: useThreadMessages.getState().at(-1)?.id ?? null,
role,
content: [{ type: "text", text }],
attachments: composer.attachments,
});

setRole("user");
Expand Down
17 changes: 9 additions & 8 deletions packages/react-playground/src/lib/playground-runtime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,12 @@ import { LanguageModelV1FunctionTool } from "@ai-sdk/provider";
import { useState } from "react";
import { create } from "zustand";

const { BaseAssistantRuntime, ProxyConfigProvider, generateId } = INTERNAL;
const {
BaseAssistantRuntime,
ProxyConfigProvider,
generateId,
ThreadRuntimeComposer,
} = INTERNAL;

const makeModelConfigStore = () =>
create<ModelConfig>(() => ({
Expand Down Expand Up @@ -103,13 +108,9 @@ export class PlaygroundThreadRuntime implements ReactThreadRuntime {

private configProvider = new ProxyConfigProvider();

public readonly composer = {
text: "",
setText: (value: string) => {
this.composer.text = value;
this.notifySubscribers();
},
};
public readonly composer = new ThreadRuntimeComposer(
this.notifySubscribers.bind(this),
);

constructor(
configProvider: ModelConfigProvider,
Expand Down
1 change: 1 addition & 0 deletions packages/react/src/context/providers/MessageProvider.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ const useMessageContext = (messageIndex: number) => {
parentId,
role: "user",
content: [{ type: "text", text }, ...nonTextParts],
attachments: message.attachments,
});
},
});
Expand Down
2 changes: 2 additions & 0 deletions packages/react/src/context/providers/ThreadProvider.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,12 @@ export const ThreadProvider: FC<PropsWithChildren<ThreadProviderProps>> = ({
const composerState = context.useComposer.getState();
if (
thread.composer.text !== composerState.text ||
thread.composer.attachments !== composerState.attachments ||
state.capabilities.cancel !== composerState.canCancel
) {
(context.useComposer as unknown as StoreApi<ComposerState>).setState({
text: thread.composer.text,
attachments: thread.composer.attachments,
canCancel: state.capabilities.cancel,
});
}
Expand Down
11 changes: 11 additions & 0 deletions packages/react/src/context/stores/Attachment.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
export type Attachment = {
id: string;
type: "image" | "document" | "file";
name: string;

file?: File;
};

export type AttachmentState = Readonly<{
attachment: Attachment;
}>;
28 changes: 24 additions & 4 deletions packages/react/src/context/stores/Composer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,23 @@ import { create } from "zustand";
import { ReadonlyStore } from "../ReadonlyStore";
import { Unsubscribe } from "../../types/Unsubscribe";
import { ThreadContextValue } from "../react";
import { Attachment } from "./Attachment";

export type ComposerState = Readonly<{
/** @deprecated Use `text` instead. */
value: string;
/** @deprecated Use `setText` instead. */
setValue: (value: string) => void;

attachments: readonly Attachment[];
addAttachment: (attachment: Attachment) => void;
removeAttachment: (attachmentId: string) => void;

text: string;
setText: (value: string) => void;

reset: () => void;

canCancel: boolean;
isEditing: true;

Expand All @@ -35,9 +42,20 @@ export const makeComposerStore = (
get().setText(value);
},

attachments: runtime.composer.attachments,
addAttachment: (attachment) => {
useThreadRuntime.getState().composer.addAttachment(attachment);
},
removeAttachment: (attachmentId) => {
useThreadRuntime.getState().composer.removeAttachment(attachmentId);
},
reset: () => {
useThreadRuntime.getState().composer.reset();
},

text: runtime.composer.text,
setText: (value) => {
useThreadRuntime.getState().composer.setText(value);
setText: (text) => {
useThreadRuntime.getState().composer.setText(text);
},

canCancel: runtime.capabilities.cancel,
Expand All @@ -46,12 +64,14 @@ export const makeComposerStore = (
send: () => {
const runtime = useThreadRuntime.getState();
const text = runtime.composer.text;
runtime.composer.setText("");
const attachments = runtime.composer.attachments;
runtime.composer.reset();

runtime.append({
parentId: runtime.messages.at(-1)?.id ?? null,
role: "user",
content: [{ type: "text", text }],
content: text ? [{ type: "text", text }] : [],
attachments,
});
},
cancel: () => {
Expand Down
3 changes: 3 additions & 0 deletions packages/react/src/hooks/useAppendMessage.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ type CreateAppendMessage =
parentId?: string | null | undefined;
role?: AppendMessage["role"] | undefined;
content: AppendMessage["content"];
attachments?: AppendMessage["attachments"] | undefined;
};

const toAppendMessage = (
Expand All @@ -19,6 +20,7 @@ const toAppendMessage = (
parentId: useThreadMessages.getState().at(-1)?.id ?? null,
role: "user",
content: [{ type: "text", text: message }],
attachments: [],
};
}

Expand All @@ -27,6 +29,7 @@ const toAppendMessage = (
message.parentId ?? useThreadMessages.getState().at(-1)?.id ?? null,
role: message.role ?? "user",
content: message.content,
attachments: message.attachments ?? [],
} as AppendMessage;
};

Expand Down
1 change: 1 addition & 0 deletions packages/react/src/internal.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
export { ThreadRuntimeComposer } from "./runtimes/utils/ThreadRuntimeComposer";
export { ProxyConfigProvider } from "./utils/ProxyConfigProvider";
export { MessageRepository } from "./runtimes/utils/MessageRepository";
export { BaseAssistantRuntime } from "./runtimes/core/BaseAssistantRuntime";
Expand Down
7 changes: 7 additions & 0 deletions packages/react/src/runtimes/core/ThreadRuntime.tsx
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { Attachment } from "../../context/stores/Attachment";
import { RuntimeCapabilities } from "../../context/stores/Thread";
import { ThreadActionsState } from "../../context/stores/ThreadActions";
import { ThreadMessage } from "../../types";
Expand All @@ -15,7 +16,13 @@ export type ThreadRuntime = ThreadActionsState &

export declare namespace ThreadRuntime {
export type Composer = Readonly<{
attachments: Attachment[];
addAttachment: (attachment: Attachment) => void;
removeAttachment: (attachmentId: string) => void;

text: string;
setText: (value: string) => void;

reset: () => void;
}>;
}
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { Attachment } from "../../../context/stores/Attachment";
import { generateId } from "../../../internal";
import {
ThreadMessage,
Expand All @@ -13,7 +14,8 @@ export const fromCoreMessages = (
};

export const fromCoreMessage = (
message: CoreMessage,
// TODO clean up this type
message: CoreMessage & { attachments?: readonly Attachment[] | undefined },
{
id = generateId(),
status = { type: "complete", reason: "unknown" } as MessageStatus,
Expand Down Expand Up @@ -47,6 +49,7 @@ export const fromCoreMessage = (
...commonProps,
role,
content: message.content,
attachments: message.attachments ?? [],
} satisfies ThreadMessage;

case "system":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import { fromThreadMessageLike } from "./ThreadMessageLike";
import { RuntimeCapabilities } from "../../context/stores/Thread";
import { getThreadMessageText } from "../../utils/getThreadMessageText";
import { generateId } from "../../internal";
import { ThreadRuntimeComposer } from "../utils/ThreadRuntimeComposer";

export const hasUpcomingMessage = (
isRunning: boolean,
Expand Down Expand Up @@ -46,13 +47,9 @@ export class ExternalStoreThreadRuntime implements ReactThreadRuntime {

private _store!: ExternalStoreAdapter<any>;

public readonly composer = {
text: "",
setText: (value: string) => {
this.composer.text = value;
this.notifySubscribers();
},
};
public readonly composer = new ThreadRuntimeComposer(
this.notifySubscribers.bind(this),
);

constructor(store: ExternalStoreAdapter<any>) {
this.store = store;
Expand Down
11 changes: 10 additions & 1 deletion packages/react/src/runtimes/external-store/ThreadMessageLike.tsx
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { Attachment } from "../../context/stores/Attachment";
import {
MessageStatus,
TextContentPart,
Expand Down Expand Up @@ -27,14 +28,15 @@ export type ThreadMessageLike = {
id?: string | undefined;
createdAt?: Date | undefined;
status?: MessageStatus | undefined;
attachments?: Attachment[] | undefined;
};

export const fromThreadMessageLike = (
like: ThreadMessageLike,
fallbackId: string,
fallbackStatus: MessageStatus,
): ThreadMessage => {
const { role, id, createdAt, status } = like;
const { role, id, createdAt, attachments, status } = like;
const common = {
id: id ?? fallbackId,
createdAt: createdAt ?? new Date(),
Expand All @@ -45,6 +47,12 @@ export const fromThreadMessageLike = (
? [{ type: "text" as const, text: like.content }]
: like.content;

if (role !== "user" && attachments)
throw new Error("Attachments are only supported for user messages");
// TODO add in 0.6
// if (role !== "assistant" && status)
// throw new Error("Status is only supported for assistant messages");

switch (role) {
case "assistant":
return {
Expand Down Expand Up @@ -97,6 +105,7 @@ export const fromThreadMessageLike = (
}
}
}),
attachments: attachments ?? [],
} satisfies ThreadUserMessage;

case "system":
Expand Down
13 changes: 6 additions & 7 deletions packages/react/src/runtimes/local/LocalThreadRuntime.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import {
MessageRepository,
} from "../utils/MessageRepository";
import type { ChatModelAdapter, ChatModelRunResult } from "./ChatModelAdapter";
import { ThreadRuntimeComposer } from "../utils/ThreadRuntimeComposer";
import { shouldContinue } from "./shouldContinue";
import { LocalRuntimeOptions } from "./LocalRuntimeOptions";
import { ThreadRuntime } from "../core";
Expand Down Expand Up @@ -40,13 +41,9 @@ export class LocalThreadRuntime implements ThreadRuntime {
return this.repository.getMessages();
}

public readonly composer = {
text: "",
setText: (value: string) => {
this.composer.text = value;
this.notifySubscribers();
},
};
public readonly composer = new ThreadRuntimeComposer(
this.notifySubscribers.bind(this),
);

constructor(
private configProvider: ModelConfigProvider,
Expand Down Expand Up @@ -91,6 +88,7 @@ export class LocalThreadRuntime implements ThreadRuntime {
}

public async append(message: AppendMessage): Promise<void> {
// TODO add support for assistant appends
if (message.role !== "user")
throw new Error(
"Only appending user messages are supported in LocalRuntime. This is likely an internal bug in assistant-ui.",
Expand All @@ -102,6 +100,7 @@ export class LocalThreadRuntime implements ThreadRuntime {
id: userMessageId,
role: "user",
content: message.content,
attachments: message.attachments ?? [],
createdAt: new Date(),
};
this.repository.addOrUpdateMessage(message.parentId, userMessage);
Expand Down
Loading

0 comments on commit e58d61b

Please sign in to comment.