Skip to content

Commit

Permalink
feat: Edge/Local runtime AttachmentAdapter support (#774)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yonom authored Sep 8, 2024
1 parent 19c6365 commit 44d08bd
Show file tree
Hide file tree
Showing 14 changed files with 99 additions and 29 deletions.
5 changes: 5 additions & 0 deletions .changeset/early-ears-juggle.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@assistant-ui/react": patch
---

feat: styled components for attachments
6 changes: 6 additions & 0 deletions .changeset/unlucky-pants-destroy.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
"@assistant-ui/react-playground": patch
"@assistant-ui/react": patch
---

feat: Edge/Local runtime AttachmentAdapter support
1 change: 1 addition & 0 deletions packages/react-playground/src/lib/playground-runtime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ export class PlaygroundThreadRuntime implements ReactThreadRuntime {
private configProvider = new ProxyConfigProvider();

public readonly composer = new ThreadRuntimeComposer(
this,
this.notifySubscribers.bind(this),
);

Expand Down
17 changes: 4 additions & 13 deletions packages/react/src/context/stores/Composer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ export type ComposerState = Readonly<{
setValue: (value: string) => void;

attachments: readonly Attachment[];
addAttachment: (attachment: Attachment) => void;
addAttachment: (file: File) => void;
removeAttachment: (attachmentId: string) => void;

text: string;
Expand Down Expand Up @@ -43,8 +43,8 @@ export const makeComposerStore = (
},

attachments: runtime.composer.attachments,
addAttachment: (attachment) => {
useThreadRuntime.getState().composer.addAttachment(attachment);
addAttachment: (file) => {
useThreadRuntime.getState().composer.addAttachment(file);
},
removeAttachment: (attachmentId) => {
useThreadRuntime.getState().composer.removeAttachment(attachmentId);
Expand All @@ -63,16 +63,7 @@ export const makeComposerStore = (

send: () => {
const runtime = useThreadRuntime.getState();
const text = runtime.composer.text;
const attachments = runtime.composer.attachments;
runtime.composer.reset();

runtime.append({
parentId: runtime.messages.at(-1)?.id ?? null,
role: "user",
content: text ? [{ type: "text", text }] : [],
attachments,
});
runtime.composer.send();
},
cancel: () => {
useThreadRuntime.getState().cancelRun();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import { useCallback } from "react";
import { useThreadContext } from "../../context";
import { generateId } from "../../internal";

export const useComposerAddAttachment = () => {
const { useComposer } = useThreadContext();
Expand All @@ -16,12 +15,7 @@ export const useComposerAddAttachment = () => {
input.onchange = (e) => {
const file = (e.target as HTMLInputElement).files?.[0];
if (!file) return;
addAttachment({
id: generateId(),
type: "file", // TODO infer type from file extension or mimetype
name: file.name,
file,
});
addAttachment(file);
};

input.click();
Expand Down
10 changes: 10 additions & 0 deletions packages/react/src/runtimes/attachment/AttachmentAdapter.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import { Attachment } from "../../context/stores/Attachment";
import { CoreUserContentPart } from "../../types";

export type AttachmentAdapter = {
add(state: { file: File }): Promise<Attachment>;
send(attachment: Attachment): Promise<{
content: CoreUserContentPart[];
}>;
remove(attachment: Attachment): Promise<void>;
};
1 change: 1 addition & 0 deletions packages/react/src/runtimes/attachment/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
export type { AttachmentAdapter } from "./AttachmentAdapter";
6 changes: 4 additions & 2 deletions packages/react/src/runtimes/core/ThreadRuntime.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@ export type ThreadRuntime = ThreadActionsState &
export declare namespace ThreadRuntime {
export type Composer = Readonly<{
attachments: Attachment[];
addAttachment: (attachment: Attachment) => void;
removeAttachment: (attachmentId: string) => void;
addAttachment: (file: File) => Promise<void>;
removeAttachment: (attachmentId: string) => Promise<void>;

text: string;
setText: (value: string) => void;

reset: () => void;

send: () => void;
}>;
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ export class ExternalStoreThreadRuntime implements ReactThreadRuntime {
private _store!: ExternalStoreAdapter<any>;

public readonly composer = new ThreadRuntimeComposer(
this,
this.notifySubscribers.bind(this),
);

Expand Down
1 change: 1 addition & 0 deletions packages/react/src/runtimes/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ export * from "./edge";
export * from "./external-store";
export * from "./dangerous-in-browser";
export * from "./speech";
export * from "./attachment";
2 changes: 2 additions & 0 deletions packages/react/src/runtimes/local/LocalRuntimeOptions.tsx
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import type { CoreMessage } from "../../types";
import { AttachmentAdapter } from "../attachment/AttachmentAdapter";
import { SpeechSynthesisAdapter } from "../speech/SpeechAdapterTypes";

export type LocalRuntimeOptions = {
initialMessages?: readonly CoreMessage[] | undefined;
maxToolRoundtrips?: number | undefined;
adapters?:
| {
attachments?: AttachmentAdapter | undefined;
speech?: SpeechSynthesisAdapter | undefined;
}
| undefined;
Expand Down
14 changes: 13 additions & 1 deletion packages/react/src/runtimes/local/LocalThreadRuntime.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ export class LocalThreadRuntime implements ThreadRuntime {
}

public readonly composer = new ThreadRuntimeComposer(
this,
this.notifySubscribers.bind(this),
);

Expand Down Expand Up @@ -72,11 +73,22 @@ export class LocalThreadRuntime implements ThreadRuntime {
public set options({ initialMessages, ...options }: LocalRuntimeOptions) {
this._options = options;

let hasUpdates = false;

const canSpeak = options.adapters?.speech !== undefined;
if (this.capabilities.speak !== canSpeak) {
this.capabilities.speak = canSpeak;
this.notifySubscribers();
hasUpdates = true;
}

this.composer.adapter = options.adapters?.attachments;
const canAttach = this.composer.adapter !== undefined;
if (this.capabilities.attachments !== canAttach) {
this.capabilities.attachments = canAttach;
hasUpdates = true;
}

if (hasUpdates) this.notifySubscribers();
}

public getBranches(messageId: string): string[] {
Expand Down
51 changes: 47 additions & 4 deletions packages/react/src/runtimes/utils/ThreadRuntimeComposer.tsx
Original file line number Diff line number Diff line change
@@ -1,22 +1,44 @@
import { Attachment } from "../../context/stores/Attachment";
import { AppendMessage } from "../../types";
import { AttachmentAdapter } from "../attachment/AttachmentAdapter";
import { ThreadRuntime } from "../core";

export class ThreadRuntimeComposer implements ThreadRuntime.Composer {
constructor(private notifySubscribers: () => void) {}
public adapter?: AttachmentAdapter | undefined;

constructor(
private runtime: {
messages: ThreadRuntime["messages"];
append: (message: AppendMessage) => void;
},
private notifySubscribers: () => void,
) {}

private _attachments: Attachment[] = [];

get attachments() {
return this._attachments;
}

addAttachment(attachment: Attachment) {
async addAttachment(file: File) {
if (!this.adapter) throw new Error("Attachments are not supported");

const attachment = await this.adapter.add({ file });

this._attachments = [...this._attachments, attachment];
this.notifySubscribers();
}

removeAttachment(attachmentId: string) {
this._attachments = this._attachments.filter((a) => a.id !== attachmentId);
async removeAttachment(attachmentId: string) {
if (!this.adapter) throw new Error("Attachments are not supported");

const index = this._attachments.findIndex((a) => a.id === attachmentId);
if (index === -1) throw new Error("Attachment not found");
const attachment = this._attachments[index]!;

await this.adapter.remove(attachment);

this._attachments = this._attachments.toSpliced(index, 1);
this.notifySubscribers();
}

Expand All @@ -36,4 +58,25 @@ export class ThreadRuntimeComposer implements ThreadRuntime.Composer {
this._attachments = [];
this.notifySubscribers();
}

public async send() {
const attachmentContentParts = this.adapter
? await Promise.all(
this.attachments.map(async (a) => {
const { content } = await this.adapter!.send(a);
return content;
}),
)
: [];

this.runtime.append({
parentId: this.runtime.messages.at(-1)?.id ?? null,
role: "user",
content: this.text
? [{ type: "text", text: this.text }, ...attachmentContentParts.flat()]
: [],
attachments: this.attachments,
});
this.reset();
}
}
5 changes: 3 additions & 2 deletions packages/react/src/ui/composer.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,13 @@ const ComposerRemoveAttachment = forwardRef<
} = useThreadConfig();

const { useComposer } = useThreadContext();
const { useAttachment } = useAttachmentContext();
const handleRemoveAttachment = () => {
// TODO delete the correct attachment
useComposer
.getState()
.removeAttachment(useComposer.getState().attachments[0]?.id!);
.removeAttachment(useAttachment.getState().attachment.id);
};

return (
<TooltipIconButton
tooltip={tooltip}
Expand Down

0 comments on commit 44d08bd

Please sign in to comment.