Skip to content

Commit

Permalink
feat: useRemoteThreadListRuntime (#1286)
Browse files Browse the repository at this point in the history
* feat: useRemoteThreadListRuntime

* Update packages/react/src/runtimes/remote-thread-list/EMPTY_THREAD_CORE.tsx

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>

* review fixes

* Update packages/react/src/runtimes/remote-thread-list/RemoteThreadListThreadListRuntimeCore.tsx

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>

* Update packages/react/src/runtimes/remote-thread-list/RemoteThreadListThreadListRuntimeCore.tsx

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>

---------

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Jan 2, 2025
1 parent 5c6b3f7 commit d6b3b79
Show file tree
Hide file tree
Showing 22 changed files with 956 additions and 44 deletions.
5 changes: 5 additions & 0 deletions .changeset/strong-kids-brush.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@assistant-ui/react": patch
---

feat: useRemoteThreadListRuntime
7 changes: 6 additions & 1 deletion packages/react-playground/src/lib/playground-runtime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ const makeModelConfigStore = () =>
type PlaygroundThreadFactory = () => PlaygroundThreadRuntimeCore;

const EMPTY_ARRAY = [] as never[];
const RESOLVED_PROMISE = Promise.resolve();

class PlaygroundThreadListRuntimeCore
implements INTERNAL.ThreadListRuntimeCore
Expand All @@ -79,6 +80,10 @@ class PlaygroundThreadListRuntimeCore
this._mainThread = this.threadFactory();
}

public getLoadThreadsPromise(): Promise<void> {
return RESOLVED_PROMISE;
}

public getMainThreadRuntimeCore() {
return this._mainThread;
}
Expand All @@ -88,7 +93,7 @@ class PlaygroundThreadListRuntimeCore

return {
threadId: "default",
state: "regular",
status: "regular",
runtime: this._mainThread,
} as const;
}
Expand Down
8 changes: 2 additions & 6 deletions packages/react/src/api/AssistantRuntime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,7 @@ export type AssistantRuntime = {
registerModelConfigProvider(provider: ModelConfigProvider): Unsubscribe;
};

export class AssistantRuntimeImpl
implements
Omit<AssistantRuntimeCore, "thread" | "threadList">,
AssistantRuntime
{
export class AssistantRuntimeImpl implements AssistantRuntime {
public readonly threadList;
public readonly _thread: ThreadRuntime;

Expand All @@ -66,7 +62,7 @@ export class AssistantRuntimeImpl
getState: () => _core.threadList.getMainThreadRuntimeCore(),
subscribe: (callback) => _core.threadList.subscribe(callback),
}),
this.threadList.mainThreadListItem, // TODO capture "main" threadListItem from context around useLocalRuntime / useExternalStoreRuntime
this.threadList.mainItem, // TODO capture "main" threadListItem from context around useLocalRuntime / useExternalStoreRuntime
);
}

Expand Down
3 changes: 2 additions & 1 deletion packages/react/src/api/ThreadListItemRuntime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@ export type ThreadListItemState = {
readonly isMain: boolean;

readonly id: string;
readonly remoteId: string | undefined;

/**
* @deprecated This field was renamed to `id`. This field will be removed in 0.8.0.
*/
readonly threadId: string;

readonly state: "archived" | "regular" | "new" | "deleted";
readonly status: "archived" | "regular" | "new" | "deleted";
readonly title?: string | undefined;
};

Expand Down
19 changes: 10 additions & 9 deletions packages/react/src/api/ThreadListRuntime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ export type ThreadListRuntime = {

subscribe(callback: () => void): Unsubscribe;

readonly mainThreadListItem: ThreadListItemRuntime;
getThreadListItemById(threadId: string): ThreadListItemRuntime;
getThreadListItemByIndex(idx: number): ThreadListItemRuntime;
getThreadListArchivedItemByIndex(idx: number): ThreadListItemRuntime;
readonly mainItem: ThreadListItemRuntime;
getItemById(threadId: string): ThreadListItemRuntime;
getItemByIndex(idx: number): ThreadListItemRuntime;
getArchivedItemByIndex(idx: number): ThreadListItemRuntime;
};

const getThreadListState = (
Expand All @@ -48,9 +48,10 @@ const getThreadListItemState = (
if (!threadData) return SKIP_UPDATE;
return {
id: threadData.threadId,
remoteId: threadData.remoteId,
threadId: threadData.threadId, // TODO remove in 0.8.0
title: threadData.title,
state: threadData.state,
status: threadData.status,
isMain: threadData.threadId === threadList.mainThreadId,
};
};
Expand Down Expand Up @@ -93,11 +94,11 @@ export class ThreadListRuntimeImpl implements ThreadListRuntime {

private _mainThreadListItemRuntime;

public get mainThreadListItem() {
public get mainItem() {
return this._mainThreadListItemRuntime;
}

public getThreadListItemByIndex(idx: number) {
public getItemByIndex(idx: number) {
return new ThreadListItemRuntimeImpl(
new ShallowMemoizeSubject({
path: {
Expand All @@ -113,7 +114,7 @@ export class ThreadListRuntimeImpl implements ThreadListRuntime {
);
}

public getThreadListArchivedItemByIndex(idx: number) {
public getArchivedItemByIndex(idx: number) {
return new ThreadListItemRuntimeImpl(
new ShallowMemoizeSubject({
path: {
Expand All @@ -132,7 +133,7 @@ export class ThreadListRuntimeImpl implements ThreadListRuntime {
);
}

public getThreadListItemById(threadId: string) {
public getItemById(threadId: string) {
return new ThreadListItemRuntimeImpl(
new ShallowMemoizeSubject({
path: {
Expand Down
6 changes: 3 additions & 3 deletions packages/react/src/api/ThreadRuntime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -215,11 +215,11 @@ export class ThreadRuntimeImpl implements ThreadRuntime {
return this._threadBinding.path;
}

public unstable_getCore() {
return this._threadBinding.getState();
public get __internal_threadBinding() {
return this._threadBinding;
}

private _threadBinding: ThreadRuntimeCoreBinding & {
private readonly _threadBinding: ThreadRuntimeCoreBinding & {
getStateState(): ThreadState;
};

Expand Down
27 changes: 21 additions & 6 deletions packages/react/src/context/providers/AssistantRuntimeProvider.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import { ThreadRuntimeProvider } from "./ThreadRuntimeProvider";
import { AssistantRuntime } from "../../api/AssistantRuntime";
import { create } from "zustand";
import { writableStore } from "../ReadonlyStore";
import { AssistantRuntimeCore } from "../../runtimes/core/AssistantRuntimeCore";

export namespace AssistantRuntimeProvider {
export type Props = PropsWithChildren<{
Expand Down Expand Up @@ -45,6 +46,16 @@ const useThreadListStore = (runtime: AssistantRuntime) => {
return store;
};

const DEFAULT_RENDER_COMPONENT: FC<PropsWithChildren> = ({ children }) =>
children;

const getRenderComponent = (runtime: AssistantRuntime) => {
return (
(runtime as { _core?: AssistantRuntimeCore })._core
?.__internal_RenderComponent ?? DEFAULT_RENDER_COMPONENT
);
};

export const AssistantRuntimeProviderImpl: FC<
AssistantRuntimeProvider.Props
> = ({ children, runtime }) => {
Expand All @@ -59,14 +70,18 @@ export const AssistantRuntimeProviderImpl: FC<
};
}, [useAssistantRuntime, useToolUIs, useThreadList]);

const RenderComponent = getRenderComponent(runtime);

return (
<AssistantContext.Provider value={context}>
<ThreadRuntimeProvider
runtime={runtime.thread}
listItemRuntime={runtime.threadList.mainThreadListItem}
>
{children}
</ThreadRuntimeProvider>
<RenderComponent>
<ThreadRuntimeProvider
runtime={runtime.thread}
listItemRuntime={runtime.threadList.mainItem}
>
{children}
</ThreadRuntimeProvider>
</RenderComponent>
</AssistantContext.Provider>
);
};
Expand Down
8 changes: 3 additions & 5 deletions packages/react/src/primitives/threadList/ThreadListItems.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,9 @@ const ThreadListItemImpl: FC<ThreadListItemProps> = ({
const runtime = useMemo(
() =>
archived
? assistantRuntime.threadList.getThreadListArchivedItemByIndex(
partIndex,
)
: assistantRuntime.threadList.getThreadListItemByIndex(partIndex),
[assistantRuntime, partIndex],
? assistantRuntime.threadList.getArchivedItemByIndex(partIndex)
: assistantRuntime.threadList.getItemByIndex(partIndex),
[assistantRuntime, partIndex, archived],
);

const ThreadListItemComponent = components.ThreadListItem;
Expand Down
3 changes: 3 additions & 0 deletions packages/react/src/runtimes/core/AssistantRuntimeCore.tsx
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { PropsWithChildren } from "react";
import type { ModelConfigProvider } from "../../types/ModelConfigTypes";
import type { Unsubscribe } from "../../types/Unsubscribe";
import { ThreadListRuntimeCore } from "./ThreadListRuntimeCore";
Expand All @@ -6,4 +7,6 @@ export type AssistantRuntimeCore = {
readonly threadList: ThreadListRuntimeCore;

registerModelConfigProvider: (provider: ModelConfigProvider) => Unsubscribe;

__internal_RenderComponent?: React.FC<PropsWithChildren>;
};
6 changes: 4 additions & 2 deletions packages/react/src/runtimes/core/ThreadListRuntimeCore.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ import { ThreadRuntimeCore } from "./ThreadRuntimeCore";

type ThreadListItemCoreState = {
readonly threadId: string;
readonly state: "archived" | "regular" | "new" | "deleted";
readonly remoteId?: string;

readonly status: "archived" | "regular" | "new" | "deleted";
readonly title?: string | undefined;

readonly runtime?: ThreadRuntimeCore | undefined;
Expand All @@ -22,7 +24,7 @@ export type ThreadListRuntimeCore = {
switchToThread(threadId: string): Promise<void>;
switchToNewThread(): Promise<void>;

// getLoadThreadsPromise(): Promise<void>;
getLoadThreadsPromise(): Promise<void>;
// getLoadArchivedThreadsPromise(): Promise<void>;

rename(threadId: string, newTitle: string): Promise<void>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import { SpeechSynthesisAdapter } from "../speech/SpeechAdapterTypes";
import { ThreadMessageLike } from "./ThreadMessageLike";

export type ExternalStoreThreadData<TState extends "regular" | "archived"> = {
state: TState;
status: TState;
threadId: string;
title?: string | undefined;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@ const DEFAULT_THREAD_ID = "DEFAULT_THREAD_ID";
const DEFAULT_THREADS = Object.freeze([DEFAULT_THREAD_ID]);
const DEFAULT_THREAD: ExternalStoreThreadData<"regular"> = Object.freeze({
threadId: DEFAULT_THREAD_ID,
state: "regular",
status: "regular",
});
const RESOLVED_PROMISE = Promise.resolve();

export class ExternalStoreThreadListRuntimeCore
implements ThreadListRuntimeCore
Expand All @@ -35,6 +36,10 @@ export class ExternalStoreThreadListRuntimeCore
return this._archivedThreads;
}

public getLoadThreadsPromise() {
return RESOLVED_PROMISE;
}

private _mainThread: ExternalStoreThreadRuntimeCore;

public get mainThreadId() {
Expand Down
1 change: 1 addition & 0 deletions packages/react/src/runtimes/index.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
export * from "./core";
export * from "./local";
// export * from "./remote-thread-list";
export * from "./edge";
export * from "./external-store";
export * from "./dangerous-in-browser";
Expand Down
23 changes: 14 additions & 9 deletions packages/react/src/runtimes/local/LocalThreadListRuntimeCore.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@ export type ThreadListAdapter = {

export type LocalThreadData = {
readonly runtime: LocalThreadRuntimeCore;
readonly state: "new" | "regular" | "archived";
readonly status: "new" | "regular" | "archived";
readonly threadId: string;
readonly title?: string | undefined;
};

export type LocalThreadFactory = () => LocalThreadRuntimeCore;

const RESOLVED_PROMISE = Promise.resolve();
export class LocalThreadListRuntimeCore implements ThreadListRuntimeCore {
private _threadData = new Map<string, LocalThreadData>();
private _threadIds: readonly string[] = [];
Expand Down Expand Up @@ -51,6 +52,10 @@ export class LocalThreadListRuntimeCore implements ThreadListRuntimeCore {
return result;
}

public getLoadThreadsPromise(): Promise<void> {
return RESOLVED_PROMISE;
}

public getItemById(threadId: string) {
return this._threadData.get(threadId);
}
Expand All @@ -61,7 +66,7 @@ export class LocalThreadListRuntimeCore implements ThreadListRuntimeCore {
const data = this._threadData.get(threadId);
if (!data) throw new Error("Thread not found");

if (data.state === "archived") await this.unarchive(threadId);
if (data.status === "archived") await this.unarchive(threadId);

this._mainThreadId = data.threadId;
this._notifySubscribers();
Expand All @@ -84,7 +89,7 @@ export class LocalThreadListRuntimeCore implements ThreadListRuntimeCore {
});
this._threadData.set(threadId, {
runtime,
state: "new",
status: "new",
threadId,
});
this._newThreadId = threadId;
Expand All @@ -101,7 +106,7 @@ export class LocalThreadListRuntimeCore implements ThreadListRuntimeCore {
const data = this._threadData.get(threadId);
if (!data) throw new Error("Thread not found");

const { state: lastState } = data;
const { status: lastState } = data;
if (lastState === newState) return;

// lastState
Expand Down Expand Up @@ -147,7 +152,7 @@ export class LocalThreadListRuntimeCore implements ThreadListRuntimeCore {
if (newState !== "deleted") {
this._threadData.set(threadId, {
...data,
state: newState,
status: newState,
});
}

Expand Down Expand Up @@ -181,7 +186,7 @@ export class LocalThreadListRuntimeCore implements ThreadListRuntimeCore {
public archive(threadId: string): Promise<void> {
const data = this._threadData.get(threadId);
if (!data) throw new Error("Thread not found");
if (data.state !== "regular")
if (data.status !== "regular")
throw new Error("Thread is not yet initialized or already archived");

this._stateOp(threadId, "archived");
Expand All @@ -191,7 +196,7 @@ export class LocalThreadListRuntimeCore implements ThreadListRuntimeCore {
public unarchive(threadId: string): Promise<void> {
const data = this._threadData.get(threadId);
if (!data) throw new Error("Thread not found");
if (data.state !== "archived") throw new Error("Thread is not archived");
if (data.status !== "archived") throw new Error("Thread is not archived");

this._stateOp(threadId, "regular");
return Promise.resolve();
Expand All @@ -200,8 +205,8 @@ export class LocalThreadListRuntimeCore implements ThreadListRuntimeCore {
public delete(threadId: string): Promise<void> {
const data = this._threadData.get(threadId);
if (!data) throw new Error("Thread not found");
if (data.state !== "regular" && data.state !== "archived")
throw new Error("Thread is not yet initalized");
if (data.status !== "regular" && data.status !== "archived")
throw new Error("Thread is not yet initialized");

this._stateOp(threadId, "deleted");
return Promise.resolve();
Expand Down
Loading

0 comments on commit d6b3b79

Please sign in to comment.