From a22e6bb084b83a82582819fe83336769109a67f4 Mon Sep 17 00:00:00 2001 From: Simon Farshid Date: Sun, 8 Sep 2024 17:00:10 -0700 Subject: [PATCH] feat: AttachmentAdapter.accept allow attachment adapters to specify supported file types (#780) --- .changeset/nine-gifts-promise.md | 6 ++++ .../with-ffmpeg/app/MyRuntimeProvider.tsx | 1 + .../src/ui/utils/vercelAttachmentAdapter.ts | 2 ++ .../composer/useComposerAddAttachment.tsx | 8 +++-- .../runtimes/attachment/AttachmentAdapter.ts | 1 + .../react/src/runtimes/core/ThreadRuntime.tsx | 1 + .../ExternalStoreThreadRuntime.tsx | 2 +- .../src/runtimes/local/LocalThreadRuntime.tsx | 5 +-- .../runtimes/utils/ThreadRuntimeComposer.tsx | 32 ++++++++++++------- 9 files changed, 42 insertions(+), 16 deletions(-) create mode 100644 .changeset/nine-gifts-promise.md diff --git a/.changeset/nine-gifts-promise.md b/.changeset/nine-gifts-promise.md new file mode 100644 index 000000000..55d984d59 --- /dev/null +++ b/.changeset/nine-gifts-promise.md @@ -0,0 +1,6 @@ +--- +"@assistant-ui/react-ai-sdk": patch +"@assistant-ui/react": patch +--- + +feat: AttachmentAdapter.accept allow attachment adapters to specify supported file types diff --git a/examples/with-ffmpeg/app/MyRuntimeProvider.tsx b/examples/with-ffmpeg/app/MyRuntimeProvider.tsx index 8ce722527..9e8b206dd 100644 --- a/examples/with-ffmpeg/app/MyRuntimeProvider.tsx +++ b/examples/with-ffmpeg/app/MyRuntimeProvider.tsx @@ -7,6 +7,7 @@ import { INTERNAL } from "@assistant-ui/react"; const { generateId } = INTERNAL; const attachmentAdapter: AttachmentAdapter = { + accept: "image/*,video/*,audio/*", async add({ file }) { return { id: generateId(), diff --git a/packages/react-ai-sdk/src/ui/utils/vercelAttachmentAdapter.ts b/packages/react-ai-sdk/src/ui/utils/vercelAttachmentAdapter.ts index e5ee3b0cd..7f5e8fc4f 100644 --- a/packages/react-ai-sdk/src/ui/utils/vercelAttachmentAdapter.ts +++ b/packages/react-ai-sdk/src/ui/utils/vercelAttachmentAdapter.ts @@ -2,6 +2,8 @@ import { AttachmentAdapter } from "@assistant-ui/react"; import { generateId } from "ai"; export const vercelAttachmentAdapter: AttachmentAdapter = { + accept: + "image/*, text/plain, text/html, text/markdown, text/csv, text/xml, text/json, text/css", async add({ file }) { return { id: generateId(), diff --git a/packages/react/src/primitive-hooks/composer/useComposerAddAttachment.tsx b/packages/react/src/primitive-hooks/composer/useComposerAddAttachment.tsx index 0f0e7eacf..1ae738da8 100644 --- a/packages/react/src/primitive-hooks/composer/useComposerAddAttachment.tsx +++ b/packages/react/src/primitive-hooks/composer/useComposerAddAttachment.tsx @@ -2,15 +2,19 @@ import { useCallback } from "react"; import { useThreadContext } from "../../context"; export const useComposerAddAttachment = () => { - const { useComposer } = useThreadContext(); + const { useComposer, useThreadRuntime } = useThreadContext(); const disabled = useComposer((c) => !c.isEditing); const callback = useCallback(() => { const { addAttachment } = useComposer.getState(); + const { attachmentAccept } = useThreadRuntime.getState().composer; const input = document.createElement("input"); input.type = "file"; + if (attachmentAccept !== "*") { + input.accept = attachmentAccept; + } input.onchange = (e) => { const file = (e.target as HTMLInputElement).files?.[0]; @@ -19,7 +23,7 @@ export const useComposerAddAttachment = () => { }; input.click(); - }, [useComposer]); + }, [useComposer, useThreadRuntime]); if (disabled) return null; return callback; diff --git a/packages/react/src/runtimes/attachment/AttachmentAdapter.ts b/packages/react/src/runtimes/attachment/AttachmentAdapter.ts index e51b3e15f..0eabccf57 100644 --- a/packages/react/src/runtimes/attachment/AttachmentAdapter.ts +++ b/packages/react/src/runtimes/attachment/AttachmentAdapter.ts @@ -4,6 +4,7 @@ import { } from "../../context/stores/Attachment"; export type AttachmentAdapter = { + accept: string; add(state: { file: File }): Promise; remove(attachment: ComposerAttachment): Promise; send(attachment: ComposerAttachment): Promise; diff --git a/packages/react/src/runtimes/core/ThreadRuntime.tsx b/packages/react/src/runtimes/core/ThreadRuntime.tsx index a4de5f26d..c214d1266 100644 --- a/packages/react/src/runtimes/core/ThreadRuntime.tsx +++ b/packages/react/src/runtimes/core/ThreadRuntime.tsx @@ -16,6 +16,7 @@ export type ThreadRuntime = ThreadActionsState & export declare namespace ThreadRuntime { export type Composer = Readonly<{ + attachmentAccept: string; attachments: ComposerAttachment[]; addAttachment: (file: File) => Promise; removeAttachment: (attachmentId: string) => Promise; diff --git a/packages/react/src/runtimes/external-store/ExternalStoreThreadRuntime.tsx b/packages/react/src/runtimes/external-store/ExternalStoreThreadRuntime.tsx index 79d65ede4..9a2fa9957 100644 --- a/packages/react/src/runtimes/external-store/ExternalStoreThreadRuntime.tsx +++ b/packages/react/src/runtimes/external-store/ExternalStoreThreadRuntime.tsx @@ -80,7 +80,7 @@ export class ExternalStoreThreadRuntime implements ReactThreadRuntime { attachments: !!this.store.adapters?.attachments, }; - this.composer.attachmentAdapter = this._store.adapters?.attachments; + this.composer.setAttachmentAdapter(this._store.adapters?.attachments); if (oldStore) { // flush the converter cache when the convertMessage prop changes diff --git a/packages/react/src/runtimes/local/LocalThreadRuntime.tsx b/packages/react/src/runtimes/local/LocalThreadRuntime.tsx index a85bfe172..2b3baa8b5 100644 --- a/packages/react/src/runtimes/local/LocalThreadRuntime.tsx +++ b/packages/react/src/runtimes/local/LocalThreadRuntime.tsx @@ -81,8 +81,9 @@ export class LocalThreadRuntime implements ThreadRuntime { hasUpdates = true; } - this.composer.attachmentAdapter = options.adapters?.attachments; - const canAttach = this.composer.attachmentAdapter !== undefined; + this.composer.setAttachmentAdapter(options.adapters?.attachments); + + const canAttach = options.adapters?.attachments !== undefined; if (this.capabilities.attachments !== canAttach) { this.capabilities.attachments = canAttach; hasUpdates = true; diff --git a/packages/react/src/runtimes/utils/ThreadRuntimeComposer.tsx b/packages/react/src/runtimes/utils/ThreadRuntimeComposer.tsx index 05f3b7167..0ee2ec80d 100644 --- a/packages/react/src/runtimes/utils/ThreadRuntimeComposer.tsx +++ b/packages/react/src/runtimes/utils/ThreadRuntimeComposer.tsx @@ -4,7 +4,9 @@ import { AttachmentAdapter } from "../attachment/AttachmentAdapter"; import { ThreadRuntime } from "../core"; export class ThreadRuntimeComposer implements ThreadRuntime.Composer { - public attachmentAdapter?: AttachmentAdapter | undefined; + private _attachmentAdapter?: AttachmentAdapter | undefined; + + public attachmentAccept: string = "*"; constructor( private runtime: { @@ -14,31 +16,41 @@ export class ThreadRuntimeComposer implements ThreadRuntime.Composer { private notifySubscribers: () => void, ) {} + public setAttachmentAdapter(adapter: AttachmentAdapter | undefined) { + this._attachmentAdapter = adapter; + const accept = adapter?.accept ?? "*"; + if (this.attachmentAccept !== accept) { + this.attachmentAccept = accept; + return true; + } + return false; + } + private _attachments: ComposerAttachment[] = []; - get attachments() { + public get attachments() { return this._attachments; } async addAttachment(file: File) { - if (!this.attachmentAdapter) + if (!this._attachmentAdapter) throw new Error("Attachments are not supported"); - const attachment = await this.attachmentAdapter.add({ file }); + const attachment = await this._attachmentAdapter.add({ file }); this._attachments = [...this._attachments, attachment]; this.notifySubscribers(); } async removeAttachment(attachmentId: string) { - if (!this.attachmentAdapter) + if (!this._attachmentAdapter) throw new Error("Attachments are not supported"); const index = this._attachments.findIndex((a) => a.id === attachmentId); if (index === -1) throw new Error("Attachment not found"); const attachment = this._attachments[index]!; - await this.attachmentAdapter.remove(attachment); + await this._attachmentAdapter.remove(attachment); this._attachments = this._attachments.toSpliced(index, 1); this.notifySubscribers(); @@ -62,10 +74,10 @@ export class ThreadRuntimeComposer implements ThreadRuntime.Composer { } public async send() { - const attachments = this.attachmentAdapter + const attachments = this._attachmentAdapter ? await Promise.all( this.attachments.map( - async (a) => await this.attachmentAdapter!.send(a), + async (a) => await this._attachmentAdapter!.send(a), ), ) : []; @@ -73,9 +85,7 @@ export class ThreadRuntimeComposer implements ThreadRuntime.Composer { this.runtime.append({ parentId: this.runtime.messages.at(-1)?.id ?? null, role: "user", - content: this.text - ? [{ type: "text", text: this.text }] - : [], + content: this.text ? [{ type: "text", text: this.text }] : [], attachments, }); this.reset();