diff --git a/.changeset/swift-pigs-sort.md b/.changeset/swift-pigs-sort.md new file mode 100644 index 000000000..0e31159f5 --- /dev/null +++ b/.changeset/swift-pigs-sort.md @@ -0,0 +1,6 @@ +--- +"@assistant-ui/react-playground": patch +"@assistant-ui/react": patch +--- + +feat: ThreadManagerRuntime diff --git a/packages/react-playground/src/lib/playground-runtime.ts b/packages/react-playground/src/lib/playground-runtime.ts index 5230457a0..59f36a570 100644 --- a/packages/react-playground/src/lib/playground-runtime.ts +++ b/packages/react-playground/src/lib/playground-runtime.ts @@ -49,36 +49,83 @@ const makeModelConfigStore = () => config: {}, })); -class PlaygroundRuntimeCore extends BaseAssistantRuntimeCore { - private readonly _proxyConfigProvider: InstanceType< - typeof ProxyConfigProvider - >; - - constructor(initialMessages: CoreMessage[], adapter: ChatModelAdapter) { - const cp = new ProxyConfigProvider(); - super( - new PlaygroundThreadRuntimeCore( - cp, - fromCoreMessages(initialMessages), - adapter, - ), - ); - this._proxyConfigProvider = cp; +type PlaygroundThreadFactory = ( + threadId: string, +) => PlaygroundThreadRuntimeCore; + +const EMPTY_ARRAY = [] as never[]; + +class PlaygroundThreadManagerRuntimeCore + implements INTERNAL.ThreadManagerRuntimeCore +{ + private _mainThread: PlaygroundThreadRuntimeCore; + + public get mainThread() { + return this._mainThread; } - public switchToNewThread() { - this.thread = new PlaygroundThreadRuntimeCore( - this._proxyConfigProvider, - [], - this.thread.adapter, - ); + public get threads() { + return EMPTY_ARRAY; } - public switchToThread(threadId: string | null) { - if (threadId !== null) - throw new Error("PlaygroundRuntime does not support switching threads"); + public get archivedThreads() { + return EMPTY_ARRAY; + } - this.switchToNewThread(); + constructor(private threadFactory: PlaygroundThreadFactory) { + this._mainThread = this.threadFactory(generateId()); + } + + public switchToThread(): void { + throw new Error("Method not implemented."); + } + + public switchToNewThread(): void { + this._mainThread = this.threadFactory(generateId()); + this.notifySubscribers(); + } + + public rename(): Promise { + throw new Error("Method not implemented."); + } + public archive(): Promise { + throw new Error("Method not implemented."); + } + public unarchive(): Promise { + throw new Error("Method not implemented."); + } + public delete(): Promise { + throw new Error("Method not implemented."); + } + + private _subscriptions = new Set<() => void>(); + + private notifySubscribers() { + for (const callback of this._subscriptions) callback(); + } + + public subscribe(callback: () => void): Unsubscribe { + this._subscriptions.add(callback); + return () => this._subscriptions.delete(callback); + } +} + +class PlaygroundRuntimeCore extends BaseAssistantRuntimeCore { + public readonly threadManager; + + constructor(adapter: ChatModelAdapter, initialMessages: CoreMessage[]) { + super(); + + this.threadManager = new PlaygroundThreadManagerRuntimeCore((threadId) => { + const thread = new PlaygroundThreadRuntimeCore( + this._proxyConfigProvider, + threadId, + fromCoreMessages(initialMessages), + adapter, + ); + initialMessages = []; + return thread; + }); } public override registerModelConfigProvider( @@ -108,7 +155,6 @@ export class PlaygroundThreadRuntimeCore implements INTERNAL.ThreadRuntimeCore { public tools: Record> = {}; - public readonly threadId = generateId(); public readonly isDisabled = false; public readonly capabilities = CAPABILITIES; public readonly extras = undefined; @@ -130,6 +176,7 @@ export class PlaygroundThreadRuntimeCore implements INTERNAL.ThreadRuntimeCore { constructor( configProvider: ModelConfigProvider, + public readonly threadId: string, private _messages: ThreadMessage[], public readonly adapter: ChatModelAdapter, ) { @@ -586,8 +633,8 @@ export const usePlaygroundRuntime = ({ const [runtime] = useState( () => new PlaygroundRuntimeCore( - initialMessages, new EdgeChatAdapter(runtimeOptions), + initialMessages, ), ); diff --git a/packages/react/src/api/AssistantRuntime.ts b/packages/react/src/api/AssistantRuntime.ts index 7fc7511be..1b14d4edd 100644 --- a/packages/react/src/api/AssistantRuntime.ts +++ b/packages/react/src/api/AssistantRuntime.ts @@ -7,9 +7,14 @@ import { ThreadRuntimeImpl, } from "./ThreadRuntime"; import { Unsubscribe } from "../types"; +import { + ThreadManagerRuntime, + ThreadManagerRuntimeImpl, +} from "./ThreadManagerRuntime"; export type AssistantRuntime = { thread: ThreadRuntime; + threadManager: ThreadManagerRuntime; switchToNewThread(): void; @@ -28,19 +33,25 @@ export type AssistantRuntime = { }; export class AssistantRuntimeImpl - implements Omit, AssistantRuntime + implements + Omit, + AssistantRuntime { + public readonly threadManager; + protected constructor( private readonly _core: AssistantRuntimeCore, private readonly _thread: ThreadRuntime, - ) {} + ) { + this.threadManager = new ThreadManagerRuntimeImpl(_core.threadManager); + } public get thread() { return this._thread; } public switchToNewThread() { - return this._core.switchToNewThread(); + return this._core.threadManager.switchToNewThread(); } public switchToThread(threadId: string): void; @@ -49,7 +60,8 @@ export class AssistantRuntimeImpl */ public switchToThread(threadId: string | null): void; public switchToThread(threadId: string | null) { - return this._core.switchToThread(threadId); + if (threadId === null) return this.switchToNewThread(); + return this._core.threadManager.switchToThread(threadId); } public registerModelConfigProvider(provider: ModelConfigProvider) { @@ -59,8 +71,8 @@ export class AssistantRuntimeImpl /** * @deprecated Thread is now static and never gets updated. This will be removed in 0.6.0. */ - public subscribe(callback: () => void) { - return this._core.subscribe(callback); + public subscribe() { + return () => {}; } protected static createMainThreadRuntime( @@ -75,8 +87,8 @@ export class AssistantRuntimeImpl ref: "threads.main", threadSelector: { type: "main" }, }, - getState: () => _core.thread, - subscribe: (callback) => _core.subscribe(callback), + getState: () => _core.threadManager.mainThread, + subscribe: (callback) => _core.threadManager.subscribe(callback), }), ); } diff --git a/packages/react/src/api/AttachmentRuntime.ts b/packages/react/src/api/AttachmentRuntime.ts index f994119a8..7ebb33f89 100644 --- a/packages/react/src/api/AttachmentRuntime.ts +++ b/packages/react/src/api/AttachmentRuntime.ts @@ -7,7 +7,7 @@ import { PendingAttachment, Unsubscribe, } from "../types"; -import { AttachmentRuntimePath } from "./PathTypes"; +import { AttachmentRuntimePath } from "./RuntimePathTypes"; type MessageAttachmentState = CompleteAttachment & { source: "message"; diff --git a/packages/react/src/api/ComposerRuntime.ts b/packages/react/src/api/ComposerRuntime.ts index 6312feb9d..88a942cdc 100644 --- a/packages/react/src/api/ComposerRuntime.ts +++ b/packages/react/src/api/ComposerRuntime.ts @@ -14,7 +14,7 @@ import { } from "./AttachmentRuntime"; import { ShallowMemoizeSubject } from "./subscribable/ShallowMemoizeSubject"; import { SKIP_UPDATE } from "./subscribable/SKIP_UPDATE"; -import { ComposerRuntimePath } from "./PathTypes"; +import { ComposerRuntimePath } from "./RuntimePathTypes"; export type ThreadComposerRuntimeCoreBinding = SubscribableWithState< ThreadComposerRuntimeCore | undefined, diff --git a/packages/react/src/api/ContentPartRuntime.ts b/packages/react/src/api/ContentPartRuntime.ts index 0a53f1bf0..c0008d497 100644 --- a/packages/react/src/api/ContentPartRuntime.ts +++ b/packages/react/src/api/ContentPartRuntime.ts @@ -8,7 +8,7 @@ import { ThreadRuntimeCoreBinding } from "./ThreadRuntime"; import { MessageStateBinding } from "./MessageRuntime"; import { SubscribableWithState } from "./subscribable/Subscribable"; import { Unsubscribe } from "../types"; -import { ContentPartRuntimePath } from "./PathTypes"; +import { ContentPartRuntimePath } from "./RuntimePathTypes"; export type ContentPartState = ( | ThreadUserContentPart diff --git a/packages/react/src/api/MessageRuntime.ts b/packages/react/src/api/MessageRuntime.ts index 8edf93a2f..45f5cd992 100644 --- a/packages/react/src/api/MessageRuntime.ts +++ b/packages/react/src/api/MessageRuntime.ts @@ -27,7 +27,7 @@ import { ContentPartRuntimeImpl, ContentPartState, } from "./ContentPartRuntime"; -import { MessageRuntimePath } from "./PathTypes"; +import { MessageRuntimePath } from "./RuntimePathTypes"; import { ThreadRuntimeCoreBinding } from "./ThreadRuntime"; import { NestedSubscriptionSubject } from "./subscribable/NestedSubscriptionSubject"; import { SKIP_UPDATE } from "./subscribable/SKIP_UPDATE"; diff --git a/packages/react/src/api/PathTypes.ts b/packages/react/src/api/RuntimePathTypes.ts similarity index 94% rename from packages/react/src/api/PathTypes.ts rename to packages/react/src/api/RuntimePathTypes.ts index fce794591..a870f98ca 100644 --- a/packages/react/src/api/PathTypes.ts +++ b/packages/react/src/api/RuntimePathTypes.ts @@ -1,3 +1,7 @@ +export type ThreadManagerRuntimePath = { + ref: string; +}; + export type ThreadRuntimePath = { ref: string; threadSelector: { type: "main" }; diff --git a/packages/react/src/api/ThreadManagerRuntime.ts b/packages/react/src/api/ThreadManagerRuntime.ts new file mode 100644 index 000000000..eda60c90d --- /dev/null +++ b/packages/react/src/api/ThreadManagerRuntime.ts @@ -0,0 +1,80 @@ +import { LazyMemoizeSubject } from "./subscribable/LazyMemoizeSubject"; +import { + ThreadManagerMetadata, + ThreadManagerRuntimeCore, +} from "../runtimes/core/ThreadManagerRuntimeCore"; +import { Unsubscribe } from "../types"; +import { ThreadManagerRuntimePath } from "./RuntimePathTypes"; + +export type ThreadManagerState = Readonly<{ + threads: readonly ThreadManagerMetadata[]; + archivedThreads: readonly ThreadManagerMetadata[]; +}>; + +export type ThreadManagerRuntime = Readonly<{ + path: ThreadManagerRuntimePath; + getState(): ThreadManagerState; + + rename(threadId: string, newTitle: string): Promise; + archive(threadId: string): Promise; + unarchive(threadId: string): Promise; + delete(threadId: string): Promise; + + subscribe(callback: () => void): Unsubscribe; +}>; + +const getThreadManagerState = ( + threadManager: ThreadManagerRuntimeCore, +): ThreadManagerState => { + return { + threads: threadManager.threads, + archivedThreads: threadManager.archivedThreads, + }; +}; + +const THREAD_MANAGER_PATH = { + ref: "threadManager", +}; + +export type ThreadManagerRuntimeCoreBinding = ThreadManagerRuntimeCore; + +export class ThreadManagerRuntimeImpl implements ThreadManagerRuntime { + public get path() { + return THREAD_MANAGER_PATH; + } + + private _getState; + constructor(private _core: ThreadManagerRuntimeCoreBinding) { + const stateBinding = new LazyMemoizeSubject({ + path: THREAD_MANAGER_PATH, + getState: () => getThreadManagerState(_core), + subscribe: (callback) => _core.subscribe(callback), + }); + + this._getState = stateBinding.getState.bind(stateBinding); + } + + public getState(): ThreadManagerState { + return this._getState(); + } + + public rename(threadId: string, newTitle: string): Promise { + return this._core.rename(threadId, newTitle); + } + + public archive(threadId: string): Promise { + return this._core.archive(threadId); + } + + public unarchive(threadId: string): Promise { + return this._core.unarchive(threadId); + } + + public delete(threadId: string): Promise { + return this._core.delete(threadId); + } + + public subscribe(callback: () => void): Unsubscribe { + return this._core.subscribe(callback); + } +} diff --git a/packages/react/src/api/ThreadRuntime.ts b/packages/react/src/api/ThreadRuntime.ts index a67ac2165..4d2a00e1c 100644 --- a/packages/react/src/api/ThreadRuntime.ts +++ b/packages/react/src/api/ThreadRuntime.ts @@ -33,7 +33,7 @@ import { import { LazyMemoizeSubject } from "./subscribable/LazyMemoizeSubject"; import { SKIP_UPDATE } from "./subscribable/SKIP_UPDATE"; import { ComposerRuntimeCore } from "../runtimes/core/ComposerRuntimeCore"; -import { MessageRuntimePath, ThreadRuntimePath } from "./PathTypes"; +import { MessageRuntimePath, ThreadRuntimePath } from "./RuntimePathTypes"; export type CreateAppendMessage = | string @@ -108,7 +108,7 @@ export const getThreadState = (runtime: ThreadRuntimeCore): ThreadState => { }); }; -export type ThreadRuntime = { +export type ThreadRuntime = Readonly<{ readonly path: ThreadRuntimePath; readonly composer: ThreadComposerRuntime; @@ -217,7 +217,7 @@ export type ThreadRuntime = { * @deprecated Use `getMesssageById(id).composer.beginEdit()` instead. This will be removed in 0.6.0. */ beginEdit: (messageId: string) => void; -}; +}>; export class ThreadRuntimeImpl implements Omit, ThreadRuntime diff --git a/packages/react/src/api/index.ts b/packages/react/src/api/index.ts index 33b6273b4..6baabbb0a 100644 --- a/packages/react/src/api/index.ts +++ b/packages/react/src/api/index.ts @@ -1,7 +1,10 @@ export type { AssistantRuntime } from "./AssistantRuntime"; export type { ThreadRuntime, ThreadState } from "./ThreadRuntime"; export type { MessageRuntime, MessageState } from "./MessageRuntime"; -export type { ContentPartRuntime } from "./ContentPartRuntime"; +export type { + ContentPartRuntime, + ContentPartState, +} from "./ContentPartRuntime"; export type { ComposerRuntime, ThreadComposerRuntime, @@ -10,3 +13,8 @@ export type { ThreadComposerState, ComposerState, } from "./ComposerRuntime"; +export type { AttachmentRuntime, AttachmentState } from "./AttachmentRuntime"; +export type { + ThreadManagerRuntime, + ThreadManagerState, +} from "./ThreadManagerRuntime"; diff --git a/packages/react/src/context/providers/AssistantRuntimeProvider.tsx b/packages/react/src/context/providers/AssistantRuntimeProvider.tsx index 2d366a58a..d0cc65bd5 100644 --- a/packages/react/src/context/providers/AssistantRuntimeProvider.tsx +++ b/packages/react/src/context/providers/AssistantRuntimeProvider.tsx @@ -27,18 +27,35 @@ const useAssistantToolUIsStore = () => { return useMemo(() => makeAssistantToolUIsStore(), []); }; +const useThreadManagerStore = (runtime: AssistantRuntime) => { + const [store] = useState(() => + create(() => runtime.threadManager.getState()), + ); + + useEffect(() => { + const updateState = () => + writableStore(store).setState(runtime.threadManager.getState(), true); + updateState(); + return runtime.threadManager.subscribe(updateState); + }, [runtime, store]); + + return store; +}; + export const AssistantRuntimeProviderImpl: FC< PropsWithChildren > = ({ children, runtime }) => { const useAssistantRuntime = useAssistantRuntimeStore(runtime); const useToolUIs = useAssistantToolUIsStore(); + const useThreadManager = useThreadManagerStore(runtime); const context = useMemo(() => { return { useToolUIs, useAssistantRuntime, useAssistantActions: useAssistantRuntime, + useThreadManager, }; - }, [useAssistantRuntime, useToolUIs]); + }, [useAssistantRuntime, useToolUIs, useThreadManager]); return ( diff --git a/packages/react/src/context/react/AssistantContext.ts b/packages/react/src/context/react/AssistantContext.ts index ada97612c..84eb566f5 100644 --- a/packages/react/src/context/react/AssistantContext.ts +++ b/packages/react/src/context/react/AssistantContext.ts @@ -7,10 +7,12 @@ import { createContextHook } from "./utils/createContextHook"; import { createContextStoreHook } from "./utils/createContextStoreHook"; import { UseBoundStore } from "zustand"; import { AssistantRuntime } from "../../api/AssistantRuntime"; +import { ThreadManagerState } from "../../api/ThreadManagerRuntime"; export type AssistantContextValue = { useToolUIs: UseBoundStore>; useAssistantRuntime: UseBoundStore>; + useThreadManager: UseBoundStore>; /** * @deprecated Use `useAssistantRuntime` instead. This will be removed in 0.6.0. @@ -65,3 +67,8 @@ export const { useToolUIs, useToolUIsStore } = createContextStoreHook( useAssistantContext, "useToolUIs", ); + +export const { useThreadManager } = createContextStoreHook( + useAssistantContext, + "useThreadManager", +); diff --git a/packages/react/src/context/react/index.ts b/packages/react/src/context/react/index.ts index 25f23218a..842c825c1 100644 --- a/packages/react/src/context/react/index.ts +++ b/packages/react/src/context/react/index.ts @@ -1,5 +1,6 @@ export { useAssistantRuntime, + useThreadManager, useToolUIs, useToolUIsStore, @@ -84,7 +85,9 @@ export { useMessageRuntime, useMessage, useEditComposer, + // TODO move out of runtime context after 0.6.0 useMessageUtils, + // TODO move out of runtime context after 0.6.0 useMessageUtilsStore, /** diff --git a/packages/react/src/internal.ts b/packages/react/src/internal.ts index d2308f3c1..27f9ac692 100644 --- a/packages/react/src/internal.ts +++ b/packages/react/src/internal.ts @@ -1,4 +1,5 @@ export type { ThreadRuntimeCore } from "./runtimes/core/ThreadRuntimeCore"; +export type { ThreadManagerRuntimeCore } from "./runtimes/core/ThreadManagerRuntimeCore"; export { DefaultThreadComposerRuntimeCore } from "./runtimes/composer/DefaultThreadComposerRuntimeCore"; export { ProxyConfigProvider } from "./utils/ProxyConfigProvider"; export { MessageRepository } from "./runtimes/utils/MessageRepository"; diff --git a/packages/react/src/runtimes/core/AssistantRuntimeCore.tsx b/packages/react/src/runtimes/core/AssistantRuntimeCore.tsx index d82d14739..90fe2d8e5 100644 --- a/packages/react/src/runtimes/core/AssistantRuntimeCore.tsx +++ b/packages/react/src/runtimes/core/AssistantRuntimeCore.tsx @@ -1,19 +1,9 @@ -import { ThreadRuntimeCore } from "./ThreadRuntimeCore"; import type { ModelConfigProvider } from "../../types/ModelConfigTypes"; import type { Unsubscribe } from "../../types/Unsubscribe"; +import { ThreadManagerRuntimeCore } from "./ThreadManagerRuntimeCore"; export type AssistantRuntimeCore = { - readonly thread: ThreadRuntimeCore; - - switchToNewThread: () => void; - - switchToThread(threadId: string): void; - /** - * @deprecated Use `switchToNewThread` instead. This will be removed in 0.6.0. - */ - switchToThread(threadId: string | null): void; + readonly threadManager: ThreadManagerRuntimeCore; registerModelConfigProvider: (provider: ModelConfigProvider) => Unsubscribe; - - subscribe: (callback: () => void) => Unsubscribe; }; diff --git a/packages/react/src/runtimes/core/BaseAssistantRuntimeCore.tsx b/packages/react/src/runtimes/core/BaseAssistantRuntimeCore.tsx index 85579aa9b..8eb2a15c5 100644 --- a/packages/react/src/runtimes/core/BaseAssistantRuntimeCore.tsx +++ b/packages/react/src/runtimes/core/BaseAssistantRuntimeCore.tsx @@ -1,40 +1,18 @@ import { type ModelConfigProvider } from "../../types/ModelConfigTypes"; import type { Unsubscribe } from "../../types/Unsubscribe"; import type { AssistantRuntimeCore } from "./AssistantRuntimeCore"; -import { ThreadRuntimeCore } from "./ThreadRuntimeCore"; +import { ProxyConfigProvider } from "../../utils/ProxyConfigProvider"; +import { ThreadManagerRuntimeCore } from "./ThreadManagerRuntimeCore"; -export abstract class BaseAssistantRuntimeCore< - TThreadRuntime extends ThreadRuntimeCore, -> implements AssistantRuntimeCore -{ - constructor(private _thread: TThreadRuntime) { - this._thread = _thread; - } - - get thread() { - return this._thread; - } - - set thread(thread: TThreadRuntime) { - this._thread = thread; - this.subscriptionHandler(); - } +export abstract class BaseAssistantRuntimeCore implements AssistantRuntimeCore { + protected readonly _proxyConfigProvider = new ProxyConfigProvider(); + public abstract get threadManager(): ThreadManagerRuntimeCore; - public abstract switchToNewThread(): void; + constructor() {} - public abstract registerModelConfigProvider( + public registerModelConfigProvider( provider: ModelConfigProvider, - ): Unsubscribe; - public abstract switchToThread(threadId: string | null): void; - - private _subscriptions = new Set<() => void>(); - - public subscribe(callback: () => void): Unsubscribe { - this._subscriptions.add(callback); - return () => this._subscriptions.delete(callback); + ): Unsubscribe { + return this._proxyConfigProvider.registerModelConfigProvider(provider); } - - private subscriptionHandler = () => { - for (const callback of this._subscriptions) callback(); - }; } diff --git a/packages/react/src/runtimes/core/ThreadManagerRuntimeCore.tsx b/packages/react/src/runtimes/core/ThreadManagerRuntimeCore.tsx new file mode 100644 index 000000000..9346e7864 --- /dev/null +++ b/packages/react/src/runtimes/core/ThreadManagerRuntimeCore.tsx @@ -0,0 +1,27 @@ +import { Unsubscribe } from "../../types"; +import { ThreadRuntimeCore } from "./ThreadRuntimeCore"; + +export type ThreadManagerMetadata = { + threadId: string; + title?: string; +}; + +export type ThreadManagerRuntimeCore = { + mainThread: ThreadRuntimeCore; + + threads: readonly ThreadManagerMetadata[]; + archivedThreads: readonly ThreadManagerMetadata[]; + + switchToThread(threadId: string): void; + switchToNewThread(): void; + + // getLoadThreadsPromise(): Promise; + // getLoadArchivedThreadsPromise(): Promise; + // create(): Promise; + rename(threadId: string, newTitle: string): Promise; + archive(threadId: string): Promise; + unarchive(threadId: string): Promise; + delete(threadId: string): Promise; + + subscribe(callback: () => void): Unsubscribe; +}; diff --git a/packages/react/src/runtimes/core/ThreadRuntimeCore.tsx b/packages/react/src/runtimes/core/ThreadRuntimeCore.tsx index 4e739193b..419095191 100644 --- a/packages/react/src/runtimes/core/ThreadRuntimeCore.tsx +++ b/packages/react/src/runtimes/core/ThreadRuntimeCore.tsx @@ -45,6 +45,7 @@ export type SubmittedFeedback = Readonly<{ export type ThreadRuntimeEventType = | "switched-to" + | "switched-away" | "run-start" | "model-config-update"; diff --git a/packages/react/src/runtimes/core/subscribeToMainThread.ts b/packages/react/src/runtimes/core/subscribeToMainThread.ts index 06f52c94e..13fface33 100644 --- a/packages/react/src/runtimes/core/subscribeToMainThread.ts +++ b/packages/react/src/runtimes/core/subscribeToMainThread.ts @@ -12,7 +12,7 @@ export const subscribeToMainThread = ( let cleanup: Unsubscribe | undefined; const inner = () => { cleanup?.(); - cleanup = runtime.thread.subscribe(callback); + cleanup = runtime.threadManager.mainThread.subscribe(callback); if (!first) { callback(); @@ -20,7 +20,7 @@ export const subscribeToMainThread = ( first = false; }; - const unsubscribe = runtime.subscribe(inner); + const unsubscribe = runtime.threadManager.mainThread.subscribe(inner); inner(); return () => { diff --git a/packages/react/src/runtimes/external-store/ExternalStoreAdapter.tsx b/packages/react/src/runtimes/external-store/ExternalStoreAdapter.tsx index bf20aad6f..78e6ee720 100644 --- a/packages/react/src/runtimes/external-store/ExternalStoreAdapter.tsx +++ b/packages/react/src/runtimes/external-store/ExternalStoreAdapter.tsx @@ -3,8 +3,24 @@ import { AttachmentAdapter } from "../attachment"; import { AddToolResultOptions, ThreadSuggestion } from "../core"; import { FeedbackAdapter } from "../feedback/FeedbackAdapter"; import { SpeechSynthesisAdapter } from "../speech/SpeechAdapterTypes"; +import { ThreadManagerMetadata } from "../core/ThreadManagerRuntimeCore"; import { ThreadMessageLike } from "./ThreadMessageLike"; +export type ExternalStoreThreadManagerAdapter = { + threadId?: string | undefined; + threads?: readonly ThreadManagerMetadata[] | undefined; + archivedThreads?: readonly ThreadManagerMetadata[] | undefined; + onSwitchToNewThread?: (() => Promise | void) | undefined; + onSwitchToThread?: ((threadId: string) => Promise | void) | undefined; + onRename?: ( + threadId: string, + newTitle: string, + ) => (Promise | void) | undefined; + onArchive?: ((threadId: string) => Promise | void) | undefined; + onUnarchive?: ((threadId: string) => Promise | void) | undefined; + onDelete?: ((threadId: string) => Promise | void) | undefined; +}; + export type ExternalStoreMessageConverter = ( message: T, idx: number, @@ -15,7 +31,20 @@ type ExternalStoreMessageConverterAdapter = { }; type ExternalStoreAdapterBase = { + /** + * @deprecated Use `adapters.threadManager.threadId` instead. This will be removed in 0.6.0. + */ threadId?: string | undefined; + + /** + * @deprecated Use `adapters.threadManager.onSwitchToThread` instead. This will be removed in 0.6.0. + */ + onSwitchToThread?: ((threadId: string) => Promise | void) | undefined; + /** + * @deprecated Use `adapters.threadManager.onSwitchToNewThread` instead. This will be removed in 0.6.0. + */ + onSwitchToNewThread?: (() => Promise | void) | undefined; + isDisabled?: boolean | undefined; isRunning?: boolean | undefined; messages: T[]; @@ -30,13 +59,12 @@ type ExternalStoreAdapterBase = { onAddToolResult?: | ((options: AddToolResultOptions) => Promise | void) | undefined; - onSwitchToThread?: ((threadId: string) => Promise | void) | undefined; - onSwitchToNewThread?: (() => Promise | void) | undefined; convertMessage?: ExternalStoreMessageConverter | undefined; adapters?: { attachments?: AttachmentAdapter | undefined; speech?: SpeechSynthesisAdapter | undefined; feedback?: FeedbackAdapter | undefined; + threadManager?: ExternalStoreThreadManagerAdapter | undefined; }; unstable_capabilities?: | { diff --git a/packages/react/src/runtimes/external-store/ExternalStoreRuntimeCore.tsx b/packages/react/src/runtimes/external-store/ExternalStoreRuntimeCore.tsx index de4708585..a9e3451ea 100644 --- a/packages/react/src/runtimes/external-store/ExternalStoreRuntimeCore.tsx +++ b/packages/react/src/runtimes/external-store/ExternalStoreRuntimeCore.tsx @@ -1,56 +1,39 @@ -import { BaseAssistantRuntimeCore, ProxyConfigProvider } from "../../internal"; -import { ModelConfigProvider } from "../../types"; +import { BaseAssistantRuntimeCore } from "../../internal"; +import { ExternalStoreThreadManagerRuntimeCore } from "./ExternalStoreThreadManagementAdapter"; import { ExternalStoreAdapter } from "./ExternalStoreAdapter"; import { ExternalStoreThreadRuntimeCore } from "./ExternalStoreThreadRuntimeCore"; -export class ExternalStoreRuntimeCore extends BaseAssistantRuntimeCore { - private readonly _proxyConfigProvider; +const getThreadManagerAdapter = (store: ExternalStoreAdapter) => { + return { + threadId: store.threadId, + onSwitchToNewThread: store.onSwitchToNewThread, + onSwitchToThread: store.onSwitchToThread, + ...store.adapters?.threadManager, + }; +}; - constructor(store: ExternalStoreAdapter) { - const provider = new ProxyConfigProvider(); - super(new ExternalStoreThreadRuntimeCore(provider, store)); - this._proxyConfigProvider = provider; - } - - public getModelConfig() { - return this._proxyConfigProvider.getModelConfig(); - } +export class ExternalStoreRuntimeCore extends BaseAssistantRuntimeCore { + public readonly threadManager; - public registerModelConfigProvider(provider: ModelConfigProvider) { - return this._proxyConfigProvider.registerModelConfigProvider(provider); - } - - public async switchToNewThread() { - if (!this.thread.store.onSwitchToNewThread) - throw new Error("Runtime does not support switching to new threads."); + private _store: ExternalStoreAdapter; - this.thread = new ExternalStoreThreadRuntimeCore( - this._proxyConfigProvider, - { - ...this.thread.store, - messages: [], - }, + constructor(store: ExternalStoreAdapter) { + super(); + this._store = store; + this.threadManager = new ExternalStoreThreadManagerRuntimeCore( + getThreadManagerAdapter(store), + (threadId) => + new ExternalStoreThreadRuntimeCore( + this._proxyConfigProvider, + threadId, + this._store, + ), ); - await this.thread.store.onSwitchToNewThread!(); - this.thread._notifyEventSubscribers("switched-to"); } - public async switchToThread(threadId: string | null) { - if (threadId !== null) { - if (!this.thread.store.onSwitchToThread) - throw new Error("Runtime does not support switching threads."); - - this.thread = new ExternalStoreThreadRuntimeCore( - this._proxyConfigProvider, - { - ...this.thread.store, - messages: [], // ignore messages until rerender - }, - ); - await this.thread.store.onSwitchToThread!(threadId); - this.thread._notifyEventSubscribers("switched-to"); - } else { - this.switchToNewThread(); - } + public setStore(store: ExternalStoreAdapter) { + this._store = store; + this.threadManager.setAdapter(getThreadManagerAdapter(store)); + this.threadManager.mainThread.setStore(store); } } diff --git a/packages/react/src/runtimes/external-store/ExternalStoreThreadManagementAdapter.tsx b/packages/react/src/runtimes/external-store/ExternalStoreThreadManagementAdapter.tsx new file mode 100644 index 000000000..0e2701a4c --- /dev/null +++ b/packages/react/src/runtimes/external-store/ExternalStoreThreadManagementAdapter.tsx @@ -0,0 +1,124 @@ +import type { Unsubscribe } from "../../types"; +import { ExternalStoreThreadRuntimeCore } from "./ExternalStoreThreadRuntimeCore"; +import { ThreadManagerRuntimeCore } from "../core/ThreadManagerRuntimeCore"; +import { ExternalStoreThreadManagerAdapter } from "./ExternalStoreAdapter"; + +export type ExternalStoreThreadFactory = ( + threadId: string, +) => ExternalStoreThreadRuntimeCore; + +const EMPTY_ARRAY = Object.freeze([]); +const DEFAULT_THREAD_ID = "DEFAULT_THREAD_ID"; + +export class ExternalStoreThreadManagerRuntimeCore + implements ThreadManagerRuntimeCore +{ + public get threads() { + return this.adapter.threads ?? EMPTY_ARRAY; + } + + public get archivedThreads() { + return this.adapter.archivedThreads ?? EMPTY_ARRAY; + } + + private _mainThread: ExternalStoreThreadRuntimeCore; + + public get mainThread() { + return this._mainThread; + } + + constructor( + private adapter: ExternalStoreThreadManagerAdapter = {}, + private threadFactory: ExternalStoreThreadFactory, + ) { + this._mainThread = this.threadFactory(DEFAULT_THREAD_ID); + } + + public setAdapter(adapter: ExternalStoreThreadManagerAdapter) { + const previousAdapter = this.adapter; + this.adapter = adapter; + + const newThreadId = adapter.threadId ?? DEFAULT_THREAD_ID; + const newThreads = adapter.threads ?? EMPTY_ARRAY; + const newArchivedThreads = adapter.archivedThreads ?? EMPTY_ARRAY; + + if ( + previousAdapter.threadId === newThreadId && + previousAdapter.threads === newThreads && + previousAdapter.archivedThreads === newArchivedThreads + ) { + return; + } + + if (previousAdapter.threadId !== newThreadId) { + this._mainThread._notifyEventSubscribers("switched-away"); + this._mainThread = this.threadFactory(newThreadId); + this._mainThread._notifyEventSubscribers("switched-to"); + } + + this._notifySubscribers(); + } + + public switchToThread(threadId: string): void { + if (this._mainThread?.threadId === threadId) return; + const onSwitchToThread = this.adapter.onSwitchToThread; + if (!onSwitchToThread) + throw new Error( + "External store adapter does not support switching to thread", + ); + onSwitchToThread(threadId); + } + + public switchToNewThread(): void { + const onSwitchToNewThread = this.adapter.onSwitchToNewThread; + if (!onSwitchToNewThread) + throw new Error( + "External store adapter does not support switching to new thread", + ); + + onSwitchToNewThread(); + } + + public async rename(threadId: string, newTitle: string): Promise { + const onRename = this.adapter.onRename; + if (!onRename) + throw new Error("External store adapter does not support renaming"); + + onRename(threadId, newTitle); + } + + public async archive(threadId: string): Promise { + const onArchive = this.adapter.onArchive; + if (!onArchive) + throw new Error("External store adapter does not support archiving"); + + onArchive(threadId); + } + + public async unarchive(threadId: string): Promise { + const onUnarchive = this.adapter.onUnarchive; + if (!onUnarchive) + throw new Error("External store adapter does not support unarchiving"); + + onUnarchive(threadId); + } + + public async delete(threadId: string): Promise { + const onDelete = this.adapter.onDelete; + if (!onDelete) + throw new Error("External store adapter does not support deleting"); + + onDelete(threadId); + } + + private _subscriptions = new Set<() => void>(); + + public subscribe(callback: () => void): Unsubscribe { + this._subscriptions.add(callback); + return () => this._subscriptions.delete(callback); + } + + private _notifySubscribers() { + for (const callback of this._subscriptions) callback(); + } +} diff --git a/packages/react/src/runtimes/external-store/ExternalStoreThreadRuntimeCore.tsx b/packages/react/src/runtimes/external-store/ExternalStoreThreadRuntimeCore.tsx index d0ca5f9b2..8853d0185 100644 --- a/packages/react/src/runtimes/external-store/ExternalStoreThreadRuntimeCore.tsx +++ b/packages/react/src/runtimes/external-store/ExternalStoreThreadRuntimeCore.tsx @@ -10,7 +10,6 @@ import { ThreadMessageConverter } from "./ThreadMessageConverter"; import { getAutoStatus, isAutoStatus } from "./auto-status"; import { fromThreadMessageLike } from "./ThreadMessageLike"; import { getThreadMessageText } from "../../utils/getThreadMessageText"; -import { generateId } from "../../internal"; import { RuntimeCapabilities, ThreadRuntimeCore, @@ -67,7 +66,7 @@ export class ExternalStoreThreadRuntimeCore private _store!: ExternalStoreAdapter; public override beginEdit(messageId: string) { - if (!this.store.onEdit) + if (!this._store.onEdit) throw new Error("Runtime does not support editing."); super.beginEdit(messageId); @@ -75,20 +74,17 @@ export class ExternalStoreThreadRuntimeCore constructor( configProvider: ModelConfigProvider, + threadId: string, store: ExternalStoreAdapter, ) { super(configProvider); - this.store = store; + this.threadId = threadId; + this.setStore(store); } - public get store() { - return this._store; - } - - public set store(store: ExternalStoreAdapter) { + public setStore(store: ExternalStoreAdapter) { if (this._store === store) return; - this.threadId = store.threadId ?? this.threadId ?? generateId(); const isRunning = store.isRunning ?? false; this.isDisabled = store.isDisabled ?? false; @@ -103,8 +99,8 @@ export class ExternalStoreThreadRuntimeCore cancel: this._store.onCancel !== undefined, speech: this._store.adapters?.speech !== undefined, unstable_copy: this._store.unstable_capabilities?.copy !== false, // default true - attachments: !!this.store.adapters?.attachments, - feedback: !!this.store.adapters?.feedback, + attachments: !!this._store.adapters?.attachments, + feedback: !!this._store.adapters?.feedback, }; if (oldStore) { diff --git a/packages/react/src/runtimes/external-store/useExternalStoreRuntime.tsx b/packages/react/src/runtimes/external-store/useExternalStoreRuntime.tsx index a077603bf..54cd3f2a7 100644 --- a/packages/react/src/runtimes/external-store/useExternalStoreRuntime.tsx +++ b/packages/react/src/runtimes/external-store/useExternalStoreRuntime.tsx @@ -8,7 +8,7 @@ export const useExternalStoreRuntime = (store: ExternalStoreAdapter) => { const [runtime] = useState(() => new ExternalStoreRuntimeCore(store)); useEffect(() => { - runtime.thread.store = store; + runtime.setStore(store); }); return useMemo( diff --git a/packages/react/src/runtimes/local/LocalRuntimeCore.tsx b/packages/react/src/runtimes/local/LocalRuntimeCore.tsx index 8538eb110..e7f278fd9 100644 --- a/packages/react/src/runtimes/local/LocalRuntimeCore.tsx +++ b/packages/react/src/runtimes/local/LocalRuntimeCore.tsx @@ -1,41 +1,57 @@ -import { type ModelConfigProvider } from "../../types/ModelConfigTypes"; import type { CoreMessage } from "../../types/AssistantTypes"; import { BaseAssistantRuntimeCore } from "../core/BaseAssistantRuntimeCore"; -import type { ChatModelAdapter } from "./ChatModelAdapter"; -import { ProxyConfigProvider } from "../../internal"; import { LocalThreadRuntimeCore } from "./LocalThreadRuntimeCore"; -import { LocalRuntimeOptions } from "./LocalRuntimeOptions"; +import { LocalRuntimeOptionsBase } from "./LocalRuntimeOptions"; import { fromCoreMessages } from "../edge/converters/fromCoreMessage"; +import { LocalThreadManagerRuntimeCore } from "./LocalThreadManagerRuntimeCore"; +import { ExportedMessageRepository } from "../utils/MessageRepository"; -export class LocalRuntimeCore extends BaseAssistantRuntimeCore { - private readonly _proxyConfigProvider: ProxyConfigProvider; +const getExportFromInitialMessages = ( + initialMessages: readonly CoreMessage[], +): ExportedMessageRepository => { + const messages = fromCoreMessages(initialMessages); + return { + messages: messages.map((m, idx) => ({ + parentId: messages[idx - 1]?.id ?? null, + message: m, + })), + }; +}; - constructor(adapter: ChatModelAdapter, options: LocalRuntimeOptions) { - const proxyConfigProvider = new ProxyConfigProvider(); - super(new LocalThreadRuntimeCore(proxyConfigProvider, adapter, options)); - this._proxyConfigProvider = proxyConfigProvider; - } - public registerModelConfigProvider(provider: ModelConfigProvider) { - return this._proxyConfigProvider.registerModelConfigProvider(provider); - } +export class LocalRuntimeCore extends BaseAssistantRuntimeCore { + public readonly threadManager; - public switchToNewThread() { - const { initialMessages, ...options } = this.thread.options; + private _options: LocalRuntimeOptionsBase; - this.thread = new LocalThreadRuntimeCore( - this._proxyConfigProvider, - this.thread.adapter, - options, - ); - this.thread._notifyEventSubscribers("switched-to"); - } + constructor( + options: LocalRuntimeOptionsBase, + initialMessages?: CoreMessage[], + ) { + super(); + + this._options = options; - public switchToThread(threadId: string | null) { - if (threadId !== null) { - throw new Error("LocalRuntime does not yet support switching threads"); + this.threadManager = new LocalThreadManagerRuntimeCore((threadId, data) => { + const thread = new LocalThreadRuntimeCore( + this._proxyConfigProvider, + threadId, + this._options, + ); + thread.import(data); + return thread; + }); + + if (initialMessages) { + this.threadManager.mainThread.import( + getExportFromInitialMessages(initialMessages), + ); } + } + + public setOptions(options: LocalRuntimeOptionsBase) { + this._options = options; - this.switchToNewThread(); + this.threadManager.mainThread.setOptions(options); } public reset({ @@ -43,15 +59,11 @@ export class LocalRuntimeCore extends BaseAssistantRuntimeCore ({ - parentId: messages[idx - 1]?.id ?? null, - message: m, - })), - }); + this.threadManager.mainThread.import( + getExportFromInitialMessages(initialMessages), + ); } } diff --git a/packages/react/src/runtimes/local/LocalRuntimeOptions.tsx b/packages/react/src/runtimes/local/LocalRuntimeOptions.tsx index 901aa260e..3f2d9adc7 100644 --- a/packages/react/src/runtimes/local/LocalRuntimeOptions.tsx +++ b/packages/react/src/runtimes/local/LocalRuntimeOptions.tsx @@ -2,21 +2,26 @@ import type { CoreMessage } from "../../types"; import { AttachmentAdapter } from "../attachment/AttachmentAdapter"; import { FeedbackAdapter } from "../feedback/FeedbackAdapter"; import { SpeechSynthesisAdapter } from "../speech/SpeechAdapterTypes"; +import { ChatModelAdapter } from "./ChatModelAdapter"; -export type LocalRuntimeOptions = { - initialMessages?: readonly CoreMessage[] | undefined; +export type LocalRuntimeOptionsBase = { maxSteps?: number | undefined; /** * @deprecated Use `maxSteps` (which is `maxToolRoundtrips` + 1; if you set `maxToolRoundtrips` to 2, set `maxSteps` to 3) instead. This field will be removed in v0.6. */ maxToolRoundtrips?: number | undefined; - adapters?: - | { - attachments?: AttachmentAdapter | undefined; - speech?: SpeechSynthesisAdapter | undefined; - feedback?: FeedbackAdapter | undefined; - } - | undefined; + adapters: { + chatModel: ChatModelAdapter; + attachments?: AttachmentAdapter | undefined; + speech?: SpeechSynthesisAdapter | undefined; + feedback?: FeedbackAdapter | undefined; + }; +}; + +// TODO align LocalRuntimeOptions with LocalRuntimeOptionsBase +export type LocalRuntimeOptions = Omit & { + initialMessages?: readonly CoreMessage[] | undefined; + adapters?: Omit | undefined; }; export const splitLocalRuntimeOptions = ( diff --git a/packages/react/src/runtimes/local/LocalThreadManagerRuntimeCore.tsx b/packages/react/src/runtimes/local/LocalThreadManagerRuntimeCore.tsx new file mode 100644 index 000000000..22eea492d --- /dev/null +++ b/packages/react/src/runtimes/local/LocalThreadManagerRuntimeCore.tsx @@ -0,0 +1,161 @@ +import type { Unsubscribe } from "../../types"; +import { + ThreadManagerMetadata, + ThreadManagerRuntimeCore, +} from "../core/ThreadManagerRuntimeCore"; +import { ExportedMessageRepository } from "../utils/MessageRepository"; +import { generateId } from "../../utils/idUtils"; +import { LocalThreadRuntimeCore } from "./LocalThreadRuntimeCore"; + +export type LocalThreadData = { + data: ExportedMessageRepository; + metadata: ThreadManagerMetadata; + isArchived: boolean; +}; + +export type LocalThreadFactory = ( + threadId: string, + data: ExportedMessageRepository, +) => LocalThreadRuntimeCore; + +export class LocalThreadManagerRuntimeCore implements ThreadManagerRuntimeCore { + private _threadData = new Map(); + + private _threads: readonly ThreadManagerMetadata[] = []; + private _archivedThreads: readonly ThreadManagerMetadata[] = []; + + public get threads() { + return this._threads; + } + + public get archivedThreads() { + return this._archivedThreads; + } + + private _mainThread: LocalThreadRuntimeCore; + + public get mainThread(): LocalThreadRuntimeCore { + return this._mainThread; + } + + constructor(private _threadFactory: LocalThreadFactory) { + this._mainThread = this._threadFactory(generateId(), { messages: [] }); + } + + public switchToThread(threadId: string): void { + if (this._mainThread.threadId === threadId) return; + + const data = this._threadData.get(threadId); + if (!data) throw new Error("Thread not found"); + + const thread = this._threadFactory(threadId, data.data); + this._performThreadSwitch(thread); + } + + public switchToNewThread(): void { + if (!this._mainThread) return; + + const thread = this._threadFactory(generateId(), { messages: [] }); + this._performThreadSwitch(thread); + } + + private _performThreadSwitch(newThreadCore: LocalThreadRuntimeCore) { + if (this._mainThread) { + const data = this._threadData.get(this._mainThread.threadId); + if (!data) throw new Error("Thread not found"); + + const exprt = this._mainThread.export(); + data.data = exprt; + } + + this._mainThread._notifyEventSubscribers("switched-away"); + this._mainThread = newThreadCore; + newThreadCore._notifyEventSubscribers("switched-to"); + + this._notifySubscribers(); + } + + private _performMoveOp( + threadId: string, + operation: "archive" | "unarchive" | "delete", + ) { + const data = this._threadData.get(threadId); + if (!data) throw new Error("Thread not found"); + + if (operation === "archive" && data.isArchived) return; + if (operation === "unarchive" && !data.isArchived) return; + + if (operation === "archive") { + data.isArchived = true; + this._archivedThreads = [...this._archivedThreads, data.metadata]; + } + if (operation === "unarchive") { + data.isArchived = false; + this._threads = [...this._threads, data.metadata]; + } + if (operation === "delete") { + this._threadData.delete(threadId); + } + + if ( + operation === "archive" || + (operation === "delete" && data.isArchived) + ) { + this._archivedThreads = this._archivedThreads.filter( + (t) => t.threadId !== threadId, + ); + } + + if ( + operation === "unarchive" || + (operation === "delete" && !data.isArchived) + ) { + this._threads = this._threads.filter((t) => t.threadId !== threadId); + } + + this._notifySubscribers(); + } + + public async rename(threadId: string, newTitle: string): Promise { + const data = this._threadData.get(threadId); + if (!data) throw new Error("Thread not found"); + + data.metadata = { + ...data.metadata, + title: newTitle, + }; + + const threadList = data.isArchived ? this.archivedThreads : this.threads; + const idx = threadList.findIndex((t) => t.threadId === threadId); + const updatedThreadList = threadList.toSpliced(idx, 1, data.metadata); + if (data.isArchived) { + this._archivedThreads = updatedThreadList; + } else { + this._threads = updatedThreadList; + } + this._notifySubscribers(); + } + + public async archive(threadId: string): Promise { + this._performMoveOp(threadId, "archive"); + } + + public async unarchive(threadId: string): Promise { + this._performMoveOp(threadId, "unarchive"); + } + + public async delete(threadId: string): Promise { + this._performMoveOp(threadId, "delete"); + } + + private _subscriptions = new Set<() => void>(); + + public subscribe(callback: () => void): Unsubscribe { + this._subscriptions.add(callback); + return () => this._subscriptions.delete(callback); + } + + private _notifySubscribers() { + for (const callback of this._subscriptions) callback(); + } +} diff --git a/packages/react/src/runtimes/local/LocalThreadRuntimeCore.tsx b/packages/react/src/runtimes/local/LocalThreadRuntimeCore.tsx index ec631512e..5e4c2528f 100644 --- a/packages/react/src/runtimes/local/LocalThreadRuntimeCore.tsx +++ b/packages/react/src/runtimes/local/LocalThreadRuntimeCore.tsx @@ -4,10 +4,10 @@ import type { AppendMessage, ThreadAssistantMessage, } from "../../types"; -import { fromCoreMessage, fromCoreMessages } from "../edge"; -import type { ChatModelAdapter, ChatModelRunResult } from "./ChatModelAdapter"; +import { fromCoreMessage } from "../edge"; +import type { ChatModelRunResult } from "./ChatModelAdapter"; import { shouldContinue } from "./shouldContinue"; -import { LocalRuntimeOptions } from "./LocalRuntimeOptions"; +import { LocalRuntimeOptionsBase } from "./LocalRuntimeOptions"; import { AddToolResultOptions, ThreadSuggestion, @@ -32,44 +32,32 @@ export class LocalThreadRuntimeCore private abortController: AbortController | null = null; - public readonly threadId: string; public readonly isDisabled = false; public readonly suggestions: readonly ThreadSuggestion[] = []; public get adapters() { - return this.options.adapters; + return this._options.adapters; } constructor( configProvider: ModelConfigProvider, - public adapter: ChatModelAdapter, - { initialMessages, ...options }: LocalRuntimeOptions, + public readonly threadId: string, + options: LocalRuntimeOptionsBase, ) { super(configProvider); - this.threadId = generateId(); - this.options = options; - if (initialMessages) { - let parentId: string | null = null; - const messages = fromCoreMessages(initialMessages); - for (const message of messages) { - this.repository.addOrUpdateMessage(parentId, message); - parentId = message.id; - } - } + this._options = options; } - private _options!: LocalRuntimeOptions; - - public get options() { - return this._options; - } + private _options: LocalRuntimeOptionsBase; public get extras() { return undefined; } - public set options({ initialMessages, ...options }: LocalRuntimeOptions) { + public setOptions(options: LocalRuntimeOptionsBase) { + if (this._options === options) return; + this._options = options; let hasUpdates = false; @@ -174,9 +162,9 @@ export class LocalThreadRuntimeCore this._notifySubscribers(); }; - const maxSteps = this.options.maxSteps - ? this.options.maxSteps - : (this.options.maxToolRoundtrips ?? 1) + 1; + const maxSteps = this._options.maxSteps + ? this._options.maxSteps + : (this._options.maxToolRoundtrips ?? 1) + 1; const steps = message.metadata?.steps?.length ?? 0; if (steps >= maxSteps) { @@ -197,7 +185,7 @@ export class LocalThreadRuntimeCore } try { - const promiseOrGenerator = this.adapter.run({ + const promiseOrGenerator = this.adapters.chatModel.run({ messages, abortSignal: this.abortController.signal, config: this.getModelConfig(), diff --git a/packages/react/src/runtimes/local/useLocalRuntime.tsx b/packages/react/src/runtimes/local/useLocalRuntime.tsx index f7035fb80..a3b59bed4 100644 --- a/packages/react/src/runtimes/local/useLocalRuntime.tsx +++ b/packages/react/src/runtimes/local/useLocalRuntime.tsx @@ -1,6 +1,6 @@ "use client"; -import { useInsertionEffect, useMemo, useState } from "react"; +import { useEffect, useMemo, useState } from "react"; import type { ChatModelAdapter } from "./ChatModelAdapter"; import { LocalRuntimeCore } from "./LocalRuntimeCore"; import { LocalRuntimeOptions } from "./LocalRuntimeOptions"; @@ -39,11 +39,18 @@ export const useLocalRuntime = ( adapter: ChatModelAdapter, options: LocalRuntimeOptions = {}, ) => { - const [runtime] = useState(() => new LocalRuntimeCore(adapter, options)); - - useInsertionEffect(() => { - runtime.thread.adapter = adapter; - runtime.thread.options = options; + const opt = { + ...options, + adapters: { + ...options.adapters, + chatModel: adapter, + }, + }; + + const [runtime] = useState(() => new LocalRuntimeCore(opt)); + + useEffect(() => { + runtime.setOptions(opt); }); return useMemo(() => LocalRuntimeImpl.create(runtime), [runtime]);