Skip to content

Commit

Permalink
feat: useCloudThreadListRuntime (#1290)
Browse files Browse the repository at this point in the history
* feat: useCloudThreadListRuntime

* fix

* fixes

* fixes
  • Loading branch information
Yonom authored Jan 3, 2025
1 parent 37e1abc commit c0b4d31
Show file tree
Hide file tree
Showing 17 changed files with 484 additions and 32 deletions.
41 changes: 32 additions & 9 deletions packages/react-langgraph/src/useLangGraphRuntime.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { useState } from "react";
import { useEffect, useRef, useState } from "react";
import { LangChainMessage, LangChainToolCall } from "./types";
import {
useExternalMessageConverter,
Expand All @@ -10,6 +10,7 @@ import { SimpleImageAttachmentAdapter } from "@assistant-ui/react";
import { AttachmentAdapter } from "@assistant-ui/react";
import { AppendMessage } from "@assistant-ui/react";
import { ExternalStoreAdapter } from "@assistant-ui/react";
import { useThreadListItemRuntime } from "@assistant-ui/react/context/react/ThreadListItemContext";

const getPendingToolCalls = (messages: LangChainMessage[]) => {
const pendingToolCalls = new Map<string, LangChainToolCall>();
Expand Down Expand Up @@ -59,14 +60,17 @@ const getMessageContent = (msg: AppendMessage) => {
};

export const useLangGraphRuntime = ({
threadId,
autoCancelPendingToolCalls,
adapters: { attachments } = {},
unstable_allowImageAttachments,
stream,
threadId,
onSwitchToNewThread,
onSwitchToThread,
adapters: { attachments } = {},
}: {
/**
* @deprecated For thread management use `useCloudThreadListRuntime` instead. This option will be removed in a future version.
*/
threadId?: string | undefined;
autoCancelPendingToolCalls?: boolean | undefined;
/**
Expand All @@ -79,6 +83,9 @@ export const useLangGraphRuntime = ({
data: any;
}>
>;
/**
* @deprecated For thread management use `useCloudThreadListRuntime` instead. This option will be removed in a future version.
*/
onSwitchToNewThread?: () => Promise<void> | void;
onSwitchToThread?: (
threadId: string,
Expand Down Expand Up @@ -118,6 +125,13 @@ export const useLangGraphRuntime = ({
if (unstable_allowImageAttachments)
attachments = new SimpleImageAttachmentAdapter();

const switchToThread = !onSwitchToThread
? undefined
: async (threadId: string) => {
const { messages } = await onSwitchToThread(threadId);
setMessages(messages);
};

const threadList: NonNullable<
ExternalStoreAdapter["adapters"]
>["threadList"] = {
Expand All @@ -128,14 +142,23 @@ export const useLangGraphRuntime = ({
await onSwitchToNewThread();
setMessages([]);
},
onSwitchToThread: !onSwitchToThread
? undefined
: async (threadId) => {
const { messages } = await onSwitchToThread(threadId);
setMessages(messages);
},
onSwitchToThread: switchToThread,
};

const loadingRef = useRef(false);
const threadListItemRuntime = useThreadListItemRuntime({ optional: true });
useEffect(() => {
if (!threadListItemRuntime || !switchToThread || loadingRef.current) return;
console.log("switching to thread");
const externalId = threadListItemRuntime.getState().externalId;
if (externalId) {
loadingRef.current = true;
switchToThread(externalId).finally(() => {
loadingRef.current = false;
});
}
}, []);

return useExternalStoreRuntime({
isRunning,
messages: threadMessages,
Expand Down
1 change: 1 addition & 0 deletions packages/react/src/api/ThreadListItemRuntime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ export type ThreadListItemState = {

readonly id: string;
readonly remoteId: string | undefined;
readonly externalId: string | undefined;

/**
* @deprecated This field was renamed to `id`. This field will be removed in 0.8.0.
Expand Down
3 changes: 2 additions & 1 deletion packages/react/src/api/ThreadListRuntime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@ const getThreadListItemState = (
if (!threadData) return SKIP_UPDATE;
return {
id: threadData.threadId,
remoteId: threadData.remoteId,
threadId: threadData.threadId, // TODO remove in 0.8.0
remoteId: threadData.remoteId,
externalId: threadData.externalId,
title: threadData.title,
status: threadData.status,
isMain: threadData.threadId === threadList.mainThreadId,
Expand Down
1 change: 1 addition & 0 deletions packages/react/src/runtimes/core/ThreadListRuntimeCore.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import { ThreadRuntimeCore } from "./ThreadRuntimeCore";
type ThreadListItemCoreState = {
readonly threadId: string;
readonly remoteId?: string | undefined;
readonly externalId?: string | undefined;

readonly status: "archived" | "regular" | "new" | "deleted";
readonly title?: string | undefined;
Expand Down
10 changes: 5 additions & 5 deletions packages/react/src/runtimes/index.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
export * from "./attachment";
export * from "./core";
export * from "./local";
// export * from "./remote-thread-list";
export * from "./dangerous-in-browser";
export * from "./edge";
export * from "./external-store";
export * from "./dangerous-in-browser";
export * from "./speech";
export * from "./attachment";
export * from "./feedback";
export * from "./local";
export * from "./remote-thread-list";
export * from "./speech";
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,28 @@ import { EMPTY_THREAD_CORE } from "./EMPTY_THREAD_CORE";
import { OptimisticState } from "./OptimisticState";
import { FC, PropsWithChildren, useEffect, useId } from "react";
import { create } from "zustand";
import { CloudInitializeResponse } from "./cloud/CloudContext";

type RemoteThreadData =
| {
readonly threadId: string;
readonly remoteId?: undefined;
readonly externalId?: undefined;
readonly status: "new";
readonly title: undefined;
}
| {
readonly threadId: string;
readonly remoteId: string;
readonly externalId: string | undefined;
readonly status: "regular" | "archived";
readonly title?: string | undefined;
};

const DEFAULT_RENDER_COMPONENT: FC<PropsWithChildren> = ({ children }) => {
return children;
};

type THREAD_MAPPING_ID = string & { __brand: "THREAD_MAPPING_ID" };
function createThreadMappingId(id: string): THREAD_MAPPING_ID {
return id as THREAD_MAPPING_ID;
Expand Down Expand Up @@ -149,6 +156,10 @@ export class RemoteThreadListThreadListRuntimeCore
this._hookManager = new RemoteThreadListHookInstanceManager(
adapter.runtimeHook,
);
this.useRenderComponent = create(() => ({
RenderComponent:
adapter.__internal_RenderComponent ?? DEFAULT_RENDER_COMPONENT,
}));
this.__internal_setAdapter(adapter);

this._loadThreadsPromise = this._state
Expand Down Expand Up @@ -188,6 +199,7 @@ export class RemoteThreadListThreadListRuntimeCore
newThreadData[mappingId] = {
threadId: thread.remoteId,
remoteId: thread.remoteId,
externalId: thread.externalId,
status: thread.status,
title: thread.title,
};
Expand All @@ -213,11 +225,23 @@ export class RemoteThreadListThreadListRuntimeCore
this.switchToNewThread();
}

private useRenderComponent;

public __internal_setAdapter(adapter: RemoteThreadListAdapter) {
if (this._adapter === adapter) return;

this._adapter = adapter;
this._disposeOldAdapter?.();
this._disposeOldAdapter = this._adapter.onInitialize(this._onInitialize);

const RenderComponent =
adapter.__internal_RenderComponent ?? DEFAULT_RENDER_COMPONENT;
if (
RenderComponent !== this.useRenderComponent.getState().RenderComponent
) {
this.useRenderComponent.setState({ RenderComponent }, true);
}

this._hookManager.setRuntimeHook(adapter.runtimeHook);
}

Expand Down Expand Up @@ -303,7 +327,7 @@ export class RemoteThreadListThreadListRuntimeCore
return this.switchToThread(threadId);
}

private _onInitialize = async (task: Promise<{ remoteId: string }>) => {
private _onInitialize = async (task: Promise<CloudInitializeResponse>) => {
const threadId = this._state.value.newThreadId;
if (!threadId)
throw new Error(
Expand All @@ -317,7 +341,7 @@ export class RemoteThreadListThreadListRuntimeCore
optimistic: (state) => {
return updateStatusReducer(state, threadId, "regular");
},
then: (state, { remoteId }) => {
then: (state, { remoteId, externalId }) => {
const data = getThreadData(state, threadId);
if (!data) return state;

Expand All @@ -333,6 +357,7 @@ export class RemoteThreadListThreadListRuntimeCore
[threadId]: {
...data,
remoteId,
externalId,
},
},
};
Expand Down Expand Up @@ -435,7 +460,7 @@ export class RemoteThreadListThreadListRuntimeCore

private useBoundIds = create<string[]>(() => []);

public __internal_RenderThreadRuntimes: FC<PropsWithChildren> = () => {
public __internal_RenderComponent: FC<PropsWithChildren> = ({ children }) => {
const id = useId();
useEffect(() => {
this.useBoundIds.setState((s) => [...s, id], true);
Expand All @@ -445,10 +470,17 @@ export class RemoteThreadListThreadListRuntimeCore
}, []);

const boundIds = this.useBoundIds();
const { RenderComponent } = this.useRenderComponent();

// only render if the component is the first one mounted
if (boundIds.length > 0 && boundIds[0] !== id) return;
return (
<RenderComponent>
{(boundIds.length === 0 || boundIds[0] === id) && (
// only render if the component is the first one mounted
<this._hookManager.__internal_RenderThreadRuntimes />
)}

return <this._hookManager.__internal_RenderThreadRuntimes />;
{children}
</RenderComponent>
);
};
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import { AssistantCloudAPI, AssistantCloudConfig } from "./AssistantCloudAPI";
import { AssistantCloudThreads } from "./AssistantCloudThreads";

export class AssistantCloud {
public readonly threads;

constructor(config: AssistantCloudConfig) {
const api = new AssistantCloudAPI(config);
this.threads = new AssistantCloudThreads(api);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import {
AssistantCloudAuthStrategy,
AssistantCloudJWTAuthStrategy,
AssistantCloudAPIKeyAuthStrategy,
} from "./AssistantCloudAuthStrategy";

export type AssistantCloudConfig =
| {
baseUrl: string;
// TODO use baseUrl to construct the projectId
unstable_projectId: string;
authToken(): Promise<string>;
}
| {
apiKey: string;
workspaceId: string;
};

export class AssistantCloudAPI {
private _tokenManager: AssistantCloudAuthStrategy;
private _baseUrl;

constructor(config: AssistantCloudConfig) {
if ("authToken" in config) {
this._baseUrl = config.baseUrl;
this._tokenManager = new AssistantCloudJWTAuthStrategy(
config.unstable_projectId,
config.authToken,
);
} else {
this._baseUrl = "https://api.assistant-ui.com";
this._tokenManager = new AssistantCloudAPIKeyAuthStrategy(
config.apiKey,
config.workspaceId,
);
}
}

public async makeRequest(
endpoint: string,
options: {
method?: "POST" | "PUT" | "DELETE" | undefined;
query?: Record<string, string | number | boolean> | undefined;
body?: object | undefined;
} = {},
) {
const authHeaders = await this._tokenManager.getAuthHeaders();
const headers = {
...authHeaders,
"Content-Type": "application/json",
};

const queryParams = new URLSearchParams();
if (options.query) {
for (const [key, value] of Object.entries(options.query)) {
if (value === false) continue;
if (value === true) {
queryParams.set(key, "true");
} else {
queryParams.set(key, value.toString());
}
}
}

const url = new URL(`${this._baseUrl}${endpoint}`);
url.search = queryParams.toString();

const response = await fetch(url, {
method: options.method ?? "GET",
headers,
body: options.body ? JSON.stringify(options.body) : null,
});

if (!response.ok) {
// TODO better error handling
throw new Error(`Request failed with status ${response.status}`);
}

return response.json();
}
}
Loading

0 comments on commit c0b4d31

Please sign in to comment.