diff --git a/.changeset/brave-laws-exist.md b/.changeset/brave-laws-exist.md new file mode 100644 index 000000000..4d0c0face --- /dev/null +++ b/.changeset/brave-laws-exist.md @@ -0,0 +1,7 @@ +--- +"@assistant-ui/react-playground": patch +"@assistant-ui/react-ai-sdk": patch +"@assistant-ui/react": patch +--- + +feat(runtime): BranchPicker feature detection diff --git a/packages/react-ai-sdk/src/ui/use-assistant/VercelUseAssistantThreadRuntime.tsx b/packages/react-ai-sdk/src/ui/use-assistant/VercelUseAssistantThreadRuntime.tsx index 664b0b462..8806d8a6d 100644 --- a/packages/react-ai-sdk/src/ui/use-assistant/VercelUseAssistantThreadRuntime.tsx +++ b/packages/react-ai-sdk/src/ui/use-assistant/VercelUseAssistantThreadRuntime.tsx @@ -13,6 +13,7 @@ import { hasUpcomingMessage } from "./VercelUseAssistantRuntime"; const EMPTY_BRANCHES: readonly string[] = Object.freeze([]); const CAPABILITIES = Object.freeze({ + switchToBranch: false, edit: false, reload: false, cancel: false, diff --git a/packages/react-ai-sdk/src/ui/use-chat/VercelUseChatThreadRuntime.tsx b/packages/react-ai-sdk/src/ui/use-chat/VercelUseChatThreadRuntime.tsx index 10384516d..571ee7148 100644 --- a/packages/react-ai-sdk/src/ui/use-chat/VercelUseChatThreadRuntime.tsx +++ b/packages/react-ai-sdk/src/ui/use-chat/VercelUseChatThreadRuntime.tsx @@ -24,6 +24,7 @@ export const hasUpcomingMessage = ( }; const CAPABILITIES = Object.freeze({ + switchToBranch: true, edit: true, reload: true, cancel: true, diff --git a/packages/react-playground/src/lib/playground-runtime.ts b/packages/react-playground/src/lib/playground-runtime.ts index 992f2d583..64114e598 100644 --- a/packages/react-playground/src/lib/playground-runtime.ts +++ b/packages/react-playground/src/lib/playground-runtime.ts @@ -73,6 +73,7 @@ class PlaygroundRuntime } const CAPABILITIES = Object.freeze({ + switchToBranch: false, edit: false, reload: false, cancel: true, diff --git a/packages/react/src/context/stores/ThreadActions.ts b/packages/react/src/context/stores/ThreadActions.ts index a2faac0f4..4b690c890 100644 --- a/packages/react/src/context/stores/ThreadActions.ts +++ b/packages/react/src/context/stores/ThreadActions.ts @@ -9,13 +9,17 @@ export type AddToolResultOptions = { result: any; }; +export type RuntimeCapabilities = { + switchToBranch: boolean; + edit: boolean; + reload: boolean; + cancel: boolean; + copy: boolean; +}; + export type ThreadActionsState = Readonly<{ - capabilities: Readonly<{ - edit: boolean; - reload: boolean; - cancel: boolean; - copy: boolean; - }>; + capabilities: Readonly; + getBranches: (messageId: string) => readonly string[]; switchToBranch: (branchId: string) => void; diff --git a/packages/react/src/runtimes/external-store/ExternalStoreThreadRuntime.tsx b/packages/react/src/runtimes/external-store/ExternalStoreThreadRuntime.tsx index 154a8e7ba..0033dbf3e 100644 --- a/packages/react/src/runtimes/external-store/ExternalStoreThreadRuntime.tsx +++ b/packages/react/src/runtimes/external-store/ExternalStoreThreadRuntime.tsx @@ -25,6 +25,7 @@ export class ExternalStoreThreadRuntime implements ReactThreadRuntime { public get capabilities() { return { + switchToBranch: this.store.setMessages !== undefined, edit: this.store.onEdit !== undefined, reload: this.store.onReload !== undefined, cancel: this.store.onCancel !== undefined, @@ -46,6 +47,9 @@ export class ExternalStoreThreadRuntime implements ReactThreadRuntime { } public switchToBranch(branchId: string): void { + if (!this.store.setMessages) + throw new Error("Runtime does not support switching branches."); + this.repository.switchToBranch(branchId); this.updateMessages(this.repository.getMessages()); } diff --git a/packages/react/src/runtimes/local/LocalThreadRuntime.tsx b/packages/react/src/runtimes/local/LocalThreadRuntime.tsx index d5be43e1a..d915050a5 100644 --- a/packages/react/src/runtimes/local/LocalThreadRuntime.tsx +++ b/packages/react/src/runtimes/local/LocalThreadRuntime.tsx @@ -15,6 +15,7 @@ import { shouldContinue } from "./shouldContinue"; import { LocalRuntimeOptions } from "./LocalRuntimeOptions"; const CAPABILITIES = Object.freeze({ + switchToBranch: true, edit: true, reload: true, cancel: true, diff --git a/packages/react/src/ui/branch-picker.tsx b/packages/react/src/ui/branch-picker.tsx index ac2cdad6f..3f10ca2af 100644 --- a/packages/react/src/ui/branch-picker.tsx +++ b/packages/react/src/ui/branch-picker.tsx @@ -10,8 +10,18 @@ import { import { withDefaults } from "./utils/withDefaults"; import { useThreadConfig } from "./thread-config"; import { BranchPickerPrimitive } from "../primitives"; +import { useThreadContext } from "../context"; + +const useAllowBranchPicker = () => { + const { branchPicker: { allowBranchPicker = true } = {} } = useThreadConfig(); + const { useThreadActions } = useThreadContext(); + const branchPickerSupported = useThreadActions((t) => t.capabilities.edit); + return branchPickerSupported && allowBranchPicker; +}; const BranchPicker: FC = () => { + const allowBranchPicker = useAllowBranchPicker(); + if (!allowBranchPicker) return null; return ( diff --git a/packages/react/src/ui/thread-config.tsx b/packages/react/src/ui/thread-config.tsx index c939bd60a..251b7530f 100644 --- a/packages/react/src/ui/thread-config.tsx +++ b/packages/react/src/ui/thread-config.tsx @@ -37,6 +37,10 @@ export type AssistantMessageConfig = { | undefined; }; +export type BranchPickerConfig = { + allowBranchPicker?: boolean | undefined; +}; + export type StringsConfig = { assistantModal?: { open: { @@ -119,6 +123,8 @@ export type ThreadConfig = { assistantMessage?: AssistantMessageConfig; userMessage?: UserMessageConfig; + branchPicker?: BranchPickerConfig; + strings?: StringsConfig; };