diff --git a/.changeset/nervous-students-unite.md b/.changeset/nervous-students-unite.md new file mode 100644 index 000000000..187d74d54 --- /dev/null +++ b/.changeset/nervous-students-unite.md @@ -0,0 +1,5 @@ +--- +"@assistant-ui/react": patch +--- + +feat: useThreadModelConfig API diff --git a/packages/react/src/api/AssistantRuntime.ts b/packages/react/src/api/AssistantRuntime.ts index f17bef00e..7fc7511be 100644 --- a/packages/react/src/api/AssistantRuntime.ts +++ b/packages/react/src/api/AssistantRuntime.ts @@ -56,7 +56,6 @@ export class AssistantRuntimeImpl return this._core.registerModelConfigProvider(provider); } - // TODO events for thread switching /** * @deprecated Thread is now static and never gets updated. This will be removed in 0.6.0. */ diff --git a/packages/react/src/api/ThreadRuntime.ts b/packages/react/src/api/ThreadRuntime.ts index a458a6bea..a67ac2165 100644 --- a/packages/react/src/api/ThreadRuntime.ts +++ b/packages/react/src/api/ThreadRuntime.ts @@ -6,6 +6,7 @@ import { ThreadRuntimeCore, SpeechState, SubmittedFeedback, + ThreadRuntimeEventType, } from "../runtimes/core/ThreadRuntimeCore"; import { ExportedMessageRepository } from "../runtimes/utils/MessageRepository"; import { @@ -133,10 +134,7 @@ export type ThreadRuntime = { */ stopSpeaking: () => void; - unstable_on( - event: "switched-to" | "run-start", - callback: () => void, - ): Unsubscribe; + unstable_on(event: ThreadRuntimeEventType, callback: () => void): Unsubscribe; // Legacy methods with deprecations @@ -501,7 +499,7 @@ export class ThreadRuntimeImpl >(); public unstable_on( - event: "switched-to" | "run-start", + event: ThreadRuntimeEventType, callback: () => void, ): Unsubscribe { let subject = this._eventListenerNestedSubscriptions.get(event); diff --git a/packages/react/src/context/react/ThreadContext.ts b/packages/react/src/context/react/ThreadContext.ts index 8450d1955..079a57e0c 100644 --- a/packages/react/src/context/react/ThreadContext.ts +++ b/packages/react/src/context/react/ThreadContext.ts @@ -1,6 +1,6 @@ "use client"; -import { createContext } from "react"; +import { createContext, useEffect, useState } from "react"; import type { ThreadViewportState } from "../stores/ThreadViewport"; import { ReadonlyStore } from "../ReadonlyStore"; import { UseBoundStore } from "zustand"; @@ -8,7 +8,7 @@ import { createContextHook } from "./utils/createContextHook"; import { createContextStoreHook } from "./utils/createContextStoreHook"; import { ThreadRuntime } from "../../api"; import { ThreadState } from "../../api/ThreadRuntime"; -import { ThreadMessage } from "../../types"; +import { ModelConfig, ThreadMessage } from "../../types"; import { ThreadComposerState } from "../../api/ComposerRuntime"; export type ThreadContextValue = { @@ -88,3 +88,23 @@ export const { useViewport: useThreadViewport, useViewportStore: useThreadViewportStore, } = createContextStoreHook(useThreadContext, "useViewport"); + +export function useThreadModelConfig(options?: { + optional?: false | undefined; +}): ModelConfig; +export function useThreadModelConfig(options?: { + optional?: boolean | undefined; +}): ModelConfig | null; +export function useThreadModelConfig(options?: { + optional?: boolean | undefined; +}): ModelConfig | null { + const [, rerender] = useState({}); + + const runtime = useThreadRuntime(options); + useEffect(() => { + return runtime?.unstable_on("model-config-update", () => rerender({})); + }, [runtime]); + + if (!runtime) return null; + return runtime?.getModelConfig(); +} diff --git a/packages/react/src/context/react/index.ts b/packages/react/src/context/react/index.ts index 3e5841dbb..25f23218a 100644 --- a/packages/react/src/context/react/index.ts +++ b/packages/react/src/context/react/index.ts @@ -30,6 +30,7 @@ export { useThreadRuntime, useThread, useThreadComposer, + useThreadModelConfig, /** * @deprecated Use `useThread().messages` instead. This will be removed in 0.6.0. diff --git a/packages/react/src/runtimes/core/BaseThreadRuntimeCore.tsx b/packages/react/src/runtimes/core/BaseThreadRuntimeCore.tsx index fb61ef438..f405bc3ff 100644 --- a/packages/react/src/runtimes/core/BaseThreadRuntimeCore.tsx +++ b/packages/react/src/runtimes/core/BaseThreadRuntimeCore.tsx @@ -16,6 +16,7 @@ import { SpeechState, RuntimeCapabilities, SubmittedFeedback, + ThreadRuntimeEventType, } from "../core/ThreadRuntimeCore"; import { DefaultEditComposerRuntimeCore } from "../composer/DefaultEditComposerRuntimeCore"; import { SpeechSynthesisAdapter } from "../speech"; @@ -52,7 +53,11 @@ export abstract class BaseThreadRuntimeCore implements ThreadRuntimeCore { public readonly composer = new DefaultThreadComposerRuntimeCore(this); - constructor(private configProvider: ModelConfigProvider) {} + constructor(private configProvider: ModelConfigProvider) { + this.configProvider.subscribe?.(() => { + this._notifyEventSubscribers("model-config-update"); + }); + } public getModelConfig() { return this.configProvider.getModelConfig(); @@ -94,7 +99,7 @@ export abstract class BaseThreadRuntimeCore implements ThreadRuntimeCore { for (const callback of this._subscriptions) callback(); } - public _notifyEventSubscribers(event: "switched-to" | "run-start") { + public _notifyEventSubscribers(event: ThreadRuntimeEventType) { const subscribers = this._eventSubscribers.get(event); if (!subscribers) return; @@ -173,7 +178,7 @@ export abstract class BaseThreadRuntimeCore implements ThreadRuntimeCore { private _eventSubscribers = new Map void>>(); - public unstable_on(event: "switched-to" | "run-start", callback: () => void) { + public unstable_on(event: ThreadRuntimeEventType, callback: () => void) { const subscribers = this._eventSubscribers.get(event); if (!subscribers) { this._eventSubscribers.set(event, new Set([callback])); diff --git a/packages/react/src/runtimes/core/ThreadRuntimeCore.tsx b/packages/react/src/runtimes/core/ThreadRuntimeCore.tsx index edc04e2db..4e739193b 100644 --- a/packages/react/src/runtimes/core/ThreadRuntimeCore.tsx +++ b/packages/react/src/runtimes/core/ThreadRuntimeCore.tsx @@ -43,6 +43,11 @@ export type SubmittedFeedback = Readonly<{ type: "negative" | "positive"; }>; +export type ThreadRuntimeEventType = + | "switched-to" + | "run-start" + | "model-config-update"; + export type ThreadRuntimeCore = Readonly<{ getMessageById: (messageId: string) => | { @@ -86,8 +91,5 @@ export type ThreadRuntimeCore = Readonly<{ import(repository: ExportedMessageRepository): void; export(): ExportedMessageRepository; - unstable_on( - event: "switched-to" | "run-start", - callback: () => void, - ): Unsubscribe; + unstable_on(event: ThreadRuntimeEventType, callback: () => void): Unsubscribe; }>; diff --git a/packages/react/src/types/ModelConfigTypes.ts b/packages/react/src/types/ModelConfigTypes.ts index 33ea6b40c..90e0efb5f 100644 --- a/packages/react/src/types/ModelConfigTypes.ts +++ b/packages/react/src/types/ModelConfigTypes.ts @@ -1,5 +1,6 @@ import { z } from "zod"; import type { JSONSchema7 } from "json-schema"; +import { Unsubscribe } from "./Unsubscribe"; export const LanguageModelV1CallSettingsSchema = z.object({ maxTokens: z.number().int().positive().optional(), @@ -47,7 +48,10 @@ export type ModelConfig = { config?: LanguageModelConfig | undefined; }; -export type ModelConfigProvider = { getModelConfig: () => ModelConfig }; +export type ModelConfigProvider = { + getModelConfig: () => ModelConfig; + subscribe?: (callback: () => void) => Unsubscribe; +}; export const mergeModelConfigs = ( configSet: Set, diff --git a/packages/react/src/utils/ProxyConfigProvider.ts b/packages/react/src/utils/ProxyConfigProvider.ts index 79d6075b3..26f81788c 100644 --- a/packages/react/src/utils/ProxyConfigProvider.ts +++ b/packages/react/src/utils/ProxyConfigProvider.ts @@ -1,4 +1,3 @@ -"use client"; import { type ModelConfigProvider, mergeModelConfigs, @@ -13,8 +12,25 @@ export class ProxyConfigProvider implements ModelConfigProvider { registerModelConfigProvider(provider: ModelConfigProvider) { this._providers.add(provider); + const unsubscribe = provider.subscribe?.(() => { + this.notifySubscribers(); + }); + this.notifySubscribers(); return () => { this._providers.delete(provider); + unsubscribe?.(); + this.notifySubscribers(); }; } + + private _subscribers = new Set<() => void>(); + + notifySubscribers() { + for (const callback of this._subscribers) callback(); + } + + subscribe(callback: () => void) { + this._subscribers.add(callback); + return () => this._subscribers.delete(callback); + } }