Skip to content

Commit

Permalink
feat: add streamUtils (#664)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yonom authored Aug 13, 2024
1 parent f03d4da commit 915b5b7
Show file tree
Hide file tree
Showing 13 changed files with 150 additions and 83 deletions.
5 changes: 5 additions & 0 deletions .changeset/empty-snakes-try.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@assistant-ui/react": patch
---

feat: expose streamUtils
5 changes: 2 additions & 3 deletions packages/react/src/runtimes/edge/EdgeChatAdapter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import { toCoreMessages } from "./converters/toCoreMessages";
import { toLanguageModelTools } from "./converters/toLanguageModelTools";
import { EdgeRuntimeRequestOptions } from "./EdgeRuntimeRequestOptions";
import { assistantDecoderStream } from "./streams/assistantDecoderStream";
import { chunkByLineStream } from "./streams/chunkByLineStream";
import { streamPartDecoderStream } from "./streams/utils/streamPartDecoderStream";
import { runResultStream } from "./streams/runResultStream";
import { toolResultStream } from "./streams/toolResultStream";

Expand Down Expand Up @@ -53,8 +53,7 @@ export class EdgeChatAdapter implements ChatModelAdapter {
}

const stream = result
.body!.pipeThrough(new TextDecoderStream())
.pipeThrough(chunkByLineStream())
.body!.pipeThrough(streamPartDecoderStream())
.pipeThrough(assistantDecoderStream())
.pipeThrough(toolResultStream(config.tools))
.pipeThrough(runResultStream());
Expand Down
3 changes: 2 additions & 1 deletion packages/react/src/runtimes/edge/createEdgeRuntimeAPI.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import {
} from "../../types/ModelConfigTypes";
import { ChatModelRunResult } from "../local";
import { toCoreMessage } from "./converters/toCoreMessages";
import { streamPartEncoderStream } from "./streams/utils/streamPartEncoderStream";

type FinishResult = {
messages: CoreMessage[];
Expand Down Expand Up @@ -174,7 +175,7 @@ export const createEdgeRuntimeAPI = ({
return new Response(
stream
.pipeThrough(assistantEncoderStream())
.pipeThrough(new TextEncoderStream()),
.pipeThrough(streamPartEncoderStream()),
{
headers: {
contentType: "text/plain; charset=utf-8",
Expand Down
2 changes: 2 additions & 0 deletions packages/react/src/runtimes/edge/index.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
export * from "./converters";

export * from "./streams/utils";

export { useEdgeRuntime, type EdgeRuntimeOptions } from "./useEdgeRuntime";
export { EdgeChatAdapter } from "./EdgeChatAdapter";
export type { EdgeRuntimeRequestOptions } from "./EdgeRuntimeRequestOptions";
Original file line number Diff line number Diff line change
Expand Up @@ -9,30 +9,22 @@ export enum AssistantStreamChunkType {
Finish = "F",
}

export type AssistantStreamChunkTuple =
| [AssistantStreamChunkType.TextDelta, string]
| [
AssistantStreamChunkType.ToolCallBegin,
{
id: string;
name: string;
},
]
| [AssistantStreamChunkType.ToolCallArgsTextDelta, string]
| [
AssistantStreamChunkType.ToolCallResult,
{
id: string;
result: any;
},
]
| [AssistantStreamChunkType.Error, unknown]
| [
AssistantStreamChunkType.Finish,
Omit<
LanguageModelV1StreamPart & {
type: "finish";
},
"type"
>,
];
export type AssistantStreamChunk = {
[AssistantStreamChunkType.TextDelta]: string;
[AssistantStreamChunkType.ToolCallBegin]: {
id: string;
name: string;
};
[AssistantStreamChunkType.ToolCallArgsTextDelta]: string;
[AssistantStreamChunkType.ToolCallResult]: {
id: string;
result: any;
};
[AssistantStreamChunkType.Error]: unknown;
[AssistantStreamChunkType.Finish]: Omit<
LanguageModelV1StreamPart & {
type: "finish";
},
"type"
>;
};
29 changes: 11 additions & 18 deletions packages/react/src/runtimes/edge/streams/assistantDecoderStream.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import {
AssistantStreamChunkTuple,
AssistantStreamChunk,
AssistantStreamChunkType,
} from "./AssistantStreamChunkType";
import { StreamPart } from "./utils/StreamPart";
import { ToolResultStreamPart } from "./toolResultStream";

export function assistantDecoderStream() {
Expand All @@ -10,14 +11,15 @@ export function assistantDecoderStream() {
| { id: string; name: string; argsText: string }
| undefined;

return new TransformStream<string, ToolResultStreamPart>({
transform(chunk, controller) {
const [code, value] = parseStreamPart(chunk);

return new TransformStream<
StreamPart<AssistantStreamChunk>,
ToolResultStreamPart
>({
transform({ type, value }, controller) {
if (
currentToolCall &&
code !== AssistantStreamChunkType.ToolCallArgsTextDelta &&
code !== AssistantStreamChunkType.Error
type !== AssistantStreamChunkType.ToolCallArgsTextDelta &&
type !== AssistantStreamChunkType.Error
) {
controller.enqueue({
type: "tool-call",
Expand All @@ -29,7 +31,7 @@ export function assistantDecoderStream() {
currentToolCall = undefined;
}

switch (code) {
switch (type) {
case AssistantStreamChunkType.TextDelta: {
controller.enqueue({
type: "text-delta",
Expand Down Expand Up @@ -80,19 +82,10 @@ export function assistantDecoderStream() {
break;
}
default: {
const unhandledType: never = code;
const unhandledType: never = type;
throw new Error(`Unhandled chunk type: ${unhandledType}`);
}
}
},
});
}

const parseStreamPart = (part: string): AssistantStreamChunkTuple => {
const index = part.indexOf(":");
if (index === -1) throw new Error("Invalid stream part");
return [
part.slice(0, index) as AssistantStreamChunkType,
JSON.parse(part.slice(index + 1)),
] as const;
};
66 changes: 32 additions & 34 deletions packages/react/src/runtimes/edge/streams/assistantEncoderStream.ts
Original file line number Diff line number Diff line change
@@ -1,41 +1,42 @@
import {
AssistantStreamChunkTuple,
AssistantStreamChunk,
AssistantStreamChunkType,
} from "./AssistantStreamChunkType";
import { StreamPart } from "./utils/StreamPart";
import { ToolResultStreamPart } from "./toolResultStream";

export function assistantEncoderStream() {
const toolCalls = new Set<string>();
return new TransformStream<ToolResultStreamPart, string>({
return new TransformStream<
ToolResultStreamPart,
StreamPart<AssistantStreamChunk>
>({
transform(chunk, controller) {
const chunkType = chunk.type;
switch (chunkType) {
case "text-delta": {
controller.enqueue(
formatStreamPart(
AssistantStreamChunkType.TextDelta,
chunk.textDelta,
),
);
controller.enqueue({
type: AssistantStreamChunkType.TextDelta,
value: chunk.textDelta,
});
break;
}
case "tool-call-delta": {
if (!toolCalls.has(chunk.toolCallId)) {
toolCalls.add(chunk.toolCallId);
controller.enqueue(
formatStreamPart(AssistantStreamChunkType.ToolCallBegin, {
controller.enqueue({
type: AssistantStreamChunkType.ToolCallBegin,
value: {
id: chunk.toolCallId,
name: chunk.toolName,
}),
);
},
});
}

controller.enqueue(
formatStreamPart(
AssistantStreamChunkType.ToolCallArgsTextDelta,
chunk.argsTextDelta,
),
);
controller.enqueue({
type: AssistantStreamChunkType.ToolCallArgsTextDelta,
value: chunk.argsTextDelta,
});
break;
}

Expand All @@ -44,27 +45,30 @@ export function assistantEncoderStream() {
break;

case "tool-result": {
controller.enqueue(
formatStreamPart(AssistantStreamChunkType.ToolCallResult, {
controller.enqueue({
type: AssistantStreamChunkType.ToolCallResult,
value: {
id: chunk.toolCallId,
result: chunk.result,
}),
);
},
});
break;
}

case "finish": {
const { type, ...rest } = chunk;
controller.enqueue(
formatStreamPart(AssistantStreamChunkType.Finish, rest),
);
controller.enqueue({
type: AssistantStreamChunkType.Finish,
value: rest,
});
break;
}

case "error": {
controller.enqueue(
formatStreamPart(AssistantStreamChunkType.Error, chunk.error),
);
controller.enqueue({
type: AssistantStreamChunkType.Error,
value: chunk.error,
});
break;
}
default: {
Expand All @@ -75,9 +79,3 @@ export function assistantEncoderStream() {
},
});
}

export function formatStreamPart(
...[code, value]: AssistantStreamChunkTuple
): string {
return `${code}:${JSON.stringify(value)}\n`;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
export class PipeableTransformStream<I, O> extends TransformStream<I, O> {
constructor(transform: (readable: ReadableStream<I>) => ReadableStream<O>) {
super();
const readable = transform(super.readable as any);
Object.defineProperty(this, "readable", {
value: readable,
writable: false,
});
}
}
3 changes: 3 additions & 0 deletions packages/react/src/runtimes/edge/streams/utils/StreamPart.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
export type StreamPart<T extends Record<string, unknown>> = {
[K in keyof T]: { type: K; value: T[K] };
}[keyof T];
12 changes: 12 additions & 0 deletions packages/react/src/runtimes/edge/streams/utils/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import { streamPartDecoderStream } from "./streamPartDecoderStream";
import { streamPartEncoderStream } from "./streamPartEncoderStream";
import { StreamPart } from "./StreamPart";

export declare namespace StreamUtils {
export { StreamPart };
}

export const streamUtils = {
streamPartEncoderStream,
streamPartDecoderStream,
};
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import { chunkByLineStream } from "./chunkByLineStream";
import { PipeableTransformStream } from "./PipeableTransformStream";
import { StreamPart } from "./StreamPart";

const decodeStreamPart = <T extends Record<string, unknown>>(
part: string,
): StreamPart<T> => {
const index = part.indexOf(":");
if (index === -1) throw new Error("Invalid stream part");
return {
type: part.slice(0, index),
value: JSON.parse(part.slice(index + 1)),
};
};

export function streamPartDecoderStream<T extends Record<string, unknown>>() {
const decodeStream = new TransformStream<string, StreamPart<T>>({
transform(chunk, controller) {
controller.enqueue(decodeStreamPart<T>(chunk));
},
});

return new PipeableTransformStream((readable) => {
return readable
.pipeThrough(new TextDecoderStream())
.pipeThrough(chunkByLineStream())
.pipeThrough(decodeStream);
});
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import { PipeableTransformStream } from "./PipeableTransformStream";
import { StreamPart } from "./StreamPart";

function encodeStreamPart<T extends Record<string, unknown>>({
type,
value,
}: StreamPart<T>): string {
return `${type as string}:${JSON.stringify(value)}\n`;
}

export function streamPartEncoderStream<T extends Record<string, unknown>>() {
const encodeStream = new TransformStream<StreamPart<T>, string>({
transform(chunk, controller) {
controller.enqueue(encodeStreamPart<T>(chunk));
},
});

return new PipeableTransformStream((readable) => {
return readable
.pipeThrough(encodeStream)
.pipeThrough(new TextEncoderStream());
});
}

0 comments on commit 915b5b7

Please sign in to comment.