Skip to content

Commit

Permalink
feat: append assistant / system message support (#472)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yonom authored Jul 12, 2024
1 parent 3928cb7 commit e832e1c
Show file tree
Hide file tree
Showing 21 changed files with 125 additions and 108 deletions.
2 changes: 1 addition & 1 deletion apps/docs/components/docs/parameters/context.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ export const AssistantModelConfigState: ParametersTableProps = {
},
{
name: "registerModelConfigProvider",
type: "(provider: () => ModelConfig) => Unsubscribe",
type: "(provider: ModelConfigProvider) => Unsubscribe",
description:
"Registers a model config provider to update the model config.",
required: true,
Expand Down
4 changes: 4 additions & 0 deletions packages/react-ai-sdk/src/rsc/VercelRSCThreadRuntime.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ export class VercelRSCThreadRuntime<T extends WeakKey = VercelRSCMessage>
}

public async append(message: AppendMessage): Promise<void> {
if (message.role !== "user")
throw new Error(
"Only appending user messages are supported in VercelRSCRuntime. This is likely an internal bug in assistant-ui.",
);
if (message.parentId !== (this.messages.at(-1)?.id ?? null)) {
if (!this.adapter.edit)
throw new Error(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ export class VercelUseAssistantThreadRuntime implements ReactThreadRuntime {

public async append(message: AppendMessage): Promise<void> {
// add user message
if (message.role !== "user")
throw new Error(
"Only appending user messages are supported in VercelUseAssistantRuntime. This is likely an internal bug in assistant-ui.",
);
if (message.content.length !== 1 || message.content[0]?.type !== "text")
throw new Error("VercelUseAssistantRuntime only supports text content.");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ export class VercelUseChatThreadRuntime implements ReactThreadRuntime {

public async append(message: AppendMessage): Promise<void> {
// add user message
if (message.role !== "user")
throw new Error(
"Only appending user messages are supported in VercelUseChatRuntime. This is likely an internal bug in assistant-ui.",
);

if (message.content.length !== 1 || message.content[0]?.type !== "text")
throw new Error(
"Only text content is supported by VercelUseChatRuntime.",
Expand Down
2 changes: 1 addition & 1 deletion packages/react-hook-form/src/useAssistantForm.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ export const useAssistantForm = <
},
},
};
return registerModelConfigProvider(() => value);
return registerModelConfigProvider({ getModelConfig: () => value });
}, [control, setValue, getValues, registerModelConfigProvider]);

const renderFormFieldTool = props?.assistant?.tools?.set_form_field?.render;
Expand Down
2 changes: 1 addition & 1 deletion packages/react/src/context/providers/AssistantProvider.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ export const AssistantProvider: FC<
return { useModelConfig, useToolUIs, useAssistantActions };
});

const getModelConfig = context.useModelConfig((c) => c.getModelConfig);
const getModelConfig = context.useModelConfig();
useEffect(() => {
return runtime.registerModelConfigProvider(getModelConfig);
}, [runtime, getModelConfig]);
Expand Down
4 changes: 2 additions & 2 deletions packages/react/src/context/providers/MessageProvider.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import { type FC, type PropsWithChildren, useEffect, useState } from "react";
import { StoreApi, create } from "zustand";
import type {
AppendContentPart,
AppendUserContentPart,
ThreadMessage,
} from "../../types/AssistantTypes";
import { getMessageText } from "../../utils/getMessageText";
Expand Down Expand Up @@ -81,7 +81,7 @@ const useMessageContext = (messageIndex: number) => {
);

const nonTextParts = message.content.filter(
(part): part is AppendContentPart =>
(part): part is AppendUserContentPart =>
part.type !== "text" && part.type !== "ui",
);
useThreadActions.getState().append({
Expand Down
9 changes: 5 additions & 4 deletions packages/react/src/context/stores/AssistantModelConfig.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@ import { create } from "zustand";
import type { ModelConfigProvider } from "../../types/ModelConfigTypes";
import { ProxyConfigProvider } from "../../utils/ProxyConfigProvider";

export type AssistantModelConfigState = Readonly<{
getModelConfig: ModelConfigProvider;
registerModelConfigProvider: (provider: ModelConfigProvider) => () => void;
}>;
export type AssistantModelConfigState = Readonly<
ModelConfigProvider & {
registerModelConfigProvider: (provider: ModelConfigProvider) => () => void;
}
>;

export const makeAssistantModelConfigStore = () =>
create<AssistantModelConfigState>(() => {
Expand Down
2 changes: 1 addition & 1 deletion packages/react/src/hooks/useAppendMessage.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ const toAppendMessage = (
message.parentId ?? useThreadMessages.getState().at(-1)?.id ?? null,
role: message.role ?? "user",
content: message.content,
};
} as AppendMessage;
};

export const useAppendMessage = () => {
Expand Down
1 change: 1 addition & 0 deletions packages/react/src/internal.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ export { MessageRepository } from "./runtimes/utils/MessageRepository";
export { BaseAssistantRuntime } from "./runtimes/core/BaseAssistantRuntime";
export { useSmooth } from "./utils/hooks/useSmooth";
export { TooltipIconButton } from "./ui/base/tooltip-icon-button";
export { generateId } from "./utils/idUtils";
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@ export const useAssistantInstructions = (instruction: string) => {
const config = {
system: instruction,
};
return registerModelConfigProvider(() => config);
return registerModelConfigProvider({ getModelConfig: () => config });
}, [registerModelConfigProvider, instruction]);
};
4 changes: 3 additions & 1 deletion packages/react/src/model-config/useAssistantTool.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ export const useAssistantTool = <TArgs, TResult>(
[tool.toolName]: rest,
},
};
const unsub1 = registerModelConfigProvider(() => config);
const unsub1 = registerModelConfigProvider({
getModelConfig: () => config,
});
const unsub2 = render ? setToolUI(toolName, render) : undefined;
return () => {
unsub1();
Expand Down
64 changes: 64 additions & 0 deletions packages/react/src/runtimes/edge/EdgeChatAdapter.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import { ChatModelAdapter, ChatModelRunOptions } from "../local";
import { ChatModelRunResult } from "../local/ChatModelAdapter";
import { toCoreMessage } from "./converters/toCoreMessage";
import { toLanguageModelTools } from "./converters/toLanguageModelTools";
import { EdgeRuntimeRequestOptions } from "./EdgeRuntimeRequestOptions";
import { assistantDecoderStream } from "./streams/assistantDecoderStream";
import { chunkByLineStream } from "./streams/chunkByLineStream";
import { runResultStream } from "./streams/runResultStream";
import { toolResultStream } from "./streams/toolResultStream";

export function asAsyncIterable<T>(
source: ReadableStream<T>,
): AsyncIterable<T> {
return {
[Symbol.asyncIterator]: () => {
const reader = source.getReader();
return {
async next(): Promise<IteratorResult<T, undefined>> {
const { done, value } = await reader.read();
return done
? { done: true, value: undefined }
: { done: false, value };
},
};
},
};
}
export type EdgeRuntimeOptions = { api: string };

export class EdgeChatAdapter implements ChatModelAdapter {
constructor(private options: EdgeRuntimeOptions) {}

async run({ messages, abortSignal, config, onUpdate }: ChatModelRunOptions) {
const result = await fetch(this.options.api, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
system: config.system,
messages: messages.map(toCoreMessage),
tools: toLanguageModelTools(
config.tools,
) as EdgeRuntimeRequestOptions["tools"],
} satisfies EdgeRuntimeRequestOptions),
signal: abortSignal,
});

const stream = result
.body!.pipeThrough(new TextDecoderStream())
.pipeThrough(chunkByLineStream())
.pipeThrough(assistantDecoderStream())
.pipeThrough(toolResultStream(config.tools))
.pipeThrough(runResultStream());

let update: ChatModelRunResult | undefined;
for await (update of asAsyncIterable(stream)) {
onUpdate(update);
}
if (update === undefined)
throw new Error("No data received from Edge Runtime");
return update;
}
}
2 changes: 2 additions & 0 deletions packages/react/src/runtimes/edge/index.ts
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
export { useEdgeRuntime } from "./useEdgeRuntime";
export { EdgeChatAdapter } from "./EdgeChatAdapter";
export type { EdgeRuntimeOptions } from "./EdgeChatAdapter";
72 changes: 3 additions & 69 deletions packages/react/src/runtimes/edge/useEdgeRuntime.ts
Original file line number Diff line number Diff line change
@@ -1,74 +1,8 @@
import { assistantDecoderStream } from "./streams/assistantDecoderStream";
import { chunkByLineStream } from "./streams/chunkByLineStream";
import {
ChatModelAdapter,
ChatModelRunResult,
} from "../local/ChatModelAdapter";
import { runResultStream } from "./streams/runResultStream";
import { useLocalRuntime } from "..";
import { useMemo } from "react";
import { toolResultStream } from "./streams/toolResultStream";
import { EdgeRuntimeRequestOptions } from "./EdgeRuntimeRequestOptions";
import { toLanguageModelTools } from "./converters/toLanguageModelTools";
import { toCoreMessage } from "./converters/toCoreMessage";

export function asAsyncIterable<T>(
source: ReadableStream<T>,
): AsyncIterable<T> {
return {
[Symbol.asyncIterator]: () => {
const reader = source.getReader();
return {
async next(): Promise<IteratorResult<T, undefined>> {
const { done, value } = await reader.read();
return done
? { done: true, value: undefined }
: { done: false, value };
},
};
},
};
}

type EdgeRuntimeOptions = { api: string };

const createEdgeChatAdapter = ({
api,
}: EdgeRuntimeOptions): ChatModelAdapter => ({
run: async ({ messages, abortSignal, config, onUpdate }) => {
const result = await fetch(api, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
system: config.system,
messages: messages.map(toCoreMessage),
tools: toLanguageModelTools(
config.tools,
) as EdgeRuntimeRequestOptions["tools"],
} satisfies EdgeRuntimeRequestOptions),
signal: abortSignal,
});

const stream = result
.body!.pipeThrough(new TextDecoderStream())
.pipeThrough(chunkByLineStream())
.pipeThrough(assistantDecoderStream())
.pipeThrough(toolResultStream(config.tools))
.pipeThrough(runResultStream());

let update: ChatModelRunResult | undefined;
for await (update of asAsyncIterable(stream)) {
onUpdate(update);
}
if (update === undefined)
throw new Error("No data received from Edge Runtime");
return update;
},
});
import { useState } from "react";
import { EdgeRuntimeOptions, EdgeChatAdapter } from "./EdgeChatAdapter";

export const useEdgeRuntime = (options: EdgeRuntimeOptions) => {
const adapter = useMemo(() => createEdgeChatAdapter(options), [options]);
const [adapter] = useState(() => new EdgeChatAdapter(options));
return useLocalRuntime(adapter);
};
33 changes: 16 additions & 17 deletions packages/react/src/runtimes/local/LocalRuntime.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -3,34 +3,31 @@ import type {
ThreadAssistantMessage,
ThreadUserMessage,
} from "../../types/AssistantTypes";
import {
type ModelConfigProvider,
mergeModelConfigs,
} from "../../types/ModelConfigTypes";
import { type ModelConfigProvider } from "../../types/ModelConfigTypes";
import type { Unsubscribe } from "../../types/Unsubscribe";
import { ThreadRuntime } from "../core";
import { MessageRepository } from "../utils/MessageRepository";
import { generateId } from "../../utils/idUtils";
import { BaseAssistantRuntime } from "../core/BaseAssistantRuntime";
import type { ChatModelAdapter, ChatModelRunResult } from "./ChatModelAdapter";
import { AddToolResultOptions } from "../../context";
import { ProxyConfigProvider } from "../../internal";

export class LocalRuntime extends BaseAssistantRuntime<LocalThreadRuntime> {
private readonly _configProviders: Set<ModelConfigProvider>;
private readonly _proxyConfigProvider: ProxyConfigProvider;

constructor(adapter: ChatModelAdapter) {
const configProviders = new Set<ModelConfigProvider>();
super(new LocalThreadRuntime(configProviders, adapter));
this._configProviders = configProviders;
const proxyConfigProvider = new ProxyConfigProvider();
super(new LocalThreadRuntime(proxyConfigProvider, adapter));
this._proxyConfigProvider = proxyConfigProvider;
}

public set adapter(adapter: ChatModelAdapter) {
this.thread.adapter = adapter;
}

registerModelConfigProvider(provider: ModelConfigProvider) {
this._configProviders.add(provider);
return () => this._configProviders.delete(provider);
return this._proxyConfigProvider.registerModelConfigProvider(provider);
}

public switchToThread(threadId: string | null) {
Expand All @@ -39,7 +36,7 @@ export class LocalRuntime extends BaseAssistantRuntime<LocalThreadRuntime> {
}

return (this.thread = new LocalThreadRuntime(
this._configProviders,
this._proxyConfigProvider,
this.thread.adapter,
));
}
Expand Down Expand Up @@ -68,7 +65,7 @@ class LocalThreadRuntime implements ThreadRuntime {
}

constructor(
private _configProviders: Set<ModelConfigProvider>,
private configProvider: ModelConfigProvider,
public adapter: ChatModelAdapter,
) {}

Expand All @@ -82,6 +79,11 @@ class LocalThreadRuntime implements ThreadRuntime {
}

public async append(message: AppendMessage): Promise<void> {
if (message.role !== "user")
throw new Error(
"Only appending user messages are supported in LocalRuntime. This is likely an internal bug in assistant-ui.",
);

// add user message
const userMessageId = generateId();
const userMessage: ThreadUserMessage = {
Expand All @@ -96,14 +98,12 @@ class LocalThreadRuntime implements ThreadRuntime {
}

public async startRun(parentId: string | null): Promise<void> {
const id = generateId();

this.repository.resetHead(parentId);
const messages = this.repository.getMessages();

// add assistant message
const message: ThreadAssistantMessage = {
id,
id: generateId(),
role: "assistant",
status: { type: "in_progress" },
content: [{ type: "text", text: "" }],
Expand All @@ -126,7 +126,7 @@ class LocalThreadRuntime implements ThreadRuntime {
const result = await this.adapter.run({
messages,
abortSignal: this.abortController.signal,
config: mergeModelConfigs(this._configProviders),
config: this.configProvider.getModelConfig(),
onUpdate: updateHandler,
});
if (result !== undefined) {
Expand Down Expand Up @@ -154,7 +154,6 @@ class LocalThreadRuntime implements ThreadRuntime {

this.abortController.abort();
this.abortController = null;
this.notifySubscribers();
}

private notifySubscribers() {
Expand Down
7 changes: 6 additions & 1 deletion packages/react/src/runtimes/local/index.ts
Original file line number Diff line number Diff line change
@@ -1,2 +1,7 @@
export { useLocalRuntime } from "./useLocalRuntime";
export type { ChatModelAdapter, ChatModelRunOptions } from "./ChatModelAdapter";
export type {
ChatModelAdapter,
ChatModelRunOptions,
ChatModelRunResult,
ChatModelRunUpdate,
} from "./ChatModelAdapter";
Loading

0 comments on commit e832e1c

Please sign in to comment.