From 915b5b78db4e6fd0c508fcc34530c8306a5af50f Mon Sep 17 00:00:00 2001 From: Simon Farshid Date: Tue, 13 Aug 2024 14:49:55 -0700 Subject: [PATCH] feat: add streamUtils (#664) --- .changeset/empty-snakes-try.md | 5 ++ .../src/runtimes/edge/EdgeChatAdapter.ts | 5 +- .../src/runtimes/edge/createEdgeRuntimeAPI.ts | 3 +- packages/react/src/runtimes/edge/index.ts | 2 + .../edge/streams/AssistantStreamChunkType.ts | 46 ++++++------- .../edge/streams/assistantDecoderStream.ts | 29 ++++---- .../edge/streams/assistantEncoderStream.ts | 66 +++++++++---------- .../streams/utils/PipeableTransformStream.ts | 10 +++ .../runtimes/edge/streams/utils/StreamPart.ts | 3 + .../streams/{ => utils}/chunkByLineStream.ts | 0 .../src/runtimes/edge/streams/utils/index.ts | 12 ++++ .../streams/utils/streamPartDecoderStream.ts | 29 ++++++++ .../streams/utils/streamPartEncoderStream.ts | 23 +++++++ 13 files changed, 150 insertions(+), 83 deletions(-) create mode 100644 .changeset/empty-snakes-try.md create mode 100644 packages/react/src/runtimes/edge/streams/utils/PipeableTransformStream.ts create mode 100644 packages/react/src/runtimes/edge/streams/utils/StreamPart.ts rename packages/react/src/runtimes/edge/streams/{ => utils}/chunkByLineStream.ts (100%) create mode 100644 packages/react/src/runtimes/edge/streams/utils/index.ts create mode 100644 packages/react/src/runtimes/edge/streams/utils/streamPartDecoderStream.ts create mode 100644 packages/react/src/runtimes/edge/streams/utils/streamPartEncoderStream.ts diff --git a/.changeset/empty-snakes-try.md b/.changeset/empty-snakes-try.md new file mode 100644 index 000000000..2038a9af8 --- /dev/null +++ b/.changeset/empty-snakes-try.md @@ -0,0 +1,5 @@ +--- +"@assistant-ui/react": patch +--- + +feat: expose streamUtils diff --git a/packages/react/src/runtimes/edge/EdgeChatAdapter.ts b/packages/react/src/runtimes/edge/EdgeChatAdapter.ts index 82d2838fc..8774cd0d2 100644 --- a/packages/react/src/runtimes/edge/EdgeChatAdapter.ts +++ b/packages/react/src/runtimes/edge/EdgeChatAdapter.ts @@ -4,7 +4,7 @@ import { toCoreMessages } from "./converters/toCoreMessages"; import { toLanguageModelTools } from "./converters/toLanguageModelTools"; import { EdgeRuntimeRequestOptions } from "./EdgeRuntimeRequestOptions"; import { assistantDecoderStream } from "./streams/assistantDecoderStream"; -import { chunkByLineStream } from "./streams/chunkByLineStream"; +import { streamPartDecoderStream } from "./streams/utils/streamPartDecoderStream"; import { runResultStream } from "./streams/runResultStream"; import { toolResultStream } from "./streams/toolResultStream"; @@ -53,8 +53,7 @@ export class EdgeChatAdapter implements ChatModelAdapter { } const stream = result - .body!.pipeThrough(new TextDecoderStream()) - .pipeThrough(chunkByLineStream()) + .body!.pipeThrough(streamPartDecoderStream()) .pipeThrough(assistantDecoderStream()) .pipeThrough(toolResultStream(config.tools)) .pipeThrough(runResultStream()); diff --git a/packages/react/src/runtimes/edge/createEdgeRuntimeAPI.ts b/packages/react/src/runtimes/edge/createEdgeRuntimeAPI.ts index 99fb381a9..06bac509f 100644 --- a/packages/react/src/runtimes/edge/createEdgeRuntimeAPI.ts +++ b/packages/react/src/runtimes/edge/createEdgeRuntimeAPI.ts @@ -27,6 +27,7 @@ import { } from "../../types/ModelConfigTypes"; import { ChatModelRunResult } from "../local"; import { toCoreMessage } from "./converters/toCoreMessages"; +import { streamPartEncoderStream } from "./streams/utils/streamPartEncoderStream"; type FinishResult = { messages: CoreMessage[]; @@ -174,7 +175,7 @@ export const createEdgeRuntimeAPI = ({ return new Response( stream .pipeThrough(assistantEncoderStream()) - .pipeThrough(new TextEncoderStream()), + .pipeThrough(streamPartEncoderStream()), { headers: { contentType: "text/plain; charset=utf-8", diff --git a/packages/react/src/runtimes/edge/index.ts b/packages/react/src/runtimes/edge/index.ts index 5af5b62ba..7c09f0423 100644 --- a/packages/react/src/runtimes/edge/index.ts +++ b/packages/react/src/runtimes/edge/index.ts @@ -1,5 +1,7 @@ export * from "./converters"; +export * from "./streams/utils"; + export { useEdgeRuntime, type EdgeRuntimeOptions } from "./useEdgeRuntime"; export { EdgeChatAdapter } from "./EdgeChatAdapter"; export type { EdgeRuntimeRequestOptions } from "./EdgeRuntimeRequestOptions"; diff --git a/packages/react/src/runtimes/edge/streams/AssistantStreamChunkType.ts b/packages/react/src/runtimes/edge/streams/AssistantStreamChunkType.ts index 229ab804a..c62b95836 100644 --- a/packages/react/src/runtimes/edge/streams/AssistantStreamChunkType.ts +++ b/packages/react/src/runtimes/edge/streams/AssistantStreamChunkType.ts @@ -9,30 +9,22 @@ export enum AssistantStreamChunkType { Finish = "F", } -export type AssistantStreamChunkTuple = - | [AssistantStreamChunkType.TextDelta, string] - | [ - AssistantStreamChunkType.ToolCallBegin, - { - id: string; - name: string; - }, - ] - | [AssistantStreamChunkType.ToolCallArgsTextDelta, string] - | [ - AssistantStreamChunkType.ToolCallResult, - { - id: string; - result: any; - }, - ] - | [AssistantStreamChunkType.Error, unknown] - | [ - AssistantStreamChunkType.Finish, - Omit< - LanguageModelV1StreamPart & { - type: "finish"; - }, - "type" - >, - ]; +export type AssistantStreamChunk = { + [AssistantStreamChunkType.TextDelta]: string; + [AssistantStreamChunkType.ToolCallBegin]: { + id: string; + name: string; + }; + [AssistantStreamChunkType.ToolCallArgsTextDelta]: string; + [AssistantStreamChunkType.ToolCallResult]: { + id: string; + result: any; + }; + [AssistantStreamChunkType.Error]: unknown; + [AssistantStreamChunkType.Finish]: Omit< + LanguageModelV1StreamPart & { + type: "finish"; + }, + "type" + >; +}; diff --git a/packages/react/src/runtimes/edge/streams/assistantDecoderStream.ts b/packages/react/src/runtimes/edge/streams/assistantDecoderStream.ts index 47cba731b..a96f6afe0 100644 --- a/packages/react/src/runtimes/edge/streams/assistantDecoderStream.ts +++ b/packages/react/src/runtimes/edge/streams/assistantDecoderStream.ts @@ -1,7 +1,8 @@ import { - AssistantStreamChunkTuple, + AssistantStreamChunk, AssistantStreamChunkType, } from "./AssistantStreamChunkType"; +import { StreamPart } from "./utils/StreamPart"; import { ToolResultStreamPart } from "./toolResultStream"; export function assistantDecoderStream() { @@ -10,14 +11,15 @@ export function assistantDecoderStream() { | { id: string; name: string; argsText: string } | undefined; - return new TransformStream({ - transform(chunk, controller) { - const [code, value] = parseStreamPart(chunk); - + return new TransformStream< + StreamPart, + ToolResultStreamPart + >({ + transform({ type, value }, controller) { if ( currentToolCall && - code !== AssistantStreamChunkType.ToolCallArgsTextDelta && - code !== AssistantStreamChunkType.Error + type !== AssistantStreamChunkType.ToolCallArgsTextDelta && + type !== AssistantStreamChunkType.Error ) { controller.enqueue({ type: "tool-call", @@ -29,7 +31,7 @@ export function assistantDecoderStream() { currentToolCall = undefined; } - switch (code) { + switch (type) { case AssistantStreamChunkType.TextDelta: { controller.enqueue({ type: "text-delta", @@ -80,19 +82,10 @@ export function assistantDecoderStream() { break; } default: { - const unhandledType: never = code; + const unhandledType: never = type; throw new Error(`Unhandled chunk type: ${unhandledType}`); } } }, }); } - -const parseStreamPart = (part: string): AssistantStreamChunkTuple => { - const index = part.indexOf(":"); - if (index === -1) throw new Error("Invalid stream part"); - return [ - part.slice(0, index) as AssistantStreamChunkType, - JSON.parse(part.slice(index + 1)), - ] as const; -}; diff --git a/packages/react/src/runtimes/edge/streams/assistantEncoderStream.ts b/packages/react/src/runtimes/edge/streams/assistantEncoderStream.ts index 688c5a738..cfbb05c4f 100644 --- a/packages/react/src/runtimes/edge/streams/assistantEncoderStream.ts +++ b/packages/react/src/runtimes/edge/streams/assistantEncoderStream.ts @@ -1,41 +1,42 @@ import { - AssistantStreamChunkTuple, + AssistantStreamChunk, AssistantStreamChunkType, } from "./AssistantStreamChunkType"; +import { StreamPart } from "./utils/StreamPart"; import { ToolResultStreamPart } from "./toolResultStream"; export function assistantEncoderStream() { const toolCalls = new Set(); - return new TransformStream({ + return new TransformStream< + ToolResultStreamPart, + StreamPart + >({ transform(chunk, controller) { const chunkType = chunk.type; switch (chunkType) { case "text-delta": { - controller.enqueue( - formatStreamPart( - AssistantStreamChunkType.TextDelta, - chunk.textDelta, - ), - ); + controller.enqueue({ + type: AssistantStreamChunkType.TextDelta, + value: chunk.textDelta, + }); break; } case "tool-call-delta": { if (!toolCalls.has(chunk.toolCallId)) { toolCalls.add(chunk.toolCallId); - controller.enqueue( - formatStreamPart(AssistantStreamChunkType.ToolCallBegin, { + controller.enqueue({ + type: AssistantStreamChunkType.ToolCallBegin, + value: { id: chunk.toolCallId, name: chunk.toolName, - }), - ); + }, + }); } - controller.enqueue( - formatStreamPart( - AssistantStreamChunkType.ToolCallArgsTextDelta, - chunk.argsTextDelta, - ), - ); + controller.enqueue({ + type: AssistantStreamChunkType.ToolCallArgsTextDelta, + value: chunk.argsTextDelta, + }); break; } @@ -44,27 +45,30 @@ export function assistantEncoderStream() { break; case "tool-result": { - controller.enqueue( - formatStreamPart(AssistantStreamChunkType.ToolCallResult, { + controller.enqueue({ + type: AssistantStreamChunkType.ToolCallResult, + value: { id: chunk.toolCallId, result: chunk.result, - }), - ); + }, + }); break; } case "finish": { const { type, ...rest } = chunk; - controller.enqueue( - formatStreamPart(AssistantStreamChunkType.Finish, rest), - ); + controller.enqueue({ + type: AssistantStreamChunkType.Finish, + value: rest, + }); break; } case "error": { - controller.enqueue( - formatStreamPart(AssistantStreamChunkType.Error, chunk.error), - ); + controller.enqueue({ + type: AssistantStreamChunkType.Error, + value: chunk.error, + }); break; } default: { @@ -75,9 +79,3 @@ export function assistantEncoderStream() { }, }); } - -export function formatStreamPart( - ...[code, value]: AssistantStreamChunkTuple -): string { - return `${code}:${JSON.stringify(value)}\n`; -} diff --git a/packages/react/src/runtimes/edge/streams/utils/PipeableTransformStream.ts b/packages/react/src/runtimes/edge/streams/utils/PipeableTransformStream.ts new file mode 100644 index 000000000..e97e9fccd --- /dev/null +++ b/packages/react/src/runtimes/edge/streams/utils/PipeableTransformStream.ts @@ -0,0 +1,10 @@ +export class PipeableTransformStream extends TransformStream { + constructor(transform: (readable: ReadableStream) => ReadableStream) { + super(); + const readable = transform(super.readable as any); + Object.defineProperty(this, "readable", { + value: readable, + writable: false, + }); + } +} diff --git a/packages/react/src/runtimes/edge/streams/utils/StreamPart.ts b/packages/react/src/runtimes/edge/streams/utils/StreamPart.ts new file mode 100644 index 000000000..2c7f73a08 --- /dev/null +++ b/packages/react/src/runtimes/edge/streams/utils/StreamPart.ts @@ -0,0 +1,3 @@ +export type StreamPart> = { + [K in keyof T]: { type: K; value: T[K] }; +}[keyof T]; diff --git a/packages/react/src/runtimes/edge/streams/chunkByLineStream.ts b/packages/react/src/runtimes/edge/streams/utils/chunkByLineStream.ts similarity index 100% rename from packages/react/src/runtimes/edge/streams/chunkByLineStream.ts rename to packages/react/src/runtimes/edge/streams/utils/chunkByLineStream.ts diff --git a/packages/react/src/runtimes/edge/streams/utils/index.ts b/packages/react/src/runtimes/edge/streams/utils/index.ts new file mode 100644 index 000000000..b8e7c0c22 --- /dev/null +++ b/packages/react/src/runtimes/edge/streams/utils/index.ts @@ -0,0 +1,12 @@ +import { streamPartDecoderStream } from "./streamPartDecoderStream"; +import { streamPartEncoderStream } from "./streamPartEncoderStream"; +import { StreamPart } from "./StreamPart"; + +export declare namespace StreamUtils { + export { StreamPart }; +} + +export const streamUtils = { + streamPartEncoderStream, + streamPartDecoderStream, +}; diff --git a/packages/react/src/runtimes/edge/streams/utils/streamPartDecoderStream.ts b/packages/react/src/runtimes/edge/streams/utils/streamPartDecoderStream.ts new file mode 100644 index 000000000..1399538c9 --- /dev/null +++ b/packages/react/src/runtimes/edge/streams/utils/streamPartDecoderStream.ts @@ -0,0 +1,29 @@ +import { chunkByLineStream } from "./chunkByLineStream"; +import { PipeableTransformStream } from "./PipeableTransformStream"; +import { StreamPart } from "./StreamPart"; + +const decodeStreamPart = >( + part: string, +): StreamPart => { + const index = part.indexOf(":"); + if (index === -1) throw new Error("Invalid stream part"); + return { + type: part.slice(0, index), + value: JSON.parse(part.slice(index + 1)), + }; +}; + +export function streamPartDecoderStream>() { + const decodeStream = new TransformStream>({ + transform(chunk, controller) { + controller.enqueue(decodeStreamPart(chunk)); + }, + }); + + return new PipeableTransformStream((readable) => { + return readable + .pipeThrough(new TextDecoderStream()) + .pipeThrough(chunkByLineStream()) + .pipeThrough(decodeStream); + }); +} diff --git a/packages/react/src/runtimes/edge/streams/utils/streamPartEncoderStream.ts b/packages/react/src/runtimes/edge/streams/utils/streamPartEncoderStream.ts new file mode 100644 index 000000000..7720e2b8c --- /dev/null +++ b/packages/react/src/runtimes/edge/streams/utils/streamPartEncoderStream.ts @@ -0,0 +1,23 @@ +import { PipeableTransformStream } from "./PipeableTransformStream"; +import { StreamPart } from "./StreamPart"; + +function encodeStreamPart>({ + type, + value, +}: StreamPart): string { + return `${type as string}:${JSON.stringify(value)}\n`; +} + +export function streamPartEncoderStream>() { + const encodeStream = new TransformStream, string>({ + transform(chunk, controller) { + controller.enqueue(encodeStreamPart(chunk)); + }, + }); + + return new PipeableTransformStream((readable) => { + return readable + .pipeThrough(encodeStream) + .pipeThrough(new TextEncoderStream()); + }); +}