diff --git a/.changeset/brown-clocks-dream.md b/.changeset/brown-clocks-dream.md new file mode 100644 index 000000000..b86452ff1 --- /dev/null +++ b/.changeset/brown-clocks-dream.md @@ -0,0 +1,5 @@ +--- +"@assistant-ui/react": patch +--- + +feat(runtimes/edge): dynamic model creator functions diff --git a/packages/react/src/runtimes/edge/EdgeRuntimeRequestOptions.ts b/packages/react/src/runtimes/edge/EdgeRuntimeRequestOptions.ts index de5d43417..3ca4065dc 100644 --- a/packages/react/src/runtimes/edge/EdgeRuntimeRequestOptions.ts +++ b/packages/react/src/runtimes/edge/EdgeRuntimeRequestOptions.ts @@ -1,7 +1,13 @@ import { CoreMessage } from "../../types"; import { LanguageModelV1FunctionTool } from "@ai-sdk/provider"; -export type EdgeRuntimeRequestOptions = { +export type LanguageModelConfig = { + apiKey?: string | undefined; + baseUrl?: string | undefined; + modelName?: string | undefined; +}; + +export type EdgeRuntimeRequestOptions = LanguageModelConfig & { system?: string | undefined; messages: CoreMessage[]; tools: LanguageModelV1FunctionTool[]; diff --git a/packages/react/src/runtimes/edge/createEdgeRuntimeAPI.ts b/packages/react/src/runtimes/edge/createEdgeRuntimeAPI.ts index e78b57396..8fd35e028 100644 --- a/packages/react/src/runtimes/edge/createEdgeRuntimeAPI.ts +++ b/packages/react/src/runtimes/edge/createEdgeRuntimeAPI.ts @@ -10,7 +10,10 @@ import { } from "@ai-sdk/provider"; import { CoreAssistantMessage, CoreMessage } from "../../types/AssistantTypes"; import { assistantEncoderStream } from "./streams/assistantEncoderStream"; -import { EdgeRuntimeRequestOptions } from "./EdgeRuntimeRequestOptions"; +import { + EdgeRuntimeRequestOptions, + LanguageModelConfig, +} from "./EdgeRuntimeRequestOptions"; import { toLanguageModelMessages } from "./converters/toLanguageModelMessages"; import { z } from "zod"; import { Tool } from "../../types"; @@ -53,8 +56,12 @@ type FinishResult = { | undefined; }; +type LanguageModelCreator = ( + config: LanguageModelConfig, +) => Promise | LanguageModelV1; + type CreateEdgeRuntimeAPIOptions = LanguageModelSettings & { - model: LanguageModelV1; + model: LanguageModelV1 | LanguageModelCreator; system?: string; tools?: Record>; toolChoice?: LanguageModelV1ToolChoice; @@ -70,7 +77,7 @@ const voidStream = () => { }; export const createEdgeRuntimeAPI = ({ - model, + model: modelOrCreator, system: serverSystem, tools: serverTools = {}, toolChoice, @@ -86,6 +93,9 @@ export const createEdgeRuntimeAPI = ({ system: clientSystem, tools: clientTools, messages, + apiKey, + baseUrl, + modelName, } = (await request.json()) as EdgeRuntimeRequestOptions; const systemMessages = []; @@ -101,6 +111,11 @@ export const createEdgeRuntimeAPI = ({ } } + const model = + typeof modelOrCreator === "function" + ? await modelOrCreator({ apiKey, baseUrl, modelName }) + : modelOrCreator; + let stream: ReadableStream; const streamResult = await streamMessage({ ...(settings as Partial),