From 26244cccf221d661206932b502cfdba3e1763a86 Mon Sep 17 00:00:00 2001 From: Simon Farshid Date: Sat, 13 Jul 2024 00:50:18 -0700 Subject: [PATCH] feat: edge runtime maxToolRoundtrips support (#490) --- .../src/runtimes/edge/EdgeChatAdapter.ts | 55 +++++++++++++++++-- .../src/runtimes/edge/createEdgeRuntimeAPI.ts | 48 ++++++++-------- .../runtimes/edge/streams/runResultStream.ts | 5 +- .../src/runtimes/local/ChatModelAdapter.tsx | 2 +- .../react/src/runtimes/local/LocalRuntime.tsx | 4 +- 5 files changed, 83 insertions(+), 31 deletions(-) diff --git a/packages/react/src/runtimes/edge/EdgeChatAdapter.ts b/packages/react/src/runtimes/edge/EdgeChatAdapter.ts index 1a0f5774f..57e942bd0 100644 --- a/packages/react/src/runtimes/edge/EdgeChatAdapter.ts +++ b/packages/react/src/runtimes/edge/EdgeChatAdapter.ts @@ -1,3 +1,4 @@ +import { ThreadAssistantContentPart, ThreadMessage } from "../../types"; import { ChatModelAdapter, ChatModelRunOptions } from "../local"; import { ChatModelRunResult } from "../local/ChatModelAdapter"; import { toCoreMessages } from "./converters/toCoreMessages"; @@ -27,12 +28,16 @@ export function asAsyncIterable( } export type EdgeChatAdapterOptions = { api: string; + maxToolRoundtrips?: number; }; export class EdgeChatAdapter implements ChatModelAdapter { constructor(private options: EdgeChatAdapterOptions) {} - async run({ messages, abortSignal, config, onUpdate }: ChatModelRunOptions) { + async roundtrip( + initialContent: ThreadAssistantContentPart[], + { messages, abortSignal, config, onUpdate }: ChatModelRunOptions, + ) { const result = await fetch(this.options.api, { method: "POST", headers: { @@ -53,14 +58,56 @@ export class EdgeChatAdapter implements ChatModelAdapter { .pipeThrough(chunkByLineStream()) .pipeThrough(assistantDecoderStream()) .pipeThrough(toolResultStream(config.tools)) - .pipeThrough(runResultStream()); + .pipeThrough(runResultStream(initialContent)); + let message: ThreadMessage | undefined; let update: ChatModelRunResult | undefined; for await (update of asAsyncIterable(stream)) { - onUpdate(update); + message = onUpdate(update); } if (update === undefined) throw new Error("No data received from Edge Runtime"); - return update; + + return [message, update] as const; + } + + async run({ messages, abortSignal, config, onUpdate }: ChatModelRunOptions) { + let roundtripAllowance = this.options.maxToolRoundtrips ?? 1; + let usage = { + promptTokens: 0, + completionTokens: 0, + }; + let result; + let assistantMessage; + do { + [assistantMessage, result] = await this.roundtrip(result?.content ?? [], { + messages: assistantMessage ? [...messages, assistantMessage] : messages, + abortSignal, + config, + onUpdate, + }); + if (result.status?.type === "done") { + usage.promptTokens += result.status.usage?.promptTokens ?? 0; + usage.completionTokens += result.status.usage?.completionTokens ?? 0; + } + } while ( + result.status?.type === "done" && + result.status.finishReason === "tool-calls" && + result.content.every((c) => c.type !== "tool-call" || !!c.result) && + roundtripAllowance-- > 0 + ); + + // add usage across all roundtrips + if (result.status?.type === "done" && usage.promptTokens > 0) { + result = { + ...result, + status: { + ...result.status, + usage, + }, + }; + } + + return result; } } diff --git a/packages/react/src/runtimes/edge/createEdgeRuntimeAPI.ts b/packages/react/src/runtimes/edge/createEdgeRuntimeAPI.ts index 18fbdb452..e78b57396 100644 --- a/packages/react/src/runtimes/edge/createEdgeRuntimeAPI.ts +++ b/packages/react/src/runtimes/edge/createEdgeRuntimeAPI.ts @@ -128,29 +128,31 @@ export const createEdgeRuntimeAPI = ({ let serverStream = tees[1]; if (onFinish) { - serverStream = serverStream.pipeThrough(runResultStream()).pipeThrough( - new TransformStream({ - transform(chunk) { - if (chunk.status?.type !== "done") return; - const resultingMessages = [ - ...messages, - { - role: "assistant", - content: chunk.content, - } as CoreAssistantMessage, - ]; - onFinish({ - finishReason: chunk.status.finishReason!, - usage: chunk.status.usage!, - messages: resultingMessages, - logProbs: chunk.status.logprops, - warnings: streamResult.warnings, - rawCall: streamResult.rawCall, - rawResponse: streamResult.rawResponse, - }); - }, - }), - ); + serverStream = serverStream + .pipeThrough(runResultStream([])) + .pipeThrough( + new TransformStream({ + transform(chunk) { + if (chunk.status?.type !== "done") return; + const resultingMessages = [ + ...messages, + { + role: "assistant", + content: chunk.content, + } as CoreAssistantMessage, + ]; + onFinish({ + finishReason: chunk.status.finishReason!, + usage: chunk.status.usage!, + messages: resultingMessages, + logProbs: chunk.status.logprops, + warnings: streamResult.warnings, + rawCall: streamResult.rawCall, + rawResponse: streamResult.rawResponse, + }); + }, + }), + ); } // drain the server stream diff --git a/packages/react/src/runtimes/edge/streams/runResultStream.ts b/packages/react/src/runtimes/edge/streams/runResultStream.ts index 91121b0c5..e7d330b99 100644 --- a/packages/react/src/runtimes/edge/streams/runResultStream.ts +++ b/packages/react/src/runtimes/edge/streams/runResultStream.ts @@ -2,10 +2,11 @@ import { ChatModelRunResult } from "../../local/ChatModelAdapter"; import { parsePartialJson } from "../partial-json/parse-partial-json"; import { LanguageModelV1StreamPart } from "@ai-sdk/provider"; import { ToolResultStreamPart } from "./toolResultStream"; +import { ThreadAssistantContentPart } from "../../../types"; -export function runResultStream() { +export function runResultStream(initialContent: ThreadAssistantContentPart[]) { let message: ChatModelRunResult = { - content: [], + content: initialContent, }; const currentToolCall = { toolCallId: "", argsText: "" }; diff --git a/packages/react/src/runtimes/local/ChatModelAdapter.tsx b/packages/react/src/runtimes/local/ChatModelAdapter.tsx index 00565a4ae..8cccc57ae 100644 --- a/packages/react/src/runtimes/local/ChatModelAdapter.tsx +++ b/packages/react/src/runtimes/local/ChatModelAdapter.tsx @@ -19,7 +19,7 @@ export type ChatModelRunOptions = { messages: ThreadMessage[]; abortSignal: AbortSignal; config: ModelConfig; - onUpdate: (result: ChatModelRunUpdate) => void; + onUpdate: (result: ChatModelRunUpdate) => ThreadMessage; }; export type ChatModelAdapter = { diff --git a/packages/react/src/runtimes/local/LocalRuntime.tsx b/packages/react/src/runtimes/local/LocalRuntime.tsx index 4ca4930b6..0205f8b21 100644 --- a/packages/react/src/runtimes/local/LocalRuntime.tsx +++ b/packages/react/src/runtimes/local/LocalRuntime.tsx @@ -136,8 +136,10 @@ class LocalThreadRuntime implements ThreadRuntime { try { const updateHandler = ({ content }: ChatModelRunResult) => { message.content = content; - this.repository.addOrUpdateMessage(parentId, { ...message }); + const newMessage = { ...message }; + this.repository.addOrUpdateMessage(parentId, newMessage); this.notifySubscribers(); + return newMessage; }; const result = await this.adapter.run({ messages,