diff --git a/.changeset/quick-masks-jump.md b/.changeset/quick-masks-jump.md new file mode 100644 index 000000000..f1f5684cc --- /dev/null +++ b/.changeset/quick-masks-jump.md @@ -0,0 +1,6 @@ +--- +"@assistant-ui/react-playground": patch +"@assistant-ui/react": patch +--- + +feat: Runtime.path API diff --git a/packages/react-playground/src/lib/playground-runtime.ts b/packages/react-playground/src/lib/playground-runtime.ts index af9e5b2b0..5230457a0 100644 --- a/packages/react-playground/src/lib/playground-runtime.ts +++ b/packages/react-playground/src/lib/playground-runtime.ts @@ -522,6 +522,15 @@ export class PlaygroundThreadRuntimeCore implements INTERNAL.ThreadRuntimeCore { public export(): never { throw new Error("Playground does not support exporting messages."); } + + public getMessageById(messageId: string) { + const idx = this.messages.findIndex((m) => m.id === messageId); + if (idx === -1) return undefined; + return { + message: this.messages[idx]!, + parentId: this.messages[idx - 1]?.id ?? null, + }; + } } type PlaygroundThreadRuntime = ThreadRuntime & { @@ -558,7 +567,7 @@ class PlaygroundRuntimeImpl public static override create(_core: PlaygroundRuntimeCore) { return new PlaygroundRuntimeImpl( _core, - AssistantRuntimeImpl.createThreadRuntime( + AssistantRuntimeImpl.createMainThreadRuntime( _core, PlaygroundThreadRuntimeImpl, ), diff --git a/packages/react/src/api/AssistantRuntime.ts b/packages/react/src/api/AssistantRuntime.ts index 69d31eca7..f17bef00e 100644 --- a/packages/react/src/api/AssistantRuntime.ts +++ b/packages/react/src/api/AssistantRuntime.ts @@ -28,7 +28,7 @@ export type AssistantRuntime = { }; export class AssistantRuntimeImpl - implements AssistantRuntimeCore, AssistantRuntime + implements Omit, AssistantRuntime { protected constructor( private readonly _core: AssistantRuntimeCore, @@ -64,7 +64,7 @@ export class AssistantRuntimeImpl return this._core.subscribe(callback); } - protected static createThreadRuntime( + protected static createMainThreadRuntime( _core: AssistantRuntimeCore, CustomThreadRuntime: new ( binding: ThreadRuntimeCoreBinding, @@ -72,6 +72,10 @@ export class AssistantRuntimeImpl ) { return new CustomThreadRuntime( new NestedSubscriptionSubject({ + path: { + ref: "threads.main", + threadSelector: { type: "main" }, + }, getState: () => _core.thread, subscribe: (callback) => _core.subscribe(callback), }), @@ -86,7 +90,7 @@ export class AssistantRuntimeImpl ) { return new AssistantRuntimeImpl( _core, - AssistantRuntimeImpl.createThreadRuntime(_core, CustomThreadRuntime), + AssistantRuntimeImpl.createMainThreadRuntime(_core, CustomThreadRuntime), ) as AssistantRuntime; } } diff --git a/packages/react/src/api/AttachmentRuntime.ts b/packages/react/src/api/AttachmentRuntime.ts index f025d1625..f994119a8 100644 --- a/packages/react/src/api/AttachmentRuntime.ts +++ b/packages/react/src/api/AttachmentRuntime.ts @@ -7,6 +7,7 @@ import { PendingAttachment, Unsubscribe, } from "../types"; +import { AttachmentRuntimePath } from "./PathTypes"; type MessageAttachmentState = CompleteAttachment & { source: "message"; @@ -38,13 +39,17 @@ export type AttachmentState = | MessageAttachmentState; type AttachmentSnapshotBinding = - SubscribableWithState; + SubscribableWithState< + AttachmentState & { source: Source }, + AttachmentRuntimePath & { attachmentSource: Source } + >; type AttachmentRuntimeSource = AttachmentState["source"]; export type AttachmentRuntime< TSource extends AttachmentRuntimeSource = AttachmentRuntimeSource, > = { + path: AttachmentRuntimePath & { attachmentSource: TSource }; readonly source: TSource; getState(): AttachmentState & { source: TSource }; remove(): Promise; @@ -55,6 +60,10 @@ export abstract class AttachmentRuntimeImpl< Source extends AttachmentRuntimeSource = AttachmentRuntimeSource, > implements AttachmentRuntime { + public get path() { + return this._core.path; + } + public abstract get source(): Source; constructor(private _core: AttachmentSnapshotBinding) {} diff --git a/packages/react/src/api/ComposerRuntime.ts b/packages/react/src/api/ComposerRuntime.ts index d2c9f342d..6312feb9d 100644 --- a/packages/react/src/api/ComposerRuntime.ts +++ b/packages/react/src/api/ComposerRuntime.ts @@ -14,13 +14,21 @@ import { } from "./AttachmentRuntime"; import { ShallowMemoizeSubject } from "./subscribable/ShallowMemoizeSubject"; import { SKIP_UPDATE } from "./subscribable/SKIP_UPDATE"; +import { ComposerRuntimePath } from "./PathTypes"; export type ThreadComposerRuntimeCoreBinding = SubscribableWithState< - ThreadComposerRuntimeCore | undefined + ThreadComposerRuntimeCore | undefined, + ComposerRuntimePath & { composerSource: "thread" } +>; + +export type EditComposerRuntimeCoreBinding = SubscribableWithState< + ComposerRuntimeCore | undefined, + ComposerRuntimePath & { composerSource: "edit" } >; export type ComposerRuntimeCoreBinding = SubscribableWithState< - ComposerRuntimeCore | undefined + ComposerRuntimeCore | undefined, + ComposerRuntimePath >; type LegacyEditComposerState = Readonly<{ @@ -165,6 +173,7 @@ const getEditComposerState = ( }; export type ComposerRuntime = { + path: ComposerRuntimePath; readonly type: "edit" | "thread"; getState(): ComposerState; @@ -207,6 +216,10 @@ export type ComposerRuntime = { export abstract class ComposerRuntimeImpl implements ComposerRuntimeCore, ComposerRuntime { + public get path() { + return this._core.path; + } + public abstract get type(): "edit" | "thread"; constructor(protected _core: ComposerRuntimeCoreBinding) {} @@ -318,6 +331,7 @@ export type ThreadComposerRuntime = Omit< ComposerRuntime, "getState" | "getAttachmentByIndex" > & { + readonly path: ComposerRuntimePath & { composerSource: "thread" }; readonly type: "thread"; getState(): ThreadComposerState; @@ -335,6 +349,12 @@ export class ThreadComposerRuntimeImpl extends ComposerRuntimeImpl implements ThreadComposerRuntime, ThreadComposerState { + public override get path() { + return this._core.path as ComposerRuntimePath & { + composerSource: "thread"; + }; + } + public get type() { return "thread" as const; } @@ -343,10 +363,12 @@ export class ThreadComposerRuntimeImpl constructor(core: ThreadComposerRuntimeCoreBinding) { const stateBinding = new LazyMemoizeSubject({ + path: core.path, getState: () => getThreadComposerState(core.getState()), subscribe: (callback) => core.subscribe(callback), }); super({ + path: core.path, getState: () => core.getState(), subscribe: (callback) => stateBinding.subscribe(callback), }); @@ -364,6 +386,12 @@ export class ThreadComposerRuntimeImpl public getAttachmentByIndex(idx: number) { return new ThreadComposerAttachmentRuntimeImpl( new ShallowMemoizeSubject({ + path: { + ...this.path, + attachmentSource: "thread-composer", + attachmentSelector: { type: "index", index: idx }, + ref: this.path.ref + `${this.path.ref}.attachments[${idx}]`, + }, getState: () => { const attachments = this.getState().attachments; const attachment = attachments[idx]; @@ -386,6 +414,7 @@ export type EditComposerRuntime = Omit< ComposerRuntime, "getState" | "getAttachmentByIndex" > & { + readonly path: ComposerRuntimePath & { composerSource: "edit" }; readonly type: "edit"; getState(): EditComposerState; @@ -405,21 +434,27 @@ export class EditComposerRuntimeImpl extends ComposerRuntimeImpl implements EditComposerRuntime, EditComposerState { + public override get path() { + return this._core.path as ComposerRuntimePath & { composerSource: "edit" }; + } + public get type() { return "edit" as const; } private _getState; constructor( - core: ComposerRuntimeCoreBinding, + core: EditComposerRuntimeCoreBinding, private _beginEdit: () => void, ) { const stateBinding = new LazyMemoizeSubject({ + path: core.path, getState: () => getEditComposerState(core.getState(), this._beginEdit), subscribe: (callback) => core.subscribe(callback), }); super({ + path: core.path, getState: () => core.getState(), subscribe: (callback) => stateBinding.subscribe(callback), }); @@ -445,6 +480,12 @@ export class EditComposerRuntimeImpl public getAttachmentByIndex(idx: number) { return new EditComposerAttachmentRuntimeImpl( new ShallowMemoizeSubject({ + path: { + ...this.path, + attachmentSource: "edit-composer", + attachmentSelector: { type: "index", index: idx }, + ref: this.path.ref + `${this.path.ref}.attachments[${idx}]`, + }, getState: () => { const attachments = this.getState().attachments; const attachment = attachments[idx]; diff --git a/packages/react/src/api/ContentPartRuntime.ts b/packages/react/src/api/ContentPartRuntime.ts index 5e6d2527b..0a53f1bf0 100644 --- a/packages/react/src/api/ContentPartRuntime.ts +++ b/packages/react/src/api/ContentPartRuntime.ts @@ -8,6 +8,7 @@ import { ThreadRuntimeCoreBinding } from "./ThreadRuntime"; import { MessageStateBinding } from "./MessageRuntime"; import { SubscribableWithState } from "./subscribable/Subscribable"; import { Unsubscribe } from "../types"; +import { ContentPartRuntimePath } from "./PathTypes"; export type ContentPartState = ( | ThreadUserContentPart @@ -20,15 +21,24 @@ export type ContentPartState = ( status: ContentPartStatus | ToolCallContentPartStatus; }; -type ContentPartSnapshotBinding = SubscribableWithState; +type ContentPartSnapshotBinding = SubscribableWithState< + ContentPartState, + ContentPartRuntimePath +>; export type ContentPartRuntime = { + path: ContentPartRuntimePath; + getState(): ContentPartState; addToolResult(result: any): void; subscribe(callback: () => void): Unsubscribe; }; export class ContentPartRuntimeImpl implements ContentPartRuntime { + public get path() { + return this.contentBinding.path; + } + constructor( private contentBinding: ContentPartSnapshotBinding, private messageApi: MessageStateBinding, diff --git a/packages/react/src/api/MessageRuntime.ts b/packages/react/src/api/MessageRuntime.ts index 4a1385386..42669f52a 100644 --- a/packages/react/src/api/MessageRuntime.ts +++ b/packages/react/src/api/MessageRuntime.ts @@ -149,6 +149,10 @@ export type MessageRuntime = { }; export class MessageRuntimeImpl implements MessageRuntime { + public get path() { + return this._core.path; + } + constructor( private _core: MessageStateBinding, private _threadBinding: ThreadRuntimeCoreBinding, diff --git a/packages/react/src/api/PathTypes.ts b/packages/react/src/api/PathTypes.ts new file mode 100644 index 000000000..fce794591 --- /dev/null +++ b/packages/react/src/api/PathTypes.ts @@ -0,0 +1,47 @@ +export type ThreadRuntimePath = { + ref: string; + threadSelector: { type: "main" }; +}; + +export type MessageRuntimePath = ThreadRuntimePath & { + messageSelector: + | { type: "messageId"; messageId: string } + | { type: "index"; index: number }; +}; + +export type ContentPartRuntimePath = MessageRuntimePath & { + contentPartSelector: + | { type: "index"; index: number } + | { type: "toolCallId"; toolCallId: string }; +}; + +export type AttachmentRuntimePath = ( + | (MessageRuntimePath & { + attachmentSource: "message" | "edit-composer"; + }) + | (ThreadRuntimePath & { + attachmentSource: "thread-composer"; + }) +) & { + attachmentSelector: + | { + type: "index"; + index: number; + } + | { + type: "index"; + index: number; + } + | { + type: "index"; + index: number; + }; +}; + +export type ComposerRuntimePath = + | (ThreadRuntimePath & { + composerSource: "thread"; + }) + | (MessageRuntimePath & { + composerSource: "edit"; + }); diff --git a/packages/react/src/api/ThreadRuntime.ts b/packages/react/src/api/ThreadRuntime.ts index e3ab921ee..3eccc3e84 100644 --- a/packages/react/src/api/ThreadRuntime.ts +++ b/packages/react/src/api/ThreadRuntime.ts @@ -32,6 +32,7 @@ import { import { LazyMemoizeSubject } from "./subscribable/LazyMemoizeSubject"; import { SKIP_UPDATE } from "./subscribable/SKIP_UPDATE"; import { ComposerRuntimeCore } from "../runtimes/core/ComposerRuntimeCore"; +import { MessageRuntimePath, ThreadRuntimePath } from "./PathTypes"; export type CreateAppendMessage = | string @@ -67,10 +68,12 @@ const toAppendMessage = ( } as AppendMessage; }; -export type ThreadRuntimeCoreBinding = - SubscribableWithState & { - outerSubscribe(callback: () => void): Unsubscribe; - }; +export type ThreadRuntimeCoreBinding = SubscribableWithState< + ThreadRuntimeCore, + ThreadRuntimePath +> & { + outerSubscribe(callback: () => void): Unsubscribe; +}; export type ThreadState = Readonly<{ threadId: string; @@ -105,7 +108,9 @@ export const getThreadState = (runtime: ThreadRuntimeCore): ThreadState => { }; export type ThreadRuntime = { - composer: ThreadComposerRuntime; + readonly path: ThreadRuntimePath; + + readonly composer: ThreadComposerRuntime; getState(): ThreadState; /** @@ -219,7 +224,9 @@ export type ThreadRuntime = { export class ThreadRuntimeImpl implements Omit, ThreadRuntime { - // public path = "assistant.threads[main]"; // TODO + public get path() { + return this._threadBinding.path; + } /** * @deprecated Use `getState().threadId` instead. This will be removed in 0.6.0. @@ -284,13 +291,16 @@ export class ThreadRuntimeImpl private _threadBinding: ThreadRuntimeCoreBinding & { getStateState(): ThreadState; }; + constructor(threadBinding: ThreadRuntimeCoreBinding) { const stateBinding = new LazyMemoizeSubject({ + path: threadBinding.path, getState: () => getThreadState(threadBinding.getState()), subscribe: (callback) => threadBinding.subscribe(callback), }); this._threadBinding = { + path: threadBinding.path, getState: () => threadBinding.getState(), getStateState: () => stateBinding.getState(), outerSubscribe: (callback) => threadBinding.outerSubscribe(callback), @@ -300,6 +310,11 @@ export class ThreadRuntimeImpl public readonly composer = new ThreadComposerRuntimeImpl( new NestedSubscriptionSubject({ + path: { + ...this.path, + ref: this.path.ref + `${this.path.ref}.composer`, + composerSource: "thread", + }, getState: () => this._threadBinding.getState().composer, subscribe: (callback) => this._threadBinding.subscribe(callback), }), @@ -402,30 +417,46 @@ export class ThreadRuntimeImpl public getMesssageByIndex(idx: number) { if (idx < 0) throw new Error("Message index must be >= 0"); - return this._getMessageRuntime(() => { - const messages = this._threadBinding.getState().messages; - const message = messages[idx]; - if (!message) return undefined; - return { - message, - parentId: messages[idx - 1]?.id ?? null, - }; - }); + return this._getMessageRuntime( + { + ...this.path, + ref: this.path.ref + `${this.path.ref}.messages[${idx}]`, + messageSelector: { type: "index", index: idx }, + }, + () => { + const messages = this._threadBinding.getState().messages; + const message = messages[idx]; + if (!message) return undefined; + return { + message, + parentId: messages[idx - 1]?.id ?? null, + }; + }, + ); } public getMesssageById(messageId: string) { - return this._getMessageRuntime(() => - this._threadBinding.getState().getMessageById(messageId), + return this._getMessageRuntime( + { + ...this.path, + ref: + this.path.ref + + `${this.path.ref}.messages[messageId=${JSON.stringify(messageId)}]`, + messageSelector: { type: "messageId", messageId: messageId }, + }, + () => this._threadBinding.getState().getMessageById(messageId), ); } private _getMessageRuntime( + path: MessageRuntimePath, callback: () => | { parentId: string | null; message: ThreadMessage } | undefined, ) { return new MessageRuntimeImpl( new ShallowMemoizeSubject({ + path, getState: () => { const { message, parentId } = callback() ?? {}; @@ -464,7 +495,7 @@ export class ThreadRuntimeImpl private _eventListenerNestedSubscriptions = new Map< string, - NestedSubscriptionSubject + NestedSubscriptionSubject >(); public unstable_on( @@ -474,6 +505,7 @@ export class ThreadRuntimeImpl let subject = this._eventListenerNestedSubscriptions.get(event); if (!subject) { subject = new NestedSubscriptionSubject({ + path: this.path, getState: () => ({ subscribe: (callback) => this._threadBinding.getState().unstable_on(event, callback), diff --git a/packages/react/src/api/subscribable/LazyMemoizeSubject.ts b/packages/react/src/api/subscribable/LazyMemoizeSubject.ts index f0fcaff47..7a3edfedc 100644 --- a/packages/react/src/api/subscribable/LazyMemoizeSubject.ts +++ b/packages/react/src/api/subscribable/LazyMemoizeSubject.ts @@ -2,16 +2,22 @@ import { BaseSubject } from "./BaseSubject"; import { SKIP_UPDATE } from "./SKIP_UPDATE"; import { SubscribableWithState } from "./Subscribable"; -export class LazyMemoizeSubject +export class LazyMemoizeSubject extends BaseSubject - implements SubscribableWithState + implements SubscribableWithState { - constructor(private binding: SubscribableWithState) { + public get path() { + return this.binding.path; + } + + constructor( + private binding: SubscribableWithState, + ) { super(); } private _previousStateDirty = true; - private _previousState: T | undefined; + private _previousState: TState | undefined; public getState = () => { if (!this.isConnected || this._previousStateDirty) { const newState = this.binding.getState(); diff --git a/packages/react/src/api/subscribable/NestedSubscriptionSubject.ts b/packages/react/src/api/subscribable/NestedSubscriptionSubject.ts index 50718e06f..953594ed0 100644 --- a/packages/react/src/api/subscribable/NestedSubscriptionSubject.ts +++ b/packages/react/src/api/subscribable/NestedSubscriptionSubject.ts @@ -6,11 +6,15 @@ import { SubscribableWithState, } from "./Subscribable"; -export class NestedSubscriptionSubject +export class NestedSubscriptionSubject extends BaseSubject - implements SubscribableWithState, NestedSubscribable + implements SubscribableWithState, NestedSubscribable { - constructor(private binding: NestedSubscribable) { + public get path() { + return this.binding.path; + } + + constructor(private binding: NestedSubscribable) { super(); } diff --git a/packages/react/src/api/subscribable/ShallowMemoizeSubject.ts b/packages/react/src/api/subscribable/ShallowMemoizeSubject.ts index 85495d7ba..0386a43fc 100644 --- a/packages/react/src/api/subscribable/ShallowMemoizeSubject.ts +++ b/packages/react/src/api/subscribable/ShallowMemoizeSubject.ts @@ -3,11 +3,17 @@ import { BaseSubject } from "./BaseSubject"; import { SubscribableWithState } from "./Subscribable"; import { SKIP_UPDATE } from "./SKIP_UPDATE"; -export class ShallowMemoizeSubject +export class ShallowMemoizeSubject extends BaseSubject - implements SubscribableWithState + implements SubscribableWithState { - constructor(private binding: SubscribableWithState) { + public get path() { + return this.binding.path; + } + + constructor( + private binding: SubscribableWithState, + ) { super(); const state = binding.getState(); if (state === SKIP_UPDATE) @@ -15,7 +21,7 @@ export class ShallowMemoizeSubject this._previousState = state; } - private _previousState: T; + private _previousState: TState; public getState = () => { if (!this.isConnected) this._syncState(); return this._previousState; diff --git a/packages/react/src/api/subscribable/Subscribable.ts b/packages/react/src/api/subscribable/Subscribable.ts index 66e6b8855..1801269cf 100644 --- a/packages/react/src/api/subscribable/Subscribable.ts +++ b/packages/react/src/api/subscribable/Subscribable.ts @@ -4,9 +4,12 @@ export type Subscribable = { subscribe: (callback: () => void) => Unsubscribe; }; -export type SubscribableWithState = Subscribable & { +export type SubscribableWithState = Subscribable & { + path: TPath; getState: () => TState; }; -export type NestedSubscribable = - SubscribableWithState; +export type NestedSubscribable< + TState extends Subscribable | undefined, + TPath, +> = SubscribableWithState; diff --git a/packages/react/src/runtimes/core/BaseThreadRuntimeCore.tsx b/packages/react/src/runtimes/core/BaseThreadRuntimeCore.tsx index 7e8fb79e6..fb61ef438 100644 --- a/packages/react/src/runtimes/core/BaseThreadRuntimeCore.tsx +++ b/packages/react/src/runtimes/core/BaseThreadRuntimeCore.tsx @@ -21,6 +21,7 @@ import { DefaultEditComposerRuntimeCore } from "../composer/DefaultEditComposerR import { SpeechSynthesisAdapter } from "../speech"; import { FeedbackAdapter } from "../feedback/FeedbackAdapter"; import { AttachmentAdapter } from "../attachment"; +import { getThreadMessageText } from "../../utils/getThreadMessageText"; type BaseThreadAdapters = { speech?: SpeechSynthesisAdapter | undefined; @@ -76,6 +77,10 @@ export abstract class BaseThreadRuntimeCore implements ThreadRuntimeCore { this._notifySubscribers(); } + public getMessageById(messageId: string) { + return this.repository.getMessage(messageId); + } + public getBranches(messageId: string): string[] { return this.repository.getBranches(messageId); } @@ -129,7 +134,7 @@ export abstract class BaseThreadRuntimeCore implements ThreadRuntimeCore { this._stopSpeaking?.(); - const utterance = adapter.speak(message); + const utterance = adapter.speak(getThreadMessageText(message)); const unsub = utterance.subscribe(() => { if (utterance.status.type === "ended") { this._stopSpeaking = undefined; diff --git a/packages/react/src/runtimes/local/useLocalRuntime.tsx b/packages/react/src/runtimes/local/useLocalRuntime.tsx index cf12ee4ce..f7035fb80 100644 --- a/packages/react/src/runtimes/local/useLocalRuntime.tsx +++ b/packages/react/src/runtimes/local/useLocalRuntime.tsx @@ -30,7 +30,7 @@ class LocalRuntimeImpl extends AssistantRuntimeImpl implements LocalRuntime { public static override create(_core: LocalRuntimeCore) { return new LocalRuntimeImpl( _core, - AssistantRuntimeImpl.createThreadRuntime(_core, ThreadRuntimeImpl), + AssistantRuntimeImpl.createMainThreadRuntime(_core, ThreadRuntimeImpl), ) as LocalRuntime; } }