From 378ee997f552928f46d96bba44ecbb01580ffb73 Mon Sep 17 00:00:00 2001 From: Simon Farshid Date: Tue, 8 Oct 2024 18:29:59 -0700 Subject: [PATCH] feat: rename maxToolRoundtrips to maxSteps (#955) --- .changeset/shiny-snakes-float.md | 5 ++ .changeset/tidy-fishes-hammer.md | 5 ++ .../with-ffmpeg/app/MyRuntimeProvider.tsx | 2 +- .../react-playground/src/lib/converters.ts | 2 +- .../src/lib/playground-runtime.ts | 1 + .../converters/fromLanguageModelMessages.ts | 6 +-- .../src/runtimes/edge/createEdgeRuntimeAPI.ts | 12 +++-- .../edge/streams/AssistantStreamChunkType.ts | 29 ++++++++++- .../edge/streams/assistantDecoderStream.ts | 33 +++++++++++++ .../edge/streams/assistantEncoderStream.ts | 9 ++++ .../runtimes/edge/streams/runResultStream.ts | 48 +++++++++++++++---- .../runtimes/edge/streams/toolResultStream.ts | 17 +++++++ .../src/runtimes/local/ChatModelAdapter.tsx | 9 +++- .../runtimes/local/LocalRuntimeOptions.tsx | 15 +++++- .../runtimes/local/LocalThreadRuntimeCore.tsx | 36 ++++++-------- packages/react/src/types/AssistantTypes.ts | 20 ++++++-- 16 files changed, 202 insertions(+), 47 deletions(-) create mode 100644 .changeset/shiny-snakes-float.md create mode 100644 .changeset/tidy-fishes-hammer.md diff --git a/.changeset/shiny-snakes-float.md b/.changeset/shiny-snakes-float.md new file mode 100644 index 000000000..243a4c602 --- /dev/null +++ b/.changeset/shiny-snakes-float.md @@ -0,0 +1,5 @@ +--- +"@assistant-ui/react": patch +--- + +refactor: rename maxToolRoundtrips to maxSteps diff --git a/.changeset/tidy-fishes-hammer.md b/.changeset/tidy-fishes-hammer.md new file mode 100644 index 000000000..b72cea55b --- /dev/null +++ b/.changeset/tidy-fishes-hammer.md @@ -0,0 +1,5 @@ +--- +"@assistant-ui/react": patch +--- + +feat: server-side tool roundtrips support diff --git a/examples/with-ffmpeg/app/MyRuntimeProvider.tsx b/examples/with-ffmpeg/app/MyRuntimeProvider.tsx index 7ef8b77c1..deab4cd81 100644 --- a/examples/with-ffmpeg/app/MyRuntimeProvider.tsx +++ b/examples/with-ffmpeg/app/MyRuntimeProvider.tsx @@ -42,7 +42,7 @@ export function MyRuntimeProvider({ }>) { const runtime = useEdgeRuntime({ api: "/api/chat", - maxToolRoundtrips: 3, + maxSteps: 4, adapters: { attachments: attachmentAdapter, }, diff --git a/packages/react-playground/src/lib/converters.ts b/packages/react-playground/src/lib/converters.ts index af2686df2..832f3658d 100644 --- a/packages/react-playground/src/lib/converters.ts +++ b/packages/react-playground/src/lib/converters.ts @@ -76,7 +76,7 @@ const threadMessagesFromOpenAI = ( messages: OpenAI.ChatCompletionMessageParam[], ) => { const lms = fromOpenAIMessages(messages); - return fromLanguageModelMessages(lms, { mergeRoundtrips: false }); + return fromLanguageModelMessages(lms, { mergeSteps: false }); }; const threadMessagesToOpenAI = ( diff --git a/packages/react-playground/src/lib/playground-runtime.ts b/packages/react-playground/src/lib/playground-runtime.ts index 044c4384a..0951eab58 100644 --- a/packages/react-playground/src/lib/playground-runtime.ts +++ b/packages/react-playground/src/lib/playground-runtime.ts @@ -522,6 +522,7 @@ class PlaygroundThreadRuntime extends ThreadRuntimeImpl { export const usePlaygroundRuntime = ({ initialMessages, maxToolRoundtrips, + maxSteps, ...runtimeOptions }: EdgeRuntimeOptions & { initialMessages: CoreMessage[]; diff --git a/packages/react/src/runtimes/edge/converters/fromLanguageModelMessages.ts b/packages/react/src/runtimes/edge/converters/fromLanguageModelMessages.ts index 8e84fcd19..9a2fa69f6 100644 --- a/packages/react/src/runtimes/edge/converters/fromLanguageModelMessages.ts +++ b/packages/react/src/runtimes/edge/converters/fromLanguageModelMessages.ts @@ -2,12 +2,12 @@ import { LanguageModelV1Message } from "@ai-sdk/provider"; import { CoreMessage, ToolCallContentPart } from "../../../types"; type fromLanguageModelMessagesOptions = { - mergeRoundtrips: boolean; + mergeSteps: boolean; }; export const fromLanguageModelMessages = ( lm: LanguageModelV1Message[], - { mergeRoundtrips }: fromLanguageModelMessagesOptions, + { mergeSteps }: fromLanguageModelMessagesOptions, ): CoreMessage[] => { const messages: CoreMessage[] = []; @@ -74,7 +74,7 @@ export const fromLanguageModelMessages = ( return part; }); - if (mergeRoundtrips) { + if (mergeSteps) { const previousMessage = messages[messages.length - 1]; if (previousMessage?.role === "assistant") { previousMessage.content.push(...newContent); diff --git a/packages/react/src/runtimes/edge/createEdgeRuntimeAPI.ts b/packages/react/src/runtimes/edge/createEdgeRuntimeAPI.ts index 704a0304f..575ec1216 100644 --- a/packages/react/src/runtimes/edge/createEdgeRuntimeAPI.ts +++ b/packages/react/src/runtimes/edge/createEdgeRuntimeAPI.ts @@ -8,7 +8,7 @@ import { import { CoreMessage, ThreadMessage, - ThreadRoundtrip, + ThreadStep, } from "../../types/AssistantTypes"; import { assistantEncoderStream } from "./streams/assistantEncoderStream"; import { EdgeRuntimeRequestOptionsSchema } from "./EdgeRuntimeRequestOptions"; @@ -33,7 +33,11 @@ import { z } from "zod"; type FinishResult = { messages: CoreMessage[]; metadata: { - roundtrips: ThreadRoundtrip[]; + /** + * @deprecated Use `steps` instead. This field will be removed in v0.6. + */ + roundtrips: ThreadStep[]; + steps: ThreadStep[]; }; }; @@ -157,7 +161,9 @@ export const getEdgeRuntimeStream = async ({ metadata: { // TODO // eslint-disable-next-line @typescript-eslint/no-non-null-asserted-optional-chain - roundtrips: lastChunk.metadata?.roundtrips!, + roundtrips: lastChunk.metadata?.steps!, + // eslint-disable-next-line @typescript-eslint/no-non-null-asserted-optional-chain + steps: lastChunk.metadata?.steps!, }, }); }, diff --git a/packages/react/src/runtimes/edge/streams/AssistantStreamChunkType.ts b/packages/react/src/runtimes/edge/streams/AssistantStreamChunkType.ts index 94915cbca..115866344 100644 --- a/packages/react/src/runtimes/edge/streams/AssistantStreamChunkType.ts +++ b/packages/react/src/runtimes/edge/streams/AssistantStreamChunkType.ts @@ -2,15 +2,24 @@ import { LanguageModelV1StreamPart } from "@ai-sdk/provider"; export enum AssistantStreamChunkType { TextDelta = "0", + Data = "2", + Error = "3", + ToolCall = "9", + ToolCallResult = "a", ToolCallBegin = "b", ToolCallDelta = "c", - ToolCallResult = "a", - Error = "3", Finish = "d", + StepFinish = "e", } export type AssistantStreamChunk = { [AssistantStreamChunkType.TextDelta]: string; + [AssistantStreamChunkType.Data]: unknown; + [AssistantStreamChunkType.ToolCall]: { + toolCallId: string; + toolName: string; + args: unknown; + }; [AssistantStreamChunkType.ToolCallBegin]: { toolCallId: string; toolName: string; @@ -24,6 +33,21 @@ export type AssistantStreamChunk = { result: any; }; [AssistantStreamChunkType.Error]: unknown; + [AssistantStreamChunkType.StepFinish]: { + finishReason: + | "stop" + | "length" + | "content-filter" + | "tool-calls" + | "error" + | "other" + | "unknown"; + usage: { + promptTokens: number; + completionTokens: number; + }; + isContinued: boolean; + }; [AssistantStreamChunkType.Finish]: Omit< LanguageModelV1StreamPart & { type: "finish"; @@ -31,3 +55,4 @@ export type AssistantStreamChunk = { "type" >; }; + \ No newline at end of file diff --git a/packages/react/src/runtimes/edge/streams/assistantDecoderStream.ts b/packages/react/src/runtimes/edge/streams/assistantDecoderStream.ts index a67c596b9..af497233d 100644 --- a/packages/react/src/runtimes/edge/streams/assistantDecoderStream.ts +++ b/packages/react/src/runtimes/edge/streams/assistantDecoderStream.ts @@ -87,6 +87,39 @@ export function assistantDecoderStream() { }); break; } + + case AssistantStreamChunkType.ToolCall: { + const { toolCallId, toolName, args } = value; + const argsText = JSON.stringify(args); + controller.enqueue({ + type: "tool-call-delta", + toolCallType: "function", + toolCallId, + toolName, + argsTextDelta: argsText, + }); + controller.enqueue({ + type: "tool-call", + toolCallType: "function", + toolCallId: toolCallId, + toolName: toolName, + args: argsText, + }); + break; + } + + case AssistantStreamChunkType.StepFinish: { + controller.enqueue({ + type: "step-finish", + ...value, + }); + break; + } + + // TODO + case AssistantStreamChunkType.Data: + break; + default: { const unhandledType: never = type; throw new Error(`Unhandled chunk type: ${unhandledType}`); diff --git a/packages/react/src/runtimes/edge/streams/assistantEncoderStream.ts b/packages/react/src/runtimes/edge/streams/assistantEncoderStream.ts index 38e3a68cf..8beb1eccf 100644 --- a/packages/react/src/runtimes/edge/streams/assistantEncoderStream.ts +++ b/packages/react/src/runtimes/edge/streams/assistantEncoderStream.ts @@ -59,6 +59,15 @@ export function assistantEncoderStream() { break; } + case "step-finish": { + const { type, ...rest } = chunk; + controller.enqueue({ + type: AssistantStreamChunkType.StepFinish, + value: rest, + }); + break; + } + case "finish": { const { type, ...rest } = chunk; controller.enqueue({ diff --git a/packages/react/src/runtimes/edge/streams/runResultStream.ts b/packages/react/src/runtimes/edge/streams/runResultStream.ts index a567cdc31..aa87c485b 100644 --- a/packages/react/src/runtimes/edge/streams/runResultStream.ts +++ b/packages/react/src/runtimes/edge/streams/runResultStream.ts @@ -53,6 +53,11 @@ export function runResultStream() { controller.enqueue(message); break; } + case "step-finish": { + message = appendOrUpdateStepFinish(message, chunk); + controller.enqueue(message); + break; + } case "finish": { message = appendOrUpdateFinish(message, chunk); controller.enqueue(message); @@ -160,29 +165,56 @@ const appendOrUpdateToolResult = ( }; }; +const appendOrUpdateStepFinish = ( + message: ChatModelRunResult, + chunk: ToolResultStreamPart & { type: "step-finish" }, +): ChatModelRunResult => { + const { type, ...rest } = chunk; + const steps = [ + ...(message.metadata?.steps ?? []), + { + usage: rest.usage, + }, + ]; + return { + ...message, + status: getStatus(chunk), + metadata: { + ...message.metadata, + roundtrips: steps, + steps, + }, + }; +}; + const appendOrUpdateFinish = ( message: ChatModelRunResult, chunk: LanguageModelV1StreamPart & { type: "finish" }, ): ChatModelRunResult => { const { type, ...rest } = chunk; + + const steps = [ + ...(message.metadata?.steps ?? []), + { + logprobs: rest.logprobs, + usage: rest.usage, + }, + ]; return { ...message, status: getStatus(chunk), metadata: { ...message.metadata, - roundtrips: [ - ...(message.metadata?.roundtrips ?? []), - { - logprobs: rest.logprobs, - usage: rest.usage, - }, - ], + roundtrips: steps, + steps, }, }; }; const getStatus = ( - chunk: LanguageModelV1StreamPart & { type: "finish" }, + chunk: + | (LanguageModelV1StreamPart & { type: "finish" }) + | (ToolResultStreamPart & { type: "step-finish" }), ): MessageStatus => { if (chunk.finishReason === "tool-calls") { return { diff --git a/packages/react/src/runtimes/edge/streams/toolResultStream.ts b/packages/react/src/runtimes/edge/streams/toolResultStream.ts index 5b8285d9e..58e14c938 100644 --- a/packages/react/src/runtimes/edge/streams/toolResultStream.ts +++ b/packages/react/src/runtimes/edge/streams/toolResultStream.ts @@ -12,6 +12,22 @@ export type ToolResultStreamPart = toolName: string; result: unknown; isError?: boolean; + } + | { + type: "step-finish"; + finishReason: + | "stop" + | "length" + | "content-filter" + | "tool-calls" + | "error" + | "other" + | "unknown"; + usage: { + promptTokens: number; + completionTokens: number; + }; + isContinued: boolean; }; export function toolResultStream( @@ -87,6 +103,7 @@ export function toolResultStream( case "text-delta": case "tool-call-delta": case "tool-result": + case "step-finish": case "finish": case "error": case "response-metadata": diff --git a/packages/react/src/runtimes/local/ChatModelAdapter.tsx b/packages/react/src/runtimes/local/ChatModelAdapter.tsx index 6be9ed009..96e167fd4 100644 --- a/packages/react/src/runtimes/local/ChatModelAdapter.tsx +++ b/packages/react/src/runtimes/local/ChatModelAdapter.tsx @@ -1,9 +1,10 @@ "use client"; + import type { MessageStatus, ThreadAssistantContentPart, ThreadMessage, - ThreadRoundtrip, + ThreadStep, } from "../../types/AssistantTypes"; import type { ModelConfig } from "../../types/ModelConfigTypes"; @@ -16,7 +17,11 @@ export type ChatModelRunResult = { content?: ThreadAssistantContentPart[]; status?: MessageStatus; metadata?: { - roundtrips?: ThreadRoundtrip[]; + /** + * @deprecated Use `steps` instead. This field will be removed in v0.6. + */ + roundtrips?: ThreadStep[]; + steps?: ThreadStep[]; custom?: Record; }; }; diff --git a/packages/react/src/runtimes/local/LocalRuntimeOptions.tsx b/packages/react/src/runtimes/local/LocalRuntimeOptions.tsx index a695418f9..901aa260e 100644 --- a/packages/react/src/runtimes/local/LocalRuntimeOptions.tsx +++ b/packages/react/src/runtimes/local/LocalRuntimeOptions.tsx @@ -5,6 +5,10 @@ import { SpeechSynthesisAdapter } from "../speech/SpeechAdapterTypes"; export type LocalRuntimeOptions = { initialMessages?: readonly CoreMessage[] | undefined; + maxSteps?: number | undefined; + /** + * @deprecated Use `maxSteps` (which is `maxToolRoundtrips` + 1; if you set `maxToolRoundtrips` to 2, set `maxSteps` to 3) instead. This field will be removed in v0.6. + */ maxToolRoundtrips?: number | undefined; adapters?: | { @@ -18,9 +22,16 @@ export type LocalRuntimeOptions = { export const splitLocalRuntimeOptions = ( options: T, ) => { - const { initialMessages, maxToolRoundtrips, adapters, ...rest } = options; + const { initialMessages, maxToolRoundtrips, maxSteps, adapters, ...rest } = + options; + return { - localRuntimeOptions: { initialMessages, maxToolRoundtrips, adapters }, + localRuntimeOptions: { + initialMessages, + maxToolRoundtrips, + maxSteps, + adapters, + }, otherOptions: rest, }; }; diff --git a/packages/react/src/runtimes/local/LocalThreadRuntimeCore.tsx b/packages/react/src/runtimes/local/LocalThreadRuntimeCore.tsx index 36e6b724f..b63e8a3fa 100644 --- a/packages/react/src/runtimes/local/LocalThreadRuntimeCore.tsx +++ b/packages/react/src/runtimes/local/LocalThreadRuntimeCore.tsx @@ -180,9 +180,14 @@ export class LocalThreadRuntimeCore implements ThreadRuntimeCore { this.abortController = new AbortController(); const initialContent = message.content; - const initialRoundtrips = message.metadata?.roundtrips; + const initialSteps = message.metadata?.steps; const initalCustom = message.metadata?.custom; const updateMessage = (m: Partial) => { + const newSteps = m.metadata?.steps || m.metadata?.roundtrips; + const steps = newSteps + ? [...(initialSteps ?? []), ...newSteps] + : undefined; + message = { ...message, ...(m.content @@ -190,26 +195,12 @@ export class LocalThreadRuntimeCore implements ThreadRuntimeCore { : undefined), status: m.status ?? message.status, // TODO deprecated, remove in v0.6 - ...(m.metadata?.roundtrips - ? { - roundtrips: [ - ...(initialRoundtrips ?? []), - ...m.metadata.roundtrips, - ], - } - : undefined), + ...(steps ? { roundtrips: steps } : undefined), ...(m.metadata ? { metadata: { ...message.metadata, - ...(m.metadata.roundtrips - ? { - roundtrips: [ - ...(initialRoundtrips ?? []), - ...m.metadata.roundtrips, - ], - } - : undefined), + ...(steps ? { roundtrips: steps, steps } : undefined), ...(m.metadata?.custom ? { custom: { ...(initalCustom ?? {}), ...m.metadata.custom }, @@ -223,10 +214,13 @@ export class LocalThreadRuntimeCore implements ThreadRuntimeCore { this.notifySubscribers(); }; - const maxToolRoundtrips = this.options.maxToolRoundtrips ?? 1; - const toolRoundtrips = message.metadata?.roundtrips?.length ?? 0; - if (toolRoundtrips > maxToolRoundtrips) { - // reached max tool roundtrips + const maxSteps = this.options.maxSteps + ? this.options.maxSteps + : (this.options.maxToolRoundtrips ?? 1) + 1; + + const steps = message.metadata?.steps?.length ?? 0; + if (steps >= maxSteps) { + // reached max tool steps updateMessage({ status: { type: "incomplete", diff --git a/packages/react/src/types/AssistantTypes.ts b/packages/react/src/types/AssistantTypes.ts index 6acf37c6c..44f6c4b85 100644 --- a/packages/react/src/types/AssistantTypes.ts +++ b/packages/react/src/types/AssistantTypes.ts @@ -51,7 +51,15 @@ type MessageCommonProps = { createdAt: Date; }; -export type ThreadRoundtrip = { +/** + * @deprecated Use `ThreadStep` instead. This type will be removed in v0.6. + */ +export type ThreadRoundtrip = ThreadStep; + +export type ThreadStep = { + /** + * @deprecated This field will be removed in v0.6. Submit feedback if you need this functionality. + */ logprobs?: LanguageModelV1LogProbs | undefined; usage?: | { @@ -122,11 +130,15 @@ export type ThreadAssistantMessage = MessageCommonProps & { content: ThreadAssistantContentPart[]; status: MessageStatus; /** - * @deprecated Use `metadata.roundtrips` instead. + * @deprecated Use `metadata.steps` instead. */ - roundtrips?: ThreadRoundtrip[] | undefined; + roundtrips?: ThreadStep[] | undefined; metadata?: { - roundtrips?: ThreadRoundtrip[] | undefined; + /** + * @deprecated Use `steps` instead. This field will be removed in v0.6. + */ + roundtrips?: ThreadStep[] | undefined; + steps?: ThreadStep[] | undefined; custom?: Record | undefined; }; };