From 179e8699d9b8e7df7f2353bc76f356faeef315a0 Mon Sep 17 00:00:00 2001 From: Simon Farshid Date: Thu, 27 Jun 2024 00:44:03 -0700 Subject: [PATCH] fix: reimplement ProxyConfigProvider in AI SDK (#344) --- .../react-ai-sdk/src/rsc/VercelRSCRuntime.tsx | 15 +++++++++------ .../react-ai-sdk/src/ui/VercelAIRuntime.tsx | 18 ++++++++++-------- packages/react-ui/components/ui/tooltip.tsx | 4 +--- 3 files changed, 20 insertions(+), 17 deletions(-) diff --git a/packages/react-ai-sdk/src/rsc/VercelRSCRuntime.tsx b/packages/react-ai-sdk/src/rsc/VercelRSCRuntime.tsx index 711d3062b..b8a73ce25 100644 --- a/packages/react-ai-sdk/src/rsc/VercelRSCRuntime.tsx +++ b/packages/react-ai-sdk/src/rsc/VercelRSCRuntime.tsx @@ -11,6 +11,7 @@ import { type StoreApi, type UseBoundStore, create } from "zustand"; import type { VercelRSCAdapter } from "./VercelRSCAdapter"; import type { VercelRSCMessage } from "./VercelRSCMessage"; import { useVercelRSCSync } from "./useVercelRSCSync"; +import { ModelConfigProvider } from "@assistant-ui/react"; const { ProxyConfigProvider, BaseAssistantRuntime } = INTERNAL; @@ -19,6 +20,8 @@ const EMPTY_BRANCHES: readonly string[] = Object.freeze([]); export class VercelRSCRuntime< T extends WeakKey = VercelRSCMessage, > extends BaseAssistantRuntime> { + private readonly _proxyConfigProvider = new ProxyConfigProvider(); + constructor(adapter: VercelRSCAdapter) { super(new VercelRSCThreadRuntime(adapter)); } @@ -31,9 +34,12 @@ export class VercelRSCRuntime< return this.thread.onAdapterUpdated(); } - public registerModelConfigProvider() { - // no-op - return () => {}; + public getModelConfig() { + return this._proxyConfigProvider.getModelConfig(); + } + + public registerModelConfigProvider(provider: ModelConfigProvider) { + return this._proxyConfigProvider.registerModelConfigProvider(provider); } public switchToThread() { @@ -42,7 +48,6 @@ export class VercelRSCRuntime< } class VercelRSCThreadRuntime - extends ProxyConfigProvider implements ReactThreadRuntime { private useAdapter: UseBoundStore }>>; @@ -53,8 +58,6 @@ class VercelRSCThreadRuntime public messages: ThreadMessage[] = []; constructor(public adapter: VercelRSCAdapter) { - super(); - this.useAdapter = create(() => ({ adapter, })); diff --git a/packages/react-ai-sdk/src/ui/VercelAIRuntime.tsx b/packages/react-ai-sdk/src/ui/VercelAIRuntime.tsx index 3185d0575..65791c5c0 100644 --- a/packages/react-ai-sdk/src/ui/VercelAIRuntime.tsx +++ b/packages/react-ai-sdk/src/ui/VercelAIRuntime.tsx @@ -12,6 +12,7 @@ import type { VercelHelpers } from "./utils/VercelHelpers"; import { sliceMessagesUntil } from "./utils/sliceMessagesUntil"; import { useVercelAIComposerSync } from "./utils/useVercelAIComposerSync"; import { useVercelAIThreadSync } from "./utils/useVercelAIThreadSync"; +import { ModelConfigProvider } from "@assistant-ui/react"; const { ProxyConfigProvider, MessageRepository, BaseAssistantRuntime } = INTERNAL; @@ -21,6 +22,8 @@ const hasUpcomingMessage = (isRunning: boolean, messages: ThreadMessage[]) => { }; export class VercelAIRuntime extends BaseAssistantRuntime { + private readonly _proxyConfigProvider = new ProxyConfigProvider(); + constructor(vercel: VercelHelpers) { super(new VercelAIThreadRuntime(vercel)); } @@ -33,9 +36,12 @@ export class VercelAIRuntime extends BaseAssistantRuntime return this.thread.onVercelUpdated(); } - public registerModelConfigProvider() { - // no-op - return () => {}; + public getModelConfig() { + return this._proxyConfigProvider.getModelConfig(); + } + + public registerModelConfigProvider(provider: ModelConfigProvider) { + return this._proxyConfigProvider.registerModelConfigProvider(provider); } public switchToThread(threadId: string | null) { @@ -53,10 +59,7 @@ export class VercelAIRuntime extends BaseAssistantRuntime } } -class VercelAIThreadRuntime - extends ProxyConfigProvider - implements ReactThreadRuntime -{ +class VercelAIThreadRuntime implements ReactThreadRuntime { private _subscriptions = new Set<() => void>(); private repository = new MessageRepository(); private assistantOptimisticId: string | null = null; @@ -67,7 +70,6 @@ class VercelAIThreadRuntime public isRunning = false; constructor(public vercel: VercelHelpers) { - super(); this.useVercel = create(() => ({ vercel, })); diff --git a/packages/react-ui/components/ui/tooltip.tsx b/packages/react-ui/components/ui/tooltip.tsx index 4ab0847e4..0f3650022 100644 --- a/packages/react-ui/components/ui/tooltip.tsx +++ b/packages/react-ui/components/ui/tooltip.tsx @@ -5,8 +5,6 @@ import * as TooltipPrimitive from "@radix-ui/react-tooltip"; import { cn } from "@/lib/utils"; -const TooltipProvider = TooltipPrimitive.Provider; - const Tooltip = TooltipPrimitive.Root; const TooltipTrigger = TooltipPrimitive.Trigger; @@ -27,4 +25,4 @@ const TooltipContent = React.forwardRef< )); TooltipContent.displayName = TooltipPrimitive.Content.displayName; -export { Tooltip, TooltipTrigger, TooltipContent, TooltipProvider }; +export { Tooltip, TooltipTrigger, TooltipContent };