Skip to content

Commit

Permalink
feat: allow calling ThreadList methods via threadId or remoteId and u…
Browse files Browse the repository at this point in the history
…se remoteId in adapter (#1289)

* feat: allow calling ThreadList methods via threadId or remoteId and use remoteId in adapter

* code review feedback
  • Loading branch information
Yonom authored Jan 2, 2025
1 parent d6b3b79 commit 6fc733c
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 46 deletions.
2 changes: 1 addition & 1 deletion packages/react/src/runtimes/core/ThreadListRuntimeCore.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { ThreadRuntimeCore } from "./ThreadRuntimeCore";

type ThreadListItemCoreState = {
readonly threadId: string;
readonly remoteId?: string;
readonly remoteId?: string | undefined;

readonly status: "archived" | "regular" | "new" | "deleted";
readonly title?: string | undefined;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,30 +8,52 @@ import { OptimisticState } from "./OptimisticState";
import { FC, PropsWithChildren, useEffect, useId } from "react";
import { create } from "zustand";

type RemoteThreadData = {
readonly threadId: string;
readonly remoteId?: string;
readonly status: "new" | "regular" | "archived";
readonly title?: string | undefined;
};
type RemoteThreadData =
| {
readonly threadId: string;
readonly remoteId?: undefined;
readonly status: "new";
readonly title: undefined;
}
| {
readonly threadId: string;
readonly remoteId: string;
readonly status: "regular" | "archived";
readonly title?: string | undefined;
};

type THREAD_MAPPING_ID = string & { __brand: "THREAD_MAPPING_ID" };
function createThreadMappingId(id: string): THREAD_MAPPING_ID {
return id as THREAD_MAPPING_ID;
}

type RemoteThreadState = {
readonly isLoading: boolean;
readonly newThreadId: string | undefined;
readonly threadIds: readonly string[];
readonly archivedThreadIds: readonly string[];
readonly threadData: Readonly<Record<string, RemoteThreadData>>;
readonly threadIdMap: Readonly<Record<string, THREAD_MAPPING_ID>>;
readonly threadData: Readonly<Record<THREAD_MAPPING_ID, RemoteThreadData>>;
};

const getThreadData = (
state: RemoteThreadState,
threadIdOrRemoteId: string,
) => {
const idx = state.threadIdMap[threadIdOrRemoteId];
if (idx === undefined) return undefined;
return state.threadData[idx];
};

const updateStatusReducer = (
state: RemoteThreadState,
threadId: string,
threadIdOrRemoteId: string,
newStatus: "regular" | "archived" | "deleted",
) => {
const data = state.threadData[threadId];
const data = getThreadData(state, threadIdOrRemoteId);
if (!data) return state;

const { status: lastStatus } = data;
const { threadId, remoteId, status: lastStatus } = data;
if (lastStatus === newStatus) return state;

const newState = { ...state };
Expand Down Expand Up @@ -70,6 +92,11 @@ const updateStatusReducer = (
newState.threadData = Object.fromEntries(
Object.entries(newState.threadData).filter(([key]) => key !== threadId),
);
newState.threadIdMap = Object.fromEntries(
Object.entries(newState.threadIdMap).filter(
([key]) => key !== threadId && key !== remoteId,
),
);
break;

default: {
Expand Down Expand Up @@ -107,6 +134,7 @@ export class RemoteThreadListThreadListRuntimeCore
newThreadId: undefined,
threadIds: [],
archivedThreadIds: [],
threadIdMap: {},
threadData: {},
});

Expand Down Expand Up @@ -135,7 +163,11 @@ export class RemoteThreadListThreadListRuntimeCore
then: (state, l) => {
const newThreadIds = [];
const newArchivedThreadIds = [];
const newThreadData = {} as Record<string, RemoteThreadData>;
const newThreadIdMap = {} as Record<string, THREAD_MAPPING_ID>;
const newThreadData = {} as Record<
THREAD_MAPPING_ID,
RemoteThreadData
>;

for (const thread of l.threads) {
switch (thread.status) {
Expand All @@ -151,7 +183,9 @@ export class RemoteThreadListThreadListRuntimeCore
}
}

newThreadData[thread.remoteId] = {
const mappingId = createThreadMappingId(thread.remoteId);
newThreadIdMap[thread.remoteId] = mappingId;
newThreadData[mappingId] = {
threadId: thread.remoteId,
remoteId: thread.remoteId,
status: thread.status,
Expand All @@ -163,6 +197,10 @@ export class RemoteThreadListThreadListRuntimeCore
...state,
threadIds: newThreadIds,
archivedThreadIds: newArchivedThreadIds,
threadIdMap: {
...state.threadIdMap,
...newThreadIdMap,
},
threadData: {
...state.threadData,
...newThreadData,
Expand Down Expand Up @@ -205,24 +243,24 @@ export class RemoteThreadListThreadListRuntimeCore
return result;
}

public getItemById(threadId: string) {
return this._state.value.threadData[threadId];
public getItemById(threadIdOrRemoteId: string) {
return getThreadData(this._state.value, threadIdOrRemoteId);
}

public async switchToThread(threadId: string): Promise<void> {
if (this._mainThreadId === threadId) return;

const data = this.getItemById(threadId);
public async switchToThread(threadIdOrRemoteId: string): Promise<void> {
const data = this.getItemById(threadIdOrRemoteId);
if (!data) throw new Error("Thread not found");

const task = this._hookManager.startThreadRuntime(threadId);
if (this._mainThreadId === data.threadId) return;

const task = this._hookManager.startThreadRuntime(data.threadId);
if (this.mainThreadId !== undefined) {
await task;
} else {
task.then(() => this._notifySubscribers());
}

if (data.status === "archived") void this.unarchive(threadId);
if (data.status === "archived") await this.unarchive(data.threadId);
this._mainThreadId = data.threadId;

this._notifySubscribers();
Expand All @@ -242,11 +280,16 @@ export class RemoteThreadListThreadListRuntimeCore
if (threadId === undefined) {
do {
threadId = `__LOCALID_${generateId()}`;
} while (state.threadData[threadId]);
} while (state.threadIdMap[threadId]);

const mappingId = createThreadMappingId(threadId);
this._state.update({
...state,
newThreadId: threadId,
threadIdMap: {
...state.threadIdMap,
[threadId]: mappingId,
},
threadData: {
...state.threadData,
[threadId]: {
Expand Down Expand Up @@ -275,11 +318,16 @@ export class RemoteThreadListThreadListRuntimeCore
return updateStatusReducer(state, threadId, "regular");
},
then: (state, { remoteId }) => {
const data = state.threadData[threadId];
const data = getThreadData(state, threadId);
if (!data) return state;

const mappingId = createThreadMappingId(remoteId);
return {
...state,
threadIdMap: {
...state.threadIdMap,
[remoteId]: mappingId,
},
threadData: {
...state.threadData,
[threadId]: {
Expand All @@ -292,20 +340,24 @@ export class RemoteThreadListThreadListRuntimeCore
});
};

public rename(threadId: string, newTitle: string): Promise<void> {
public rename(threadIdOrRemoteId: string, newTitle: string): Promise<void> {
const data = this.getItemById(threadIdOrRemoteId);
if (!data) throw new Error("Thread not found");
if (data.status === "new") throw new Error("Thread is not yet initialized");

return this._state.optimisticUpdate({
execute: () => {
return this._adapter.rename(threadId, newTitle);
return this._adapter.rename(data.remoteId, newTitle);
},
optimistic: (state) => {
const data = state.threadData[threadId];
const data = getThreadData(state, threadIdOrRemoteId);
if (!data) return state;

return {
...state,
threadData: {
...state.threadData,
[threadId]: {
[data.threadId]: {
...data,
title: newTitle,
},
Expand All @@ -327,56 +379,56 @@ export class RemoteThreadListThreadListRuntimeCore
}
}

public async archive(threadId: string) {
const data = this.getItemById(threadId);
public async archive(threadIdOrRemoteId: string) {
const data = this.getItemById(threadIdOrRemoteId);
if (!data) throw new Error("Thread not found");
if (data.status !== "regular")
throw new Error("Thread is not yet initialized or already archived");

return this._state.optimisticUpdate({
execute: async () => {
await this._ensureThreadIsNotMain(threadId);
return this._adapter.archive(threadId);
await this._ensureThreadIsNotMain(data.threadId);
return this._adapter.archive(data.remoteId);
},
optimistic: (state) => {
return updateStatusReducer(state, threadId, "archived");
return updateStatusReducer(state, data.threadId, "archived");
},
});
}

public unarchive(threadId: string): Promise<void> {
const data = this.getItemById(threadId);
public unarchive(threadIdOrRemoteId: string): Promise<void> {
const data = this.getItemById(threadIdOrRemoteId);
if (!data) throw new Error("Thread not found");
if (data.status !== "archived") throw new Error("Thread is not archived");

return this._state.optimisticUpdate({
execute: async () => {
try {
return await this._adapter.unarchive(threadId);
return await this._adapter.unarchive(data.remoteId);
} catch (error) {
await this._ensureThreadIsNotMain(threadId);
await this._ensureThreadIsNotMain(data.threadId);
throw error;
}
},
optimistic: (state) => {
return updateStatusReducer(state, threadId, "regular");
return updateStatusReducer(state, data.threadId, "regular");
},
});
}

public async delete(threadId: string) {
const data = this.getItemById(threadId);
public async delete(threadIdOrRemoteId: string) {
const data = this.getItemById(threadIdOrRemoteId);
if (!data) throw new Error("Thread not found");
if (data.status !== "regular" && data.status !== "archived")
throw new Error("Thread is not yet initialized");

return this._state.optimisticUpdate({
execute: async () => {
await this._ensureThreadIsNotMain(threadId);
return await this._adapter.delete(threadId);
await this._ensureThreadIsNotMain(data.threadId);
return await this._adapter.delete(data.remoteId);
},
optimistic: (state) => {
return updateStatusReducer(state, threadId, "deleted");
return updateStatusReducer(state, data.threadId, "deleted");
},
});
}
Expand Down
8 changes: 4 additions & 4 deletions packages/react/src/runtimes/remote-thread-list/types.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ export type RemoteThreadListAdapter = {

list(): Promise<RemoteThreadListResponse>;

rename(threadId: string, newName: string): Promise<void>;
archive(threadId: string): Promise<void>;
unarchive(threadId: string): Promise<void>;
delete(threadId: string): Promise<void>;
rename(remoteId: string, newTitle: string): Promise<void>;
archive(remoteId: string): Promise<void>;
unarchive(remoteId: string): Promise<void>;
delete(remoteId: string): Promise<void>;

onInitialize(
callback: (task: Promise<{ remoteId: string }>) => Promise<void>,
Expand Down

0 comments on commit 6fc733c

Please sign in to comment.