Skip to content

Commit

Permalink
UI tools (#1853)
Browse files Browse the repository at this point in the history
* add client tools to schema

* api

* logic and unit tests

* e2e

* pr feedback
  • Loading branch information
sobolk authored Aug 13, 2024
1 parent 2e48ae2 commit 4929bb7
Show file tree
Hide file tree
Showing 10 changed files with 312 additions and 49 deletions.
5 changes: 5 additions & 0 deletions .changeset/fifty-bags-type.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'@aws-amplify/ai-constructs': patch
---

Add client tools
21 changes: 12 additions & 9 deletions packages/ai-constructs/API.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,24 +71,19 @@ type ConversationTurnEvent = {
};
messages: Array<ConversationMessage>;
toolsConfiguration?: {
tools: Array<{
name: string;
description: string;
inputSchema: ToolInputSchema;
dataTools?: Array<ToolDefinition & {
graphqlRequestInputDescriptor: {
queryName: string;
selectionSet: string[];
propertyTypes: Record<string, string>;
};
}>;
clientTools?: Array<ToolDefinition>;
};
};

// @public (undocumented)
type ExecutableTool = {
name: string;
description: string;
inputSchema: ToolInputSchema;
type ExecutableTool = ToolDefinition & {
execute: (input: DocumentType | undefined) => Promise<ToolResultContentBlock>;
};

Expand All @@ -103,10 +98,18 @@ declare namespace runtime {
ConversationMessageContentBlock,
ConversationTurnEvent,
ExecutableTool,
handleConversationTurnEvent
handleConversationTurnEvent,
ToolDefinition
}
}

// @public (undocumented)
type ToolDefinition = {
name: string;
description: string;
inputSchema: ToolInputSchema;
};

// (No @packageDocumentation comment for this package)

```
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { describe, it, mock } from 'node:test';
import assert from 'node:assert';
import { ConversationTurnEvent, ExecutableTool } from './types';
import { ConversationTurnEvent, ExecutableTool, ToolDefinition } from './types';
import { BedrockConverseAdapter } from './bedrock_converse_adapter';
import {
BedrockRuntimeClient,
Expand Down Expand Up @@ -101,7 +101,7 @@ void describe('Bedrock converse adapter', () => {
assert.deepStrictEqual(bedrockRequest.input, expectedBedrockInput);
});

void it('uses tools while calling bedrock', async () => {
void it('uses executable tools while calling bedrock', async () => {
const additionalToolOutput: ToolResultContentBlock = {
text: 'additionalToolOutput',
};
Expand Down Expand Up @@ -322,6 +322,26 @@ void describe('Bedrock converse adapter', () => {
new BedrockConverseAdapter(
{
...commonEvent,
toolsConfiguration: {
clientTools: [
{
// this one overlaps with executable tools below
name: 'duplicateName3',
description: '',
inputSchema: { json: {} },
},
{
name: 'duplicateName4',
description: '',
inputSchema: { json: {} },
},
{
name: 'duplicateName4',
description: '',
inputSchema: { json: {} },
},
],
},
},
[
{
Expand All @@ -348,19 +368,25 @@ void describe('Bedrock converse adapter', () => {
inputSchema: { json: {} },
execute: () => Promise.reject(new Error()),
},
{
name: 'duplicateName3',
description: '',
inputSchema: { json: {} },
execute: () => Promise.reject(new Error()),
},
]
),
(error: Error) => {
assert.strictEqual(
error.message,
'Tools must have unique names. Duplicate tools: duplicateName1, duplicateName2.'
'Tools must have unique names. Duplicate tools: duplicateName1, duplicateName2, duplicateName3, duplicateName4.'
);
return true;
}
);
});

void it('tool error is reported to bedrock', async () => {
void it('executable tool error is reported to bedrock', async () => {
const tool: ExecutableTool = {
name: 'testTool',
description: 'tool description',
Expand Down Expand Up @@ -454,7 +480,7 @@ void describe('Bedrock converse adapter', () => {
} as Message);
});

void it('tool error of unknown type is reported to bedrock', async () => {
void it('executable tool error of unknown type is reported to bedrock', async () => {
const tool: ExecutableTool = {
name: 'testTool',
description: 'tool description',
Expand Down Expand Up @@ -549,4 +575,120 @@ void describe('Bedrock converse adapter', () => {
],
} as Message);
});

void it('returns client tool input block when client tool is requested and ignores executable tools', async () => {
const additionalToolOutput: ToolResultContentBlock = {
text: 'additionalToolOutput',
};
const additionalTool: ExecutableTool = {
name: 'additionalTool',
description: 'additional tool description',
inputSchema: {
json: {
required: ['additionalToolRequiredProperty'],
},
},
execute: () => Promise.resolve(additionalToolOutput),
};
const clientTool: ToolDefinition = {
name: 'clientTool',
description: 'client tool description',
inputSchema: {
json: {
required: ['clientToolRequiredProperty'],
},
},
};

const event: ConversationTurnEvent = {
...commonEvent,
toolsConfiguration: {
clientTools: [clientTool],
},
};

const bedrockClient = new BedrockRuntimeClient();
const bedrockResponseQueue: Array<ConverseCommandOutput> = [];
const clientToolUseBlock = {
toolUse: {
toolUseId: randomUUID().toString(),
name: clientTool.name,
input: 'clientToolInput',
},
};
const toolUseBedrockResponse: ConverseCommandOutput = {
$metadata: {},
metrics: undefined,
output: {
message: {
role: 'assistant',
content: [
{
toolUse: {
toolUseId: randomUUID().toString(),
name: additionalTool.name,
input: 'additionalToolInput',
},
},
clientToolUseBlock,
],
},
},
stopReason: 'tool_use',
usage: undefined,
};
bedrockResponseQueue.push(toolUseBedrockResponse);

const bedrockClientSendMock = mock.method(bedrockClient, 'send', () =>
Promise.resolve(bedrockResponseQueue.shift())
);

const responseContent = await new BedrockConverseAdapter(
event,
[additionalTool],
bedrockClient
).askBedrock();

assert.deepStrictEqual(responseContent, [clientToolUseBlock]);

assert.strictEqual(bedrockClientSendMock.mock.calls.length, 1);
const expectedToolConfig: ToolConfiguration = {
tools: [
{
toolSpec: {
name: additionalTool.name,
description: additionalTool.description,
inputSchema: additionalTool.inputSchema,
},
},
{
toolSpec: {
name: clientTool.name,
description: clientTool.description,
inputSchema: clientTool.inputSchema,
},
},
],
};
const expectedBedrockInputCommonProperties = {
modelId: event.modelConfiguration.modelId,
inferenceConfig: {
maxTokens: 2000,
temperature: 0,
},
system: [
{
text: event.modelConfiguration.systemPrompt,
},
],
toolConfig: expectedToolConfig,
};
const bedrockRequest = bedrockClientSendMock.mock.calls[0]
.arguments[0] as unknown as ConverseCommand;
const expectedBedrockInput: ConverseCommandInput = {
messages: event.messages,
...expectedBedrockInputCommonProperties,
};
assert.deepStrictEqual(bedrockRequest.input, expectedBedrockInput);
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,24 @@ import {
Tool,
ToolConfiguration,
} from '@aws-sdk/client-bedrock-runtime';
import { ConversationTurnEvent, ExecutableTool } from './types.js';
import {
ConversationTurnEvent,
ExecutableTool,
ToolDefinition,
} from './types.js';
import { ConversationTurnEventToolsProvider } from './event-tools-provider';

/**
* This class is responsible for interacting with Bedrock Converse API
* in order to produce final response that can be sent back to caller.
*/
export class BedrockConverseAdapter {
private readonly allTools: Array<ExecutableTool>;
private readonly toolByName: Map<string, ExecutableTool> = new Map();
private readonly allTools: Array<ToolDefinition>;
private readonly executableTools: Array<ExecutableTool>;
private readonly clientTools: Array<ToolDefinition>;
private readonly executableToolByName: Map<string, ExecutableTool> =
new Map();
private readonly clientToolByName: Map<string, ToolDefinition> = new Map();

/**
* Creates Bedrock Converse Adapter.
Expand All @@ -30,13 +38,27 @@ export class BedrockConverseAdapter {
),
eventToolsProvider = new ConversationTurnEventToolsProvider(event)
) {
this.allTools = [...eventToolsProvider.getEventTools(), ...additionalTools];
this.executableTools = [
...eventToolsProvider.getEventTools(),
...additionalTools,
];
this.clientTools = this.event.toolsConfiguration?.clientTools ?? [];
this.allTools = [...this.executableTools, ...this.clientTools];
const duplicateTools = new Set<string>();
this.allTools.forEach((t) => {
if (this.toolByName.has(t.name)) {
this.executableTools.forEach((t) => {
if (this.executableToolByName.has(t.name)) {
duplicateTools.add(t.name);
}
this.toolByName.set(t.name, t);
this.executableToolByName.set(t.name, t);
});
this.clientTools.forEach((t) => {
if (this.executableToolByName.has(t.name)) {
duplicateTools.add(t.name);
}
if (this.clientToolByName.has(t.name)) {
duplicateTools.add(t.name);
}
this.clientToolByName.set(t.name, t);
});
if (duplicateTools.size > 0) {
throw new Error(
Expand Down Expand Up @@ -74,13 +96,24 @@ export class BedrockConverseAdapter {
if (bedrockResponse.stopReason === 'tool_use') {
const responseContentBlocks =
bedrockResponse.output?.message?.content ?? [];
for (const responseContentBlock of responseContentBlocks) {
if ('toolUse' in responseContentBlock) {
const toolUseBlock =
responseContentBlock as ContentBlock.ToolUseMember;
const toolMessage = await this.executeTool(toolUseBlock);
messages.push(toolMessage);
}
const toolUseBlocks = responseContentBlocks.filter(
(block) => 'toolUse' in block
) as Array<ContentBlock.ToolUseMember>;
const clientToolUseBlocks = responseContentBlocks.filter(
(block) =>
block.toolUse?.name &&
this.clientToolByName.has(block.toolUse?.name)
);
if (clientToolUseBlocks.length > 0) {
// For now if any of client tools is used we ignore executable tools
// and propagate result back to client.
return clientToolUseBlocks;
}
for (const responseContentBlock of toolUseBlocks) {
const toolUseBlock =
responseContentBlock as ContentBlock.ToolUseMember;
const toolMessage = await this.executeTool(toolUseBlock);
messages.push(toolMessage);
}
}
} while (bedrockResponse.stopReason === 'tool_use');
Expand Down Expand Up @@ -112,7 +145,7 @@ export class BedrockConverseAdapter {
if (!toolUseBlock.toolUse.name) {
throw Error('Bedrock tool use response is missing a tool name');
}
const tool = this.toolByName.get(toolUseBlock.toolUse.name);
const tool = this.executableToolByName.get(toolUseBlock.toolUse.name);
if (!tool) {
throw Error(
`Bedrock tool use response contains unknown tool '${toolUseBlock.toolUse.name}'`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ void describe('events tool provider', () => {
selectionSet: '',
},
toolsConfiguration: {
tools: [toolDefinition1, toolDefinition2],
dataTools: [toolDefinition1, toolDefinition2],
},
};
const queryFactory = new GraphQlQueryFactory();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ export class ConversationTurnEventToolsProvider {

getEventTools = (): Array<ExecutableTool> => {
const { toolsConfiguration, graphqlApiEndpoint } = this.event;
if (!toolsConfiguration || !toolsConfiguration.tools) {
if (!toolsConfiguration || !toolsConfiguration.dataTools) {
return [];
}
const tools = toolsConfiguration.tools?.map((tool) => {
const tools = toolsConfiguration.dataTools?.map((tool) => {
const { name, description, inputSchema } = tool;
const query = this.graphQlQueryFactory.createQuery(tool);
return new GraphQlTool(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { ConversationTurnEvent } from '../types';

export type ConversationTurnEventToolConfiguration = NonNullable<
ConversationTurnEvent['toolsConfiguration']
>['tools'][number];
NonNullable<ConversationTurnEvent['toolsConfiguration']>['dataTools']
>[number];
Loading

0 comments on commit 4929bb7

Please sign in to comment.