From 7c769390c64df951c1998ae1d3e314068b889e70 Mon Sep 17 00:00:00 2001 From: Simon Farshid Date: Mon, 14 Oct 2024 16:40:11 -0700 Subject: [PATCH] feat: ThreadRuntime.getMesssageById (#1012) --- .changeset/unlucky-clocks-mate.md | 5 ++ packages/react/src/api/MessageRuntime.ts | 9 ++++ packages/react/src/api/ThreadRuntime.ts | 49 ++++++++++++++++--- .../src/runtimes/core/ThreadRuntimeCore.tsx | 7 +++ 4 files changed, 64 insertions(+), 6 deletions(-) create mode 100644 .changeset/unlucky-clocks-mate.md diff --git a/.changeset/unlucky-clocks-mate.md b/.changeset/unlucky-clocks-mate.md new file mode 100644 index 000000000..0dec6a3d3 --- /dev/null +++ b/.changeset/unlucky-clocks-mate.md @@ -0,0 +1,5 @@ +--- +"@assistant-ui/react": patch +--- + +feat: ThreadRuntime.getMesssageById diff --git a/packages/react/src/api/MessageRuntime.ts b/packages/react/src/api/MessageRuntime.ts index cebf82ea0..24e72e575 100644 --- a/packages/react/src/api/MessageRuntime.ts +++ b/packages/react/src/api/MessageRuntime.ts @@ -108,6 +108,9 @@ export type MessageState = ThreadMessage & { branchNumber: number; branchCount: number; + /** + * @deprecated This API is still under active development and might change without notice. + */ speech: SpeechState | undefined; submittedFeedback: SubmittedFeedback | undefined; }; @@ -119,7 +122,13 @@ export type MessageRuntime = { getState(): MessageState; reload(): void; + /** + * @deprecated This API is still under active development and might change without notice. + */ speak(): void; + /** + * @deprecated This API is still under active development and might change without notice. + */ stopSpeaking(): void; submitFeedback({ type }: { type: "positive" | "negative" }): void; switchToBranch({ diff --git a/packages/react/src/api/ThreadRuntime.ts b/packages/react/src/api/ThreadRuntime.ts index 9b1b489d4..e3ab921ee 100644 --- a/packages/react/src/api/ThreadRuntime.ts +++ b/packages/react/src/api/ThreadRuntime.ts @@ -80,6 +80,10 @@ export type ThreadState = Readonly<{ messages: readonly ThreadMessage[]; suggestions: readonly ThreadSuggestion[]; extras: unknown; + + /** + * @deprecated This API is still under active development and might change without notice. + */ speech: SpeechState | undefined; }>; @@ -117,7 +121,13 @@ export type ThreadRuntime = { export(): ExportedMessageRepository; import(repository: ExportedMessageRepository): void; getMesssageByIndex(idx: number): MessageRuntime; + getMesssageById(messageId: string): MessageRuntime; + + /** + * @deprecated This API is still under active development and might change without notice. + */ stopSpeaking: () => void; + unstable_on( event: "switched-to" | "run-start", callback: () => void, @@ -206,7 +216,9 @@ export type ThreadRuntime = { beginEdit: (messageId: string) => void; }; -export class ThreadRuntimeImpl implements ThreadRuntimeCore, ThreadRuntime { +export class ThreadRuntimeImpl + implements Omit, ThreadRuntime +{ // public path = "assistant.threads[main]"; // TODO /** @@ -390,12 +402,37 @@ export class ThreadRuntimeImpl implements ThreadRuntimeCore, ThreadRuntime { public getMesssageByIndex(idx: number) { if (idx < 0) throw new Error("Message index must be >= 0"); + return this._getMessageRuntime(() => { + const messages = this._threadBinding.getState().messages; + const message = messages[idx]; + if (!message) return undefined; + return { + message, + parentId: messages[idx - 1]?.id ?? null, + }; + }); + } + + public getMesssageById(messageId: string) { + return this._getMessageRuntime(() => + this._threadBinding.getState().getMessageById(messageId), + ); + } + + private _getMessageRuntime( + callback: () => + | { parentId: string | null; message: ThreadMessage } + | undefined, + ) { return new MessageRuntimeImpl( new ShallowMemoizeSubject({ getState: () => { - const { messages, speech: speechState } = this.getState(); - const message = messages[idx]; - if (!message) return SKIP_UPDATE; + const { message, parentId } = callback() ?? {}; + + const { messages, speech: speechState } = + this._threadBinding.getState(); + + if (!message || !parentId) return SKIP_UPDATE; const thread = this._threadBinding.getState(); @@ -406,8 +443,8 @@ export class ThreadRuntimeImpl implements ThreadRuntimeCore, ThreadRuntime { ...message, message, - isLast: idx === messages.length - 1, - parentId: messages[idx - 1]?.id ?? null, + isLast: messages.at(-1)?.id === message.id, + parentId, branches, branchNumber: branches.indexOf(message.id) + 1, diff --git a/packages/react/src/runtimes/core/ThreadRuntimeCore.tsx b/packages/react/src/runtimes/core/ThreadRuntimeCore.tsx index f8d1ddc7c..edc04e2db 100644 --- a/packages/react/src/runtimes/core/ThreadRuntimeCore.tsx +++ b/packages/react/src/runtimes/core/ThreadRuntimeCore.tsx @@ -44,6 +44,13 @@ export type SubmittedFeedback = Readonly<{ }>; export type ThreadRuntimeCore = Readonly<{ + getMessageById: (messageId: string) => + | { + parentId: string | null; + message: ThreadMessage; + } + | undefined; + getBranches: (messageId: string) => readonly string[]; switchToBranch: (branchId: string) => void;