From 184d836a873d9fa3b69f38262ead125fecb2affb Mon Sep 17 00:00:00 2001 From: Simon Farshid Date: Tue, 10 Sep 2024 19:47:48 -0700 Subject: [PATCH] feat: change langgraph callback to accept an array for tool call cancellation support (#818) --- .changeset/thick-kings-buy.md | 5 ++++ packages/react-langgraph/src/index.ts | 5 ---- .../src/useLangGraphMessages.ts | 10 +++---- .../src/useLangGraphRuntime.ts | 30 +++++++++++-------- 4 files changed, 27 insertions(+), 23 deletions(-) create mode 100644 .changeset/thick-kings-buy.md diff --git a/.changeset/thick-kings-buy.md b/.changeset/thick-kings-buy.md new file mode 100644 index 000000000..d061a9c41 --- /dev/null +++ b/.changeset/thick-kings-buy.md @@ -0,0 +1,5 @@ +--- +"@assistant-ui/react-langgraph": patch +--- + +feat: allow multiple message sends to support pending tool call cancellations diff --git a/packages/react-langgraph/src/index.ts b/packages/react-langgraph/src/index.ts index 393f3d210..15e96e19e 100644 --- a/packages/react-langgraph/src/index.ts +++ b/packages/react-langgraph/src/index.ts @@ -8,8 +8,3 @@ export type { LangChainToolCall, LangChainToolCallChunk, } from "./types"; - -/** - * @deprecated Use `useLangGraphRuntime` instead. This will be removed in 0.1.0. - */ -export { useLangGraphRuntime as useLangChainLangGraphRuntime } from "./useLangGraphRuntime"; diff --git a/packages/react-langgraph/src/useLangGraphMessages.ts b/packages/react-langgraph/src/useLangGraphMessages.ts index 86fce4079..12073aa31 100644 --- a/packages/react-langgraph/src/useLangGraphMessages.ts +++ b/packages/react-langgraph/src/useLangGraphMessages.ts @@ -3,7 +3,7 @@ import { useState, useCallback } from "react"; export const useLangGraphMessages = ({ stream, }: { - stream: (message: TMessage) => Promise< + stream: (messages: TMessage[]) => Promise< AsyncGenerator<{ event: string; data: any; @@ -13,12 +13,12 @@ export const useLangGraphMessages = ({ const [messages, setMessages] = useState([]); const sendMessage = useCallback( - async (message: TMessage) => { - if (message !== null) { - setMessages((currentMessages) => [...currentMessages, message]); + async (messages: TMessage[]) => { + if (messages.length > 0) { + setMessages((currentMessages) => [...currentMessages, ...messages]); } - const response = await stream(message); + const response = await stream(messages); const completeMessages: TMessage[] = []; let partialMessages: Map = new Map(); diff --git a/packages/react-langgraph/src/useLangGraphRuntime.ts b/packages/react-langgraph/src/useLangGraphRuntime.ts index 5dc51a058..4fe1ad2ca 100644 --- a/packages/react-langgraph/src/useLangGraphRuntime.ts +++ b/packages/react-langgraph/src/useLangGraphRuntime.ts @@ -13,7 +13,7 @@ export const useLangGraphRuntime = ({ stream, }: { threadId?: string | undefined; - stream: (message: LangChainMessage) => Promise< + stream: (messages: LangChainMessage[]) => Promise< AsyncGenerator<{ event: string; data: any; @@ -25,10 +25,10 @@ export const useLangGraphRuntime = ({ }); const [isRunning, setIsRunning] = useState(false); - const handleSendMessage = async (message: LangChainMessage) => { + const handleSendMessage = async (messages: LangChainMessage[]) => { try { setIsRunning(true); - await sendMessage(message); + await sendMessage(messages); } catch (error) { console.error("Error streaming messages:", error); } finally { @@ -49,18 +49,22 @@ export const useLangGraphRuntime = ({ onNew: (msg) => { if (msg.content.length !== 1 || msg.content[0]?.type !== "text") throw new Error("Only text messages are supported"); - return handleSendMessage({ - type: "human", - content: msg.content[0].text, - }); + return handleSendMessage([ + { + type: "human", + content: msg.content[0].text, + }, + ]); }, onAddToolResult: async ({ toolCallId, toolName, result }) => { - await handleSendMessage({ - type: "tool", - name: toolName, - tool_call_id: toolCallId, - content: JSON.stringify(result), - }); + await handleSendMessage([ + { + type: "tool", + name: toolName, + tool_call_id: toolCallId, + content: JSON.stringify(result), + }, + ]); }, }); };