diff --git a/packages/ai-constructs/src/conversation/runtime/bedrock_converse_adapter.ts b/packages/ai-constructs/src/conversation/runtime/bedrock_converse_adapter.ts index 8f6b741d10..ab90e0ff02 100644 --- a/packages/ai-constructs/src/conversation/runtime/bedrock_converse_adapter.ts +++ b/packages/ai-constructs/src/conversation/runtime/bedrock_converse_adapter.ts @@ -21,7 +21,7 @@ import { ConversationTurnEventToolsProvider } from './event-tools-provider'; */ export class BedrockConverseAdapter { private readonly allTools: Array; - private readonly allExecutableTools: Array; + private readonly executableTools: Array; private readonly clientTools: Array; private readonly executableToolByName: Map = new Map(); @@ -38,14 +38,14 @@ export class BedrockConverseAdapter { ), eventToolsProvider = new ConversationTurnEventToolsProvider(event) ) { - this.allExecutableTools = [ + this.executableTools = [ ...eventToolsProvider.getEventTools(), ...additionalTools, ]; this.clientTools = this.event.toolsConfiguration?.clientTools ?? []; - this.allTools = [...this.allExecutableTools, ...this.clientTools]; + this.allTools = [...this.executableTools, ...this.clientTools]; const duplicateTools = new Set(); - this.allExecutableTools.forEach((t) => { + this.executableTools.forEach((t) => { if (this.executableToolByName.has(t.name)) { duplicateTools.add(t.name); } @@ -122,7 +122,7 @@ export class BedrockConverseAdapter { }; private createToolConfiguration = (): ToolConfiguration | undefined => { - if (this.allExecutableTools.length === 0) { + if (this.allTools.length === 0) { return undefined; } diff --git a/packages/integration-tests/src/test-project-setup/conversation_handler_project.ts b/packages/integration-tests/src/test-project-setup/conversation_handler_project.ts index 7c8c5f5691..00c5e9cadb 100644 --- a/packages/integration-tests/src/test-project-setup/conversation_handler_project.ts +++ b/packages/integration-tests/src/test-project-setup/conversation_handler_project.ts @@ -203,6 +203,13 @@ class ConversationHandlerTestProject extends TestProjectBase { clientConfig.data.url, apolloClient ); + + await this.assertDefaultConversationHandlerCanExecuteTurnWithClientTool( + backendId, + authenticatedUserCredentials.accessToken, + clientConfig.data.url, + apolloClient + ); } private assertDefaultConversationHandlerCanExecuteTurn = async ( @@ -320,6 +327,71 @@ class ConversationHandlerTestProject extends TestProjectBase { ); }; + private assertDefaultConversationHandlerCanExecuteTurnWithClientTool = async ( + backendId: BackendIdentifier, + accessToken: string, + graphqlApiEndpoint: string, + apolloClient: ApolloClient + ): Promise => { + const defaultConversationHandlerFunction = ( + await this.resourceFinder.findByBackendIdentifier( + backendId, + 'AWS::Lambda::Function', + (name) => name.includes('default') + ) + )[0]; + + // send event + const event: ConversationTurnEvent = { + conversationId: randomUUID().toString(), + currentMessageId: randomUUID().toString(), + graphqlApiEndpoint: graphqlApiEndpoint, + messages: [ + { + role: 'user', + content: [ + { + text: 'What is the temperature in Seattle?', + }, + ], + }, + ], + request: { + headers: { authorization: accessToken }, + }, + toolsConfiguration: { + clientTools: [ + { + name: 'thermometer', + description: 'Provides the current temperature for a given city.', + inputSchema: { + json: { + type: 'object', + properties: { + city: { + type: 'string', + description: 'string', + }, + }, + required: [], + }, + }, + }, + ], + }, + ...commonEventProperties, + }; + const response = await this.executeConversationTurn( + event, + defaultConversationHandlerFunction, + apolloClient + ); + // Assert that tool was used. I.e. that LLM used value returned by the tool. + assert.match(response.content, /Seattle/); + assert.match(response.content, /toolUse/); + assert.match(response.content, /toolUseId/); + }; + private assertCustomConversationHandlerCanExecuteTurn = async ( backendId: BackendIdentifier, accessToken: string,