From 91d3951365d2a483137f2f869776dc1ed5a41f3d Mon Sep 17 00:00:00 2001 From: Simon Farshid Date: Mon, 14 Oct 2024 17:25:35 -0700 Subject: [PATCH] feat: MessageRuntime.getContentPartByToolCallId (#1013) --- .changeset/shy-eggs-march.md | 5 +++++ packages/react/src/api/MessageRuntime.ts | 23 ++++++++++++++++++++++- 2 files changed, 27 insertions(+), 1 deletion(-) create mode 100644 .changeset/shy-eggs-march.md diff --git a/.changeset/shy-eggs-march.md b/.changeset/shy-eggs-march.md new file mode 100644 index 0000000000..29c306ebd3 --- /dev/null +++ b/.changeset/shy-eggs-march.md @@ -0,0 +1,5 @@ +--- +"@assistant-ui/react": patch +--- + +feat: MessageRuntime.getContentPartByToolCallId diff --git a/packages/react/src/api/MessageRuntime.ts b/packages/react/src/api/MessageRuntime.ts index 24e72e5754..4a13853866 100644 --- a/packages/react/src/api/MessageRuntime.ts +++ b/packages/react/src/api/MessageRuntime.ts @@ -143,6 +143,8 @@ export type MessageRuntime = { subscribe(callback: () => void): Unsubscribe; getContentPartByIndex(idx: number): ContentPartRuntime; + getContentPartByToolCallId(toolCallId: string): ContentPartRuntime; + getAttachmentByIndex(idx: number): AttachmentRuntime & { source: "message" }; }; @@ -243,7 +245,7 @@ export class MessageRuntimeImpl implements MessageRuntime { } public getContentPartByIndex(idx: number) { - if (idx < 0) throw new Error("Message index must be >= 0"); + if (idx < 0) throw new Error("Content part index must be >= 0"); return new ContentPartRuntimeImpl( new ShallowMemoizeSubject({ getState: () => { @@ -256,6 +258,25 @@ export class MessageRuntimeImpl implements MessageRuntime { ); } + public getContentPartByToolCallId(toolCallId: string) { + return new ContentPartRuntimeImpl( + new ShallowMemoizeSubject({ + getState: () => { + const state = this._core.getState(); + const idx = state.content.findIndex( + (part) => + part.type === "tool-call" && part.toolCallId === toolCallId, + ); + if (idx === -1) return SKIP_UPDATE; + return getContentPartState(state, idx); + }, + subscribe: (callback) => this._core.subscribe(callback), + }), + this._core, + this._threadBinding, + ); + } + public getAttachmentByIndex(idx: number) { return new MessageAttachmentRuntimeImpl( new ShallowMemoizeSubject({