Skip to content

Commit

Permalink
refactor(runtime/external-store): rewrite sync mechanism (#630)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yonom authored Aug 4, 2024
1 parent d57b049 commit 07f76c8
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 154 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,6 @@ export class ExternalStoreRuntime extends BaseAssistantRuntime<ExternalStoreThre
this.thread.store = store;
}

public onStoreUpdated() {
return this.thread.onStoreUpdated();
}

public getModelConfig() {
return this._proxyConfigProvider.getModelConfig();
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
import { create } from "zustand";
import { ReactThreadRuntime } from "../core";
import { MessageRepository } from "../utils/MessageRepository";
import { AppendMessage, ThreadMessage, Unsubscribe } from "../../types";
import { ExternalStoreAdapter } from "./ExternalStoreAdapter";
import { AddToolResultOptions } from "../../context";
import { getExternalStoreMessage } from "./getExternalStoreMessage";
import { useExternalStoreSync } from "./useExternalStoreSync";
import {
getExternalStoreMessage,
symbolInnerMessage,
} from "./getExternalStoreMessage";
import {
ConverterCallback,
ThreadMessageConverter,
} from "./ThreadMessageConverter";
import { getAutoStatus, isAutoStatus } from "./auto-status";
import { fromThreadMessageLike } from "./ThreadMessageLike";

export const hasUpcomingMessage = (
isRunning: boolean,
Expand All @@ -19,40 +26,112 @@ export class ExternalStoreThreadRuntime implements ReactThreadRuntime {
private repository = new MessageRepository();
private assistantOptimisticId: string | null = null;

private useStore;

public get capabilities() {
return {
switchToBranch: this.store.setMessages !== undefined,
edit: this.store.onEdit !== undefined,
reload: this.store.onReload !== undefined,
cancel: this.store.onCancel !== undefined,
copy: this.store.onCopy !== null,
switchToBranch: this._store.setMessages !== undefined,
edit: this._store.onEdit !== undefined,
reload: this._store.onReload !== undefined,
cancel: this._store.onCancel !== undefined,
copy: this._store.onCopy !== null,
};
}

public messages: ThreadMessage[] = [];
public isDisabled = false;
public isRunning = false;
public converter = new ThreadMessageConverter();

private _store;

constructor(store: ExternalStoreAdapter<any>) {
this._store = store;
}

public set store(store: ExternalStoreAdapter<any>) {
const oldStore = this._store;
this._store = store;

// flush the converter cache when the convertMessage prop changes
if (oldStore.convertMessage !== store.convertMessage) {
this.converter = new ThreadMessageConverter();
} else if (
oldStore.isDisabled === store.isDisabled &&
oldStore.isRunning === store.isRunning &&
oldStore.messages === store.messages
) {
// no update needed
return;
}

const isRunning = store.isRunning ?? false;
const isDisabled = store.isDisabled ?? false;

const convertCallback: ConverterCallback<any> = (cache, m, idx) => {
if (!store.convertMessage) return m;

const isLast = idx === store.messages.length - 1;
const autoStatus = getAutoStatus(isLast, isRunning);

constructor(public store: ExternalStoreAdapter<any>) {
this.updateData(
store.isDisabled ?? false,
store.isRunning ?? false,
if (
cache &&
(cache.role !== "assistant" ||
!isAutoStatus(cache.status) ||
cache.status === autoStatus)
)
return cache;

const newMessage = fromThreadMessageLike(
store.convertMessage(m, idx),
idx.toString(),
autoStatus,
);
(newMessage as any)[symbolInnerMessage] = m;
return newMessage;
};

const messages = this.converter.convertMessages(
store.messages,
convertCallback,
);

this.useStore = create(() => ({
store,
}));
for (let i = 0; i < messages.length; i++) {
const message = messages[i]!;
const parent = messages[i - 1];
this.repository.addOrUpdateMessage(parent?.id ?? null, message);
}

if (this.assistantOptimisticId) {
this.repository.deleteMessage(this.assistantOptimisticId);
this.assistantOptimisticId = null;
}

if (hasUpcomingMessage(isRunning, messages)) {
this.assistantOptimisticId = this.repository.appendOptimisticMessage(
messages.at(-1)?.id ?? null,
{
role: "assistant",
content: [],
},
);
}

this.repository.resetHead(
this.assistantOptimisticId ?? messages.at(-1)?.id ?? null,
);

this.messages = this.repository.getMessages();
this.isDisabled = isDisabled;
this.isRunning = isRunning;

for (const callback of this._subscriptions) callback();
}

public getBranches(messageId: string): string[] {
return this.repository.getBranches(messageId);
}

public switchToBranch(branchId: string): void {
if (!this.store.setMessages)
if (!this._store.setMessages)
throw new Error("Runtime does not support switching branches.");

this.repository.switchToBranch(branchId);
Expand All @@ -61,26 +140,26 @@ export class ExternalStoreThreadRuntime implements ReactThreadRuntime {

public async append(message: AppendMessage): Promise<void> {
if (message.parentId !== (this.messages.at(-1)?.id ?? null)) {
if (!this.store.onEdit)
if (!this._store.onEdit)
throw new Error("Runtime does not support editing messages.");
await this.store.onEdit(message);
await this._store.onEdit(message);
} else {
await this.store.onNew(message);
await this._store.onNew(message);
}
}

public async startRun(parentId: string | null): Promise<void> {
if (!this.store.onReload)
if (!this._store.onReload)
throw new Error("Runtime does not support reloading messages.");

await this.store.onReload(parentId);
await this._store.onReload(parentId);
}

public cancelRun(): void {
if (!this.store.onCancel)
if (!this._store.onCancel)
throw new Error("Runtime does not support cancelling runs.");

this.store.onCancel();
this._store.onCancel();

if (this.assistantOptimisticId) {
this.repository.deleteMessage(this.assistantOptimisticId);
Expand All @@ -101,65 +180,14 @@ export class ExternalStoreThreadRuntime implements ReactThreadRuntime {
}

private updateMessages = (messages: ThreadMessage[]) => {
this.store.setMessages?.(
this._store.setMessages?.(
messages.flatMap(getExternalStoreMessage).filter((m) => m != null),
);
};

public onStoreUpdated() {
if (this.useStore.getState().store !== this.store) {
this.useStore.setState({ store: this.store });
}
}

private updateData = (
isDisabled: boolean,
isRunning: boolean,
vm: ThreadMessage[],
) => {
for (let i = 0; i < vm.length; i++) {
const message = vm[i]!;
const parent = vm[i - 1];
this.repository.addOrUpdateMessage(parent?.id ?? null, message);
}

if (this.assistantOptimisticId) {
this.repository.deleteMessage(this.assistantOptimisticId);
this.assistantOptimisticId = null;
}

if (hasUpcomingMessage(isRunning, vm)) {
this.assistantOptimisticId = this.repository.appendOptimisticMessage(
vm.at(-1)?.id ?? null,
{
role: "assistant",
content: [],
},
);
}

this.repository.resetHead(
this.assistantOptimisticId ?? vm.at(-1)?.id ?? null,
);

this.messages = this.repository.getMessages();
this.isDisabled = isDisabled;
this.isRunning = isRunning;

for (const callback of this._subscriptions) callback();
};

unstable_synchronizer = () => {
const { store } = this.useStore();

useExternalStoreSync(store, this.updateData);

return null;
};

addToolResult(options: AddToolResultOptions) {
if (!this.store.onAddToolResult)
if (!this._store.onAddToolResult)
throw new Error("Runtime does not support tool results.");
this.store.onAddToolResult(options);
this._store.onAddToolResult(options);
}
}
5 changes: 5 additions & 0 deletions packages/react/src/runtimes/external-store/auto-status.tsx
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import { MessageStatus } from "../../types";

const AUTO_STATUS_RUNNING = Object.freeze({ type: "running" });
const AUTO_STATUS_COMPLETE = Object.freeze({
type: "complete",
reason: "unknown",
});

export const isAutoStatus = (status: MessageStatus) =>
status === AUTO_STATUS_RUNNING || status === AUTO_STATUS_COMPLETE;

export const getAutoStatus = (isLast: boolean, isRunning: boolean) =>
isLast && isRunning ? AUTO_STATUS_RUNNING : AUTO_STATUS_COMPLETE;
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
import { useEffect, useInsertionEffect, useState } from "react";
import { useInsertionEffect, useState } from "react";
import { ExternalStoreRuntime } from "./ExternalStoreRuntime";
import { ExternalStoreAdapter } from "./ExternalStoreAdapter";

export const useExternalStoreRuntime = (store: ExternalStoreAdapter<any>) => {
export const useExternalStoreRuntime = <T,>(store: ExternalStoreAdapter<T>) => {
const [runtime] = useState(() => new ExternalStoreRuntime(store));

useInsertionEffect(() => {
runtime.store = store;
});
useEffect(() => {
runtime.onStoreUpdated();
});

return runtime;
};

This file was deleted.

0 comments on commit 07f76c8

Please sign in to comment.