From 62e9f1907e2627993ceda2bb23f9438f10980d94 Mon Sep 17 00:00:00 2001 From: Simon Farshid Date: Wed, 26 Jun 2024 22:49:54 -0700 Subject: [PATCH] feat: AssistantRuntime.newThread (#335) --- .changeset/brown-wasps-lick.md | 6 +++ .../react-ai-sdk/src/rsc/VercelRSCRuntime.tsx | 38 +++++++++++-- .../react-ai-sdk/src/ui/VercelAIRuntime.tsx | 40 +++++++++++--- packages/react/src/internal.ts | 1 + .../src/runtime/core/AssistantRuntime.tsx | 3 ++ .../src/runtime/core/BaseAssistantRuntime.tsx | 53 +++++++++++++++++++ .../react/src/runtime/local/LocalRuntime.tsx | 48 +++++++++++++---- packages/react/src/types/index.ts | 2 +- 8 files changed, 167 insertions(+), 24 deletions(-) create mode 100644 .changeset/brown-wasps-lick.md create mode 100644 packages/react/src/runtime/core/BaseAssistantRuntime.tsx diff --git a/.changeset/brown-wasps-lick.md b/.changeset/brown-wasps-lick.md new file mode 100644 index 000000000..ac15d2248 --- /dev/null +++ b/.changeset/brown-wasps-lick.md @@ -0,0 +1,6 @@ +--- +"@assistant-ui/react-ai-sdk": patch +"@assistant-ui/react": patch +--- + +feat: AssistantRuntime newThread diff --git a/packages/react-ai-sdk/src/rsc/VercelRSCRuntime.tsx b/packages/react-ai-sdk/src/rsc/VercelRSCRuntime.tsx index f2ec52331..c185a4d8a 100644 --- a/packages/react-ai-sdk/src/rsc/VercelRSCRuntime.tsx +++ b/packages/react-ai-sdk/src/rsc/VercelRSCRuntime.tsx @@ -1,9 +1,8 @@ "use client"; import { - type AppendMessage, - type AssistantRuntime, INTERNAL, + type AppendMessage, type ReactThreadRuntime, type ThreadMessage, type Unsubscribe, @@ -13,13 +12,42 @@ import type { VercelRSCAdapter } from "./VercelRSCAdapter"; import type { VercelRSCMessage } from "./VercelRSCMessage"; import { useVercelRSCSync } from "./useVercelRSCSync"; -const { ProxyConfigProvider } = INTERNAL; +const { ProxyConfigProvider, BaseAssistantRuntime } = INTERNAL; const EMPTY_BRANCHES: readonly string[] = Object.freeze([]); -export class VercelRSCRuntime +export class VercelRSCRuntime< + T extends WeakKey = VercelRSCMessage, +> extends BaseAssistantRuntime> { + constructor(adapter: VercelRSCAdapter) { + super(new VercelRSCThreadRuntime(adapter)); + } + + public set adapter(adapter: VercelRSCAdapter) { + this.thread.adapter = adapter; + } + + public onAdapterUpdated() { + return this.thread.onAdapterUpdated(); + } + + public registerModelConfigProvider() { + // no-op + return () => {}; + } + + public newThread() { + this.thread = new VercelRSCThreadRuntime(this.thread.adapter); + } + + public switchToThread() { + throw new Error("VercelRSCRuntime does not yet support switching threads"); + } +} + +class VercelRSCThreadRuntime extends ProxyConfigProvider - implements AssistantRuntime, ReactThreadRuntime + implements ReactThreadRuntime { private useAdapter: UseBoundStore }>>; diff --git a/packages/react-ai-sdk/src/ui/VercelAIRuntime.tsx b/packages/react-ai-sdk/src/ui/VercelAIRuntime.tsx index c5b9dfe29..74b59d1f9 100644 --- a/packages/react-ai-sdk/src/ui/VercelAIRuntime.tsx +++ b/packages/react-ai-sdk/src/ui/VercelAIRuntime.tsx @@ -1,8 +1,4 @@ -import type { - AssistantRuntime, - ReactThreadRuntime, - Unsubscribe, -} from "@assistant-ui/react"; +import type { ReactThreadRuntime, Unsubscribe } from "@assistant-ui/react"; import type { AppendMessage, ThreadMessage } from "@assistant-ui/react"; import { INTERNAL } from "@assistant-ui/react"; import type { Message } from "ai"; @@ -13,15 +9,43 @@ import { sliceMessagesUntil } from "./utils/sliceMessagesUntil"; import { useVercelAIComposerSync } from "./utils/useVercelAIComposerSync"; import { useVercelAIThreadSync } from "./utils/useVercelAIThreadSync"; -const { ProxyConfigProvider, MessageRepository } = INTERNAL; +const { ProxyConfigProvider, MessageRepository, BaseAssistantRuntime } = + INTERNAL; const hasUpcomingMessage = (isRunning: boolean, messages: ThreadMessage[]) => { return isRunning && messages[messages.length - 1]?.role !== "assistant"; }; -export class VercelAIRuntime +export class VercelAIRuntime extends BaseAssistantRuntime { + constructor(vercel: VercelHelpers) { + super(new VercelAIThreadRuntime(vercel)); + } + + public set vercel(vercel: VercelHelpers) { + this.thread.vercel = vercel; + } + + public onVercelUpdated() { + return this.thread.onVercelUpdated(); + } + + public registerModelConfigProvider() { + // no-op + return () => {}; + } + + public newThread() { + this.thread = new VercelAIThreadRuntime(this.thread.vercel); + } + + public switchToThread() { + throw new Error("VercelAIRuntime does not yet support switching threads"); + } +} + +class VercelAIThreadRuntime extends ProxyConfigProvider - implements AssistantRuntime, ReactThreadRuntime + implements ReactThreadRuntime { private _subscriptions = new Set<() => void>(); private repository = new MessageRepository(); diff --git a/packages/react/src/internal.ts b/packages/react/src/internal.ts index 02c362a89..d4818a0d9 100644 --- a/packages/react/src/internal.ts +++ b/packages/react/src/internal.ts @@ -1,2 +1,3 @@ export { ProxyConfigProvider } from "./utils/ProxyConfigProvider"; export { MessageRepository } from "./runtime/utils/MessageRepository"; +export { BaseAssistantRuntime } from "./runtime/core/BaseAssistantRuntime"; diff --git a/packages/react/src/runtime/core/AssistantRuntime.tsx b/packages/react/src/runtime/core/AssistantRuntime.tsx index 74da5df98..3916a090b 100644 --- a/packages/react/src/runtime/core/AssistantRuntime.tsx +++ b/packages/react/src/runtime/core/AssistantRuntime.tsx @@ -3,5 +3,8 @@ import type { Unsubscribe } from "../../types/Unsubscribe"; import type { ThreadRuntime } from "./ThreadRuntime"; export type AssistantRuntime = ThreadRuntime & { + newThread: () => void; + switchToThread: (threadId: string) => void; + registerModelConfigProvider: (provider: ModelConfigProvider) => Unsubscribe; }; diff --git a/packages/react/src/runtime/core/BaseAssistantRuntime.tsx b/packages/react/src/runtime/core/BaseAssistantRuntime.tsx new file mode 100644 index 000000000..c479e4e5e --- /dev/null +++ b/packages/react/src/runtime/core/BaseAssistantRuntime.tsx @@ -0,0 +1,53 @@ +import type { AppendMessage } from "../../types/AssistantTypes"; +import { type ModelConfigProvider } from "../../types/ModelConfigTypes"; +import type { Unsubscribe } from "../../types/Unsubscribe"; +import type { AssistantRuntime } from "./AssistantRuntime"; +import { ThreadRuntime } from "./ThreadRuntime"; + +export abstract class BaseAssistantRuntime + implements AssistantRuntime +{ + constructor(protected thread: TThreadRuntime) {} + + public abstract registerModelConfigProvider( + provider: ModelConfigProvider, + ): Unsubscribe; + public abstract newThread(): void; + public abstract switchToThread(threadId: string): void; + + public get messages() { + return this.thread.messages; + } + + public get isRunning() { + return this.thread.isRunning; + } + + public getBranches(messageId: string): readonly string[] { + return this.thread.getBranches(messageId); + } + + public switchToBranch(branchId: string): void { + return this.thread.switchToBranch(branchId); + } + + public append(message: AppendMessage): void { + return this.thread.append(message); + } + + public startRun(parentId: string | null): void { + return this.thread.startRun(parentId); + } + + public cancelRun(): void { + return this.thread.cancelRun(); + } + + public addToolResult(toolCallId: string, result: any) { + return this.thread.addToolResult(toolCallId, result); + } + + public subscribe(callback: () => void): Unsubscribe { + return this.thread.subscribe(callback); + } +} diff --git a/packages/react/src/runtime/local/LocalRuntime.tsx b/packages/react/src/runtime/local/LocalRuntime.tsx index b2c7aa559..213dff5c0 100644 --- a/packages/react/src/runtime/local/LocalRuntime.tsx +++ b/packages/react/src/runtime/local/LocalRuntime.tsx @@ -8,14 +8,44 @@ import { mergeModelConfigs, } from "../../types/ModelConfigTypes"; import type { Unsubscribe } from "../../types/Unsubscribe"; -import type { AssistantRuntime } from "../core/AssistantRuntime"; +import { ThreadRuntime } from "../core"; import { MessageRepository } from "../utils/MessageRepository"; import { generateId } from "../utils/idUtils"; +import { BaseAssistantRuntime } from "../core/BaseAssistantRuntime"; import type { ChatModelAdapter, ChatModelRunResult } from "./ChatModelAdapter"; -export class LocalRuntime implements AssistantRuntime { +export class LocalRuntime extends BaseAssistantRuntime { + private readonly _configProviders: Set; + + constructor(adapter: ChatModelAdapter) { + const configProviders = new Set(); + super(new LocalThreadRuntime(configProviders, adapter)); + this._configProviders = configProviders; + } + + public set adapter(adapter: ChatModelAdapter) { + this.thread.adapter = adapter; + } + + registerModelConfigProvider(provider: ModelConfigProvider) { + this._configProviders.add(provider); + return () => this._configProviders.delete(provider); + } + + public newThread() { + return (this.thread = new LocalThreadRuntime( + this._configProviders, + this.thread.adapter, + )); + } + + public switchToThread() { + throw new Error("LocalRuntime does not yet support switching threads"); + } +} + +class LocalThreadRuntime implements ThreadRuntime { private _subscriptions = new Set<() => void>(); - private _configProviders = new Set(); private abortController: AbortController | null = null; private repository = new MessageRepository(); @@ -27,7 +57,10 @@ export class LocalRuntime implements AssistantRuntime { return this.abortController != null; } - constructor(public adapter: ChatModelAdapter) {} + constructor( + private _configProviders: Set, + public adapter: ChatModelAdapter, + ) {} public getBranches(messageId: string): string[] { return this.repository.getBranches(messageId); @@ -117,12 +150,7 @@ export class LocalRuntime implements AssistantRuntime { return () => this._subscriptions.delete(callback); } - registerModelConfigProvider(provider: ModelConfigProvider) { - this._configProviders.add(provider); - return () => this._configProviders.delete(provider); - } - addToolResult() { - throw new Error("LocalRuntime does not yet support tool results"); + throw new Error("LocalRuntime does not yet support adding tool results"); } } diff --git a/packages/react/src/types/index.ts b/packages/react/src/types/index.ts index c7328709a..906f78b00 100644 --- a/packages/react/src/types/index.ts +++ b/packages/react/src/types/index.ts @@ -23,6 +23,6 @@ export type { ToolCallContentPartComponent, } from "./ContentPartComponentTypes"; -export type { ModelConfig } from "./ModelConfigTypes"; +export type { ModelConfig, ModelConfigProvider } from "./ModelConfigTypes"; export type { Unsubscribe } from "./Unsubscribe";