From a7eadbcaa6f28acc9ed0e6665830f73d7897b85c Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Mon, 13 May 2024 10:23:29 -0700 Subject: [PATCH] Continue streaming open threads in background when switching to another thread --- frontend/src/App.tsx | 13 ++-- frontend/src/components/Chat.tsx | 6 +- frontend/src/hooks/useStreamState.tsx | 98 ++++++++++++++++++--------- 3 files changed, 76 insertions(+), 41 deletions(-) diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index 0acc62f3..edf72b34 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -23,7 +23,7 @@ function App(props: { edit?: boolean }) { const [sidebarOpen, setSidebarOpen] = useState(false); const { chats, createChat, updateChat, deleteChat } = useChatList(); const { configs, saveConfig, deleteConfig } = useConfigList(); - const { startStream, stopStream, stream } = useStreamState(); + const { startStream, stopStream, streams } = useStreamState(); const { configSchema, configDefaults } = useSchemas(); const { currentChat, assistantConfig, isLoading } = useThreadAndAssistant(); @@ -92,9 +92,6 @@ function App(props: { edit?: boolean }) { const selectChat = useCallback( async (id: string | null) => { - if (currentChat) { - stopStream?.(true); - } if (!id) { const firstAssistant = configs?.[0]?.assistant_id ?? null; navigate(firstAssistant ? `/assistant/${firstAssistant}` : "/"); @@ -106,7 +103,7 @@ function App(props: { edit?: boolean }) { setSidebarOpen(false); } }, - [currentChat, sidebarOpen, stopStream, configs, navigate], + [sidebarOpen, configs, navigate], ); const selectConfig = useCallback( @@ -144,7 +141,11 @@ function App(props: { edit?: boolean }) { } > {currentChat && assistantConfig && ( - + )} {currentChat && !assistantConfig && ( diff --git a/frontend/src/components/Chat.tsx b/frontend/src/components/Chat.tsx index 3b52f47f..caffc280 100644 --- a/frontend/src/components/Chat.tsx +++ b/frontend/src/components/Chat.tsx @@ -1,5 +1,5 @@ import { useEffect, useRef, useState } from "react"; -import { StreamStateProps } from "../hooks/useStreamState"; +import { StreamState } from "../hooks/useStreamState"; import { useChatMessages } from "../hooks/useChatMessages"; import TypingBox from "./TypingBox"; import { MessageViewer } from "./Message"; @@ -14,12 +14,14 @@ import { useMessageEditing } from "../hooks/useMessageEditing.ts"; import { MessageEditor } from "./MessageEditor.tsx"; import { Message } from "../types.ts"; -interface ChatProps extends Pick { +interface ChatProps { startStream: ( message: MessageWithFiles | null, thread_id: string, assistantType: string, ) => Promise; + stopStream?: (clear?: boolean) => void; + stream: StreamState; } function usePrevious(value: T): T | undefined { diff --git a/frontend/src/hooks/useStreamState.tsx b/frontend/src/hooks/useStreamState.tsx index 36596284..59399f68 100644 --- a/frontend/src/hooks/useStreamState.tsx +++ b/frontend/src/hooks/useStreamState.tsx @@ -10,17 +10,19 @@ export interface StreamState { } export interface StreamStateProps { - stream: StreamState | null; + streams: { + [tid: string]: StreamState; + }; startStream: ( input: Message[] | Record | null, thread_id: string, config?: Record, ) => Promise; - stopStream?: (clear?: boolean) => void; + stopStream?: (thread_id: string, clear?: boolean) => void; } export function useStreamState(): StreamStateProps { - const [current, setCurrent] = useState(null); + const [current, setCurrent] = useState<{ [tid: string]: StreamState }>({}); const [controller, setController] = useState(null); const startStream = useCallback( @@ -31,7 +33,10 @@ export function useStreamState(): StreamStateProps { ) => { const controller = new AbortController(); setController(controller); - setCurrent({ status: "inflight", messages: input || [] }); + setCurrent((threads) => ({ + ...threads, + [thread_id]: { status: "inflight", messages: input || [] }, + })); await fetchEventSource("/runs/stream", { signal: controller.signal, @@ -42,39 +47,60 @@ export function useStreamState(): StreamStateProps { onmessage(msg) { if (msg.event === "data") { const messages = JSON.parse(msg.data); - setCurrent((current) => ({ - status: "inflight" as StreamState["status"], - messages: mergeMessagesById(current?.messages, messages), - run_id: current?.run_id, + setCurrent((threads) => ({ + ...threads, + [thread_id]: { + status: "inflight" as StreamState["status"], + messages: mergeMessagesById( + threads[thread_id]?.messages, + messages, + ), + run_id: threads[thread_id]?.run_id, + }, })); } else if (msg.event === "metadata") { const { run_id } = JSON.parse(msg.data); - setCurrent((current) => ({ - status: "inflight", - messages: current?.messages, - run_id: run_id, + setCurrent((threads) => ({ + ...threads, + [thread_id]: { + status: "inflight" as StreamState["status"], + messages: threads[thread_id]?.messages, + run_id, + }, })); } else if (msg.event === "error") { - setCurrent((current) => ({ - status: "error", - messages: current?.messages, - run_id: current?.run_id, + setCurrent((threads) => ({ + ...threads, + [thread_id]: { + status: "error", + messages: threads[thread_id]?.messages, + run_id: threads[thread_id]?.run_id, + }, })); } }, onclose() { - setCurrent((current) => ({ - status: current?.status === "error" ? current.status : "done", - messages: current?.messages, - run_id: current?.run_id, + setCurrent((threads) => ({ + ...threads, + [thread_id]: { + status: + threads[thread_id]?.status === "error" + ? threads[thread_id].status + : "done", + messages: threads[thread_id]?.messages, + run_id: threads[thread_id]?.run_id, + }, })); setController(null); }, onerror(error) { - setCurrent((current) => ({ - status: "error", - messages: current?.messages, - run_id: current?.run_id, + setCurrent((threads) => ({ + ...threads, + [thread_id]: { + status: "error", + messages: threads[thread_id]?.messages, + run_id: threads[thread_id]?.run_id, + }, })); setController(null); throw error; @@ -85,19 +111,25 @@ export function useStreamState(): StreamStateProps { ); const stopStream = useCallback( - (clear: boolean = false) => { + (thread_id: string, clear: boolean = false) => { controller?.abort(); setController(null); if (clear) { - setCurrent((current) => ({ - status: "done", - run_id: current?.run_id, + setCurrent((threads) => ({ + ...threads, + [thread_id]: { + status: "done", + run_id: threads[thread_id]?.run_id, + }, })); } else { - setCurrent((current) => ({ - status: "done", - messages: current?.messages, - run_id: current?.run_id, + setCurrent((threads) => ({ + ...threads, + [thread_id]: { + status: "done", + messages: threads[thread_id]?.messages, + run_id: threads[thread_id]?.run_id, + }, })); } }, @@ -107,7 +139,7 @@ export function useStreamState(): StreamStateProps { return { startStream, stopStream, - stream: current, + streams: current, }; }