Skip to content

Commit

Permalink
feat: AttachmentAdapter.accept allow attachment adapters to specify s…
Browse files Browse the repository at this point in the history
…upported file types (#780)
  • Loading branch information
Yonom authored Sep 9, 2024
1 parent f2ec0c7 commit a22e6bb
Show file tree
Hide file tree
Showing 9 changed files with 42 additions and 16 deletions.
6 changes: 6 additions & 0 deletions .changeset/nine-gifts-promise.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
"@assistant-ui/react-ai-sdk": patch
"@assistant-ui/react": patch
---

feat: AttachmentAdapter.accept allow attachment adapters to specify supported file types
1 change: 1 addition & 0 deletions examples/with-ffmpeg/app/MyRuntimeProvider.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import { INTERNAL } from "@assistant-ui/react";
const { generateId } = INTERNAL;

const attachmentAdapter: AttachmentAdapter = {
accept: "image/*,video/*,audio/*",
async add({ file }) {
return {
id: generateId(),
Expand Down
2 changes: 2 additions & 0 deletions packages/react-ai-sdk/src/ui/utils/vercelAttachmentAdapter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ import { AttachmentAdapter } from "@assistant-ui/react";
import { generateId } from "ai";

export const vercelAttachmentAdapter: AttachmentAdapter = {
accept:
"image/*, text/plain, text/html, text/markdown, text/csv, text/xml, text/json, text/css",
async add({ file }) {
return {
id: generateId(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,19 @@ import { useCallback } from "react";
import { useThreadContext } from "../../context";

export const useComposerAddAttachment = () => {
const { useComposer } = useThreadContext();
const { useComposer, useThreadRuntime } = useThreadContext();

const disabled = useComposer((c) => !c.isEditing);

const callback = useCallback(() => {
const { addAttachment } = useComposer.getState();
const { attachmentAccept } = useThreadRuntime.getState().composer;

const input = document.createElement("input");
input.type = "file";
if (attachmentAccept !== "*") {
input.accept = attachmentAccept;
}

input.onchange = (e) => {
const file = (e.target as HTMLInputElement).files?.[0];
Expand All @@ -19,7 +23,7 @@ export const useComposerAddAttachment = () => {
};

input.click();
}, [useComposer]);
}, [useComposer, useThreadRuntime]);

if (disabled) return null;
return callback;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import {
} from "../../context/stores/Attachment";

export type AttachmentAdapter = {
accept: string;
add(state: { file: File }): Promise<ComposerAttachment>;
remove(attachment: ComposerAttachment): Promise<void>;
send(attachment: ComposerAttachment): Promise<MessageAttachment>;
Expand Down
1 change: 1 addition & 0 deletions packages/react/src/runtimes/core/ThreadRuntime.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ export type ThreadRuntime = ThreadActionsState &

export declare namespace ThreadRuntime {
export type Composer = Readonly<{
attachmentAccept: string;
attachments: ComposerAttachment[];
addAttachment: (file: File) => Promise<void>;
removeAttachment: (attachmentId: string) => Promise<void>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ export class ExternalStoreThreadRuntime implements ReactThreadRuntime {
attachments: !!this.store.adapters?.attachments,
};

this.composer.attachmentAdapter = this._store.adapters?.attachments;
this.composer.setAttachmentAdapter(this._store.adapters?.attachments);

if (oldStore) {
// flush the converter cache when the convertMessage prop changes
Expand Down
5 changes: 3 additions & 2 deletions packages/react/src/runtimes/local/LocalThreadRuntime.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,9 @@ export class LocalThreadRuntime implements ThreadRuntime {
hasUpdates = true;
}

this.composer.attachmentAdapter = options.adapters?.attachments;
const canAttach = this.composer.attachmentAdapter !== undefined;
this.composer.setAttachmentAdapter(options.adapters?.attachments);

const canAttach = options.adapters?.attachments !== undefined;
if (this.capabilities.attachments !== canAttach) {
this.capabilities.attachments = canAttach;
hasUpdates = true;
Expand Down
32 changes: 21 additions & 11 deletions packages/react/src/runtimes/utils/ThreadRuntimeComposer.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ import { AttachmentAdapter } from "../attachment/AttachmentAdapter";
import { ThreadRuntime } from "../core";

export class ThreadRuntimeComposer implements ThreadRuntime.Composer {
public attachmentAdapter?: AttachmentAdapter | undefined;
private _attachmentAdapter?: AttachmentAdapter | undefined;

public attachmentAccept: string = "*";

constructor(
private runtime: {
Expand All @@ -14,31 +16,41 @@ export class ThreadRuntimeComposer implements ThreadRuntime.Composer {
private notifySubscribers: () => void,
) {}

public setAttachmentAdapter(adapter: AttachmentAdapter | undefined) {
this._attachmentAdapter = adapter;
const accept = adapter?.accept ?? "*";
if (this.attachmentAccept !== accept) {
this.attachmentAccept = accept;
return true;
}
return false;
}

private _attachments: ComposerAttachment[] = [];

get attachments() {
public get attachments() {
return this._attachments;
}

async addAttachment(file: File) {
if (!this.attachmentAdapter)
if (!this._attachmentAdapter)
throw new Error("Attachments are not supported");

const attachment = await this.attachmentAdapter.add({ file });
const attachment = await this._attachmentAdapter.add({ file });

this._attachments = [...this._attachments, attachment];
this.notifySubscribers();
}

async removeAttachment(attachmentId: string) {
if (!this.attachmentAdapter)
if (!this._attachmentAdapter)
throw new Error("Attachments are not supported");

const index = this._attachments.findIndex((a) => a.id === attachmentId);
if (index === -1) throw new Error("Attachment not found");
const attachment = this._attachments[index]!;

await this.attachmentAdapter.remove(attachment);
await this._attachmentAdapter.remove(attachment);

this._attachments = this._attachments.toSpliced(index, 1);
this.notifySubscribers();
Expand All @@ -62,20 +74,18 @@ export class ThreadRuntimeComposer implements ThreadRuntime.Composer {
}

public async send() {
const attachments = this.attachmentAdapter
const attachments = this._attachmentAdapter
? await Promise.all(
this.attachments.map(
async (a) => await this.attachmentAdapter!.send(a),
async (a) => await this._attachmentAdapter!.send(a),
),
)
: [];

this.runtime.append({
parentId: this.runtime.messages.at(-1)?.id ?? null,
role: "user",
content: this.text
? [{ type: "text", text: this.text }]
: [],
content: this.text ? [{ type: "text", text: this.text }] : [],
attachments,
});
this.reset();
Expand Down

0 comments on commit a22e6bb

Please sign in to comment.