diff --git a/packages/react/src/runtime/core/BaseAssistantRuntime.tsx b/packages/react/src/runtime/core/BaseAssistantRuntime.tsx index 9b69ee2ae4..841b56cb4f 100644 --- a/packages/react/src/runtime/core/BaseAssistantRuntime.tsx +++ b/packages/react/src/runtime/core/BaseAssistantRuntime.tsx @@ -1,13 +1,29 @@ +import { ReactThreadRuntime } from "../../../dist"; import type { AppendMessage } from "../../types/AssistantTypes"; import { type ModelConfigProvider } from "../../types/ModelConfigTypes"; import type { Unsubscribe } from "../../types/Unsubscribe"; import type { AssistantRuntime } from "./AssistantRuntime"; -import { ThreadRuntime } from "./ThreadRuntime"; -export abstract class BaseAssistantRuntime - implements AssistantRuntime +export abstract class BaseAssistantRuntime< + TThreadRuntime extends ReactThreadRuntime, +> implements AssistantRuntime { - constructor(protected thread: TThreadRuntime) {} + constructor(private _thread: TThreadRuntime) { + this._thread = _thread; + this._unsubscribe = this._thread.subscribe(this.subscriptionHandler); + } + + private _unsubscribe: Unsubscribe; + + get thread() { + return this._thread; + } + + set thread(thread: TThreadRuntime) { + this._unsubscribe(); + this._thread = thread; + this._unsubscribe = this._thread.subscribe(this.subscriptionHandler); + } public abstract registerModelConfigProvider( provider: ModelConfigProvider, @@ -46,7 +62,18 @@ export abstract class BaseAssistantRuntime return this.thread.addToolResult(toolCallId, result); } + private _subscriptions = new Set<() => void>(); + public subscribe(callback: () => void): Unsubscribe { - return this.thread.subscribe(callback); + this._subscriptions.add(callback); + return () => this._subscriptions.delete(callback); + } + + private subscriptionHandler = () => { + for (const callback of this._subscriptions) callback(); + }; + + public get unstable_synchronizer() { + return this.thread.unstable_synchronizer; } }