Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

UI tools #1853

Merged
merged 6 commits into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/fifty-bags-type.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'@aws-amplify/ai-constructs': patch
---

Add client tools
21 changes: 12 additions & 9 deletions packages/ai-constructs/API.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,24 +71,19 @@ type ConversationTurnEvent = {
};
messages: Array<ConversationMessage>;
toolsConfiguration?: {
tools: Array<{
name: string;
description: string;
inputSchema: ToolInputSchema;
dataTools?: Array<ToolDefinition & {
graphqlRequestInputDescriptor: {
queryName: string;
selectionSet: string[];
propertyTypes: Record<string, string>;
};
}>;
clientTools?: Array<ToolDefinition>;
};
};

// @public (undocumented)
type ExecutableTool = {
name: string;
description: string;
inputSchema: ToolInputSchema;
type ExecutableTool = ToolDefinition & {
execute: (input: DocumentType | undefined) => Promise<ToolResultContentBlock>;
};

Expand All @@ -103,10 +98,18 @@ declare namespace runtime {
ConversationMessageContentBlock,
ConversationTurnEvent,
ExecutableTool,
handleConversationTurnEvent
handleConversationTurnEvent,
ToolDefinition
}
}

// @public (undocumented)
type ToolDefinition = {
name: string;
description: string;
inputSchema: ToolInputSchema;
};
Comment on lines +106 to +111
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This abstraction appeared in third place already in public API. Therefore, extracting a type.


// (No @packageDocumentation comment for this package)

```
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { describe, it, mock } from 'node:test';
import assert from 'node:assert';
import { ConversationTurnEvent, ExecutableTool } from './types';
import { ConversationTurnEvent, ExecutableTool, ToolDefinition } from './types';
import { BedrockConverseAdapter } from './bedrock_converse_adapter';
import {
BedrockRuntimeClient,
Expand Down Expand Up @@ -101,7 +101,7 @@ void describe('Bedrock converse adapter', () => {
assert.deepStrictEqual(bedrockRequest.input, expectedBedrockInput);
});

void it('uses tools while calling bedrock', async () => {
void it('uses executable tools while calling bedrock', async () => {
const additionalToolOutput: ToolResultContentBlock = {
text: 'additionalToolOutput',
};
Expand Down Expand Up @@ -322,6 +322,26 @@ void describe('Bedrock converse adapter', () => {
new BedrockConverseAdapter(
{
...commonEvent,
toolsConfiguration: {
clientTools: [
{
// this one overlaps with executable tools below
name: 'duplicateName3',
description: '',
inputSchema: { json: {} },
},
{
name: 'duplicateName4',
description: '',
inputSchema: { json: {} },
},
{
name: 'duplicateName4',
description: '',
inputSchema: { json: {} },
},
],
},
},
[
{
Expand All @@ -348,19 +368,25 @@ void describe('Bedrock converse adapter', () => {
inputSchema: { json: {} },
execute: () => Promise.reject(new Error()),
},
{
name: 'duplicateName3',
description: '',
inputSchema: { json: {} },
execute: () => Promise.reject(new Error()),
},
]
),
(error: Error) => {
assert.strictEqual(
error.message,
'Tools must have unique names. Duplicate tools: duplicateName1, duplicateName2.'
'Tools must have unique names. Duplicate tools: duplicateName1, duplicateName2, duplicateName3, duplicateName4.'
);
return true;
}
);
});

void it('tool error is reported to bedrock', async () => {
void it('executable tool error is reported to bedrock', async () => {
const tool: ExecutableTool = {
name: 'testTool',
description: 'tool description',
Expand Down Expand Up @@ -454,7 +480,7 @@ void describe('Bedrock converse adapter', () => {
} as Message);
});

void it('tool error of unknown type is reported to bedrock', async () => {
void it('executable tool error of unknown type is reported to bedrock', async () => {
const tool: ExecutableTool = {
name: 'testTool',
description: 'tool description',
Expand Down Expand Up @@ -549,4 +575,120 @@ void describe('Bedrock converse adapter', () => {
],
} as Message);
});

void it('returns client tool input block when client tool is requested and ignores executable tools', async () => {
const additionalToolOutput: ToolResultContentBlock = {
text: 'additionalToolOutput',
};
const additionalTool: ExecutableTool = {
name: 'additionalTool',
description: 'additional tool description',
inputSchema: {
json: {
required: ['additionalToolRequiredProperty'],
},
},
execute: () => Promise.resolve(additionalToolOutput),
};
const clientTool: ToolDefinition = {
name: 'clientTool',
description: 'client tool description',
inputSchema: {
json: {
required: ['clientToolRequiredProperty'],
},
},
};

const event: ConversationTurnEvent = {
...commonEvent,
toolsConfiguration: {
clientTools: [clientTool],
},
};

const bedrockClient = new BedrockRuntimeClient();
const bedrockResponseQueue: Array<ConverseCommandOutput> = [];
const clientToolUseBlock = {
toolUse: {
toolUseId: randomUUID().toString(),
name: clientTool.name,
input: 'clientToolInput',
},
};
const toolUseBedrockResponse: ConverseCommandOutput = {
$metadata: {},
metrics: undefined,
output: {
message: {
role: 'assistant',
content: [
{
toolUse: {
toolUseId: randomUUID().toString(),
name: additionalTool.name,
input: 'additionalToolInput',
},
},
clientToolUseBlock,
],
},
},
stopReason: 'tool_use',
usage: undefined,
};
bedrockResponseQueue.push(toolUseBedrockResponse);

const bedrockClientSendMock = mock.method(bedrockClient, 'send', () =>
Promise.resolve(bedrockResponseQueue.shift())
);

const responseContent = await new BedrockConverseAdapter(
event,
[additionalTool],
bedrockClient
).askBedrock();

assert.deepStrictEqual(responseContent, [clientToolUseBlock]);

assert.strictEqual(bedrockClientSendMock.mock.calls.length, 1);
const expectedToolConfig: ToolConfiguration = {
tools: [
{
toolSpec: {
name: additionalTool.name,
description: additionalTool.description,
inputSchema: additionalTool.inputSchema,
},
},
{
toolSpec: {
name: clientTool.name,
description: clientTool.description,
inputSchema: clientTool.inputSchema,
},
},
],
};
const expectedBedrockInputCommonProperties = {
modelId: event.modelConfiguration.modelId,
inferenceConfig: {
maxTokens: 2000,
temperature: 0,
},
system: [
{
text: event.modelConfiguration.systemPrompt,
},
],
toolConfig: expectedToolConfig,
};
const bedrockRequest = bedrockClientSendMock.mock.calls[0]
.arguments[0] as unknown as ConverseCommand;
const expectedBedrockInput: ConverseCommandInput = {
messages: event.messages,
...expectedBedrockInputCommonProperties,
};
assert.deepStrictEqual(bedrockRequest.input, expectedBedrockInput);
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,24 @@ import {
Tool,
ToolConfiguration,
} from '@aws-sdk/client-bedrock-runtime';
import { ConversationTurnEvent, ExecutableTool } from './types.js';
import {
ConversationTurnEvent,
ExecutableTool,
ToolDefinition,
} from './types.js';
import { ConversationTurnEventToolsProvider } from './event-tools-provider';

/**
* This class is responsible for interacting with Bedrock Converse API
* in order to produce final response that can be sent back to caller.
*/
export class BedrockConverseAdapter {
private readonly allTools: Array<ExecutableTool>;
private readonly toolByName: Map<string, ExecutableTool> = new Map();
private readonly allTools: Array<ToolDefinition>;
private readonly executableTools: Array<ExecutableTool>;
private readonly clientTools: Array<ToolDefinition>;
private readonly executableToolByName: Map<string, ExecutableTool> =
new Map();
private readonly clientToolByName: Map<string, ToolDefinition> = new Map();

/**
* Creates Bedrock Converse Adapter.
Expand All @@ -30,13 +38,27 @@ export class BedrockConverseAdapter {
),
eventToolsProvider = new ConversationTurnEventToolsProvider(event)
) {
this.allTools = [...eventToolsProvider.getEventTools(), ...additionalTools];
this.executableTools = [
...eventToolsProvider.getEventTools(),
...additionalTools,
];
this.clientTools = this.event.toolsConfiguration?.clientTools ?? [];
this.allTools = [...this.executableTools, ...this.clientTools];
const duplicateTools = new Set<string>();
this.allTools.forEach((t) => {
if (this.toolByName.has(t.name)) {
this.executableTools.forEach((t) => {
if (this.executableToolByName.has(t.name)) {
duplicateTools.add(t.name);
}
this.toolByName.set(t.name, t);
this.executableToolByName.set(t.name, t);
});
this.clientTools.forEach((t) => {
if (this.executableToolByName.has(t.name)) {
duplicateTools.add(t.name);
}
if (this.clientToolByName.has(t.name)) {
duplicateTools.add(t.name);
}
this.clientToolByName.set(t.name, t);
});
if (duplicateTools.size > 0) {
throw new Error(
Expand Down Expand Up @@ -74,13 +96,24 @@ export class BedrockConverseAdapter {
if (bedrockResponse.stopReason === 'tool_use') {
const responseContentBlocks =
bedrockResponse.output?.message?.content ?? [];
for (const responseContentBlock of responseContentBlocks) {
if ('toolUse' in responseContentBlock) {
const toolUseBlock =
responseContentBlock as ContentBlock.ToolUseMember;
const toolMessage = await this.executeTool(toolUseBlock);
messages.push(toolMessage);
}
const toolUseBlocks = responseContentBlocks.filter(
(block) => 'toolUse' in block
) as Array<ContentBlock.ToolUseMember>;
const clientToolUseBlocks = responseContentBlocks.filter(
(block) =>
block.toolUse?.name &&
this.clientToolByName.has(block.toolUse?.name)
);
if (clientToolUseBlocks.length > 0) {
// For now if any of client tools is used we ignore executable tools
// and propagate result back to client.
return clientToolUseBlocks;
}
for (const responseContentBlock of toolUseBlocks) {
const toolUseBlock =
responseContentBlock as ContentBlock.ToolUseMember;
const toolMessage = await this.executeTool(toolUseBlock);
messages.push(toolMessage);
}
}
} while (bedrockResponse.stopReason === 'tool_use');
Expand Down Expand Up @@ -112,7 +145,7 @@ export class BedrockConverseAdapter {
if (!toolUseBlock.toolUse.name) {
throw Error('Bedrock tool use response is missing a tool name');
}
const tool = this.toolByName.get(toolUseBlock.toolUse.name);
const tool = this.executableToolByName.get(toolUseBlock.toolUse.name);
if (!tool) {
throw Error(
`Bedrock tool use response contains unknown tool '${toolUseBlock.toolUse.name}'`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ void describe('events tool provider', () => {
selectionSet: '',
},
toolsConfiguration: {
tools: [toolDefinition1, toolDefinition2],
dataTools: [toolDefinition1, toolDefinition2],
},
};
const queryFactory = new GraphQlQueryFactory();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ export class ConversationTurnEventToolsProvider {

getEventTools = (): Array<ExecutableTool> => {
const { toolsConfiguration, graphqlApiEndpoint } = this.event;
if (!toolsConfiguration || !toolsConfiguration.tools) {
if (!toolsConfiguration || !toolsConfiguration.dataTools) {
return [];
}
const tools = toolsConfiguration.tools?.map((tool) => {
const tools = toolsConfiguration.dataTools?.map((tool) => {
const { name, description, inputSchema } = tool;
const query = this.graphQlQueryFactory.createQuery(tool);
return new GraphQlTool(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { ConversationTurnEvent } from '../types';

export type ConversationTurnEventToolConfiguration = NonNullable<
ConversationTurnEvent['toolsConfiguration']
>['tools'][number];
NonNullable<ConversationTurnEvent['toolsConfiguration']>['dataTools']
>[number];
Loading
Loading