diff --git a/packages/event-handler/src/bedrockAgentFunction/BedrockAgentFunctionResolver.ts b/packages/event-handler/src/bedrockAgentFunction/BedrockAgentFunctionResolver.ts new file mode 100644 index 0000000000..4732006631 --- /dev/null +++ b/packages/event-handler/src/bedrockAgentFunction/BedrockAgentFunctionResolver.ts @@ -0,0 +1,78 @@ +import type { Context } from 'aws-lambda'; +import type { + ToolConfig, + ToolRegistry, + ToolDefinition, + BedrockAgentFunctionRequest, + ToolFunction, + BedrockAgentFunctionResponse, + ResponseOpts, +} from '../types/Tools'; + +export class BedrockAgentFunctionResolver { + protected registry: ToolRegistry = new Map(); + + public tool(fn: ToolFunction, config: ToolConfig) { + this.registry.set(config.name, { function: fn, config }); + } + + public resolve( + event: BedrockAgentFunctionRequest, + context: Context + ): BedrockAgentFunctionResponse { + const { function: toolName, parameters, actionGroup } = event; + + const tool = this.registry.get(toolName); + + if (tool === undefined) { + console.error(`Cant find tool ${tool}`); + return this.response({ + actionGroup, + function: toolName, + responseBody: 'error', + }); + } + + const parameterObject = parameters.reduce((acc, curr) => { + acc[curr.name] = curr.value; + return acc; + }, {}); + + console.debug(`Calling tool ${tool.config.name}`); + const response = tool.function(parameterObject); + + return this.response({ + actionGroup, + function: toolName, + responseBody: response, + }); + } + + private response(opts: ResponseOpts): BedrockAgentFunctionResponse { + const { + actionGroup, + function: fn, + responseBody, + errorType, + sessionAttributes, + promptSessionAttributes, + } = opts; + return { + messageVersion: '1.0', + response: { + actionGroup, + function: fn, + functionResponse: { + responseState: errorType, + responseBody: { + TEXT: { + body: responseBody, + }, + }, + }, + }, + sessionAttributes, + promptSessionAttributes, + }; + } +} diff --git a/packages/event-handler/src/types/Tools.ts b/packages/event-handler/src/types/Tools.ts new file mode 100644 index 0000000000..06bac68dbd --- /dev/null +++ b/packages/event-handler/src/types/Tools.ts @@ -0,0 +1,77 @@ +type ToolConfig = { + name: string; + definition: string; + validation: { input: object; output: object }; + requireConfirmation: boolean | undefined; +}; + +type ToolDefinition = { + function: ToolFunction; + config: ToolConfig; +}; + +type ToolFunction = Function; + +type ToolRegistry = Map; + +type Parameter = { + name: string; + type: string; + value: string; +}; + +type BedrockAgentFunctionRequest = { + messageVersion: string; + agent: { + name: string; + id: string; + alias: string; + version: string; + }; + inputText: string; + sessionId: string; + actionGroup: string; + function: string; + parameters: Array; + sessionAttributes: Attributes; + promptSessionAttributes: Attributes; +}; + +type Attributes = Record; + +type BedrockAgentFunctionResponse = { + messageVersion: string; + response: { + actionGroup: string; + function: string; + functionResponse: { + responseState?: 'ERROR' | 'REPROMPT'; + responseBody: { + TEXT: { + body: string; + }; + }; + }; + }; + sessionAttributes?: Attributes; + promptSessionAttributes?: Attributes; +}; + +type ResponseOpts = { + actionGroup: string; + function: string; + responseBody: string; + errorType?: 'ERROR' | 'REPROMPT'; + sessionAttributes?: Attributes; + promptSessionAttributes?: Attributes; +}; + +export type { + BedrockAgentFunctionRequest, + ToolConfig, + ToolRegistry, + ToolDefinition, + ToolFunction, + BedrockAgentFunctionResponse, + ResponseOpts, +}; diff --git a/packages/event-handler/tests/unit/BedrockAgentFunctionResolver.test.ts b/packages/event-handler/tests/unit/BedrockAgentFunctionResolver.test.ts new file mode 100644 index 0000000000..f7ba9188d7 --- /dev/null +++ b/packages/event-handler/tests/unit/BedrockAgentFunctionResolver.test.ts @@ -0,0 +1,125 @@ +import { afterAll, beforeEach, describe, expect, it, vi } from 'vitest'; +import { BedrockAgentFunctionResolver } from '../../src/bedrockAgentFunction/BedrockAgentFunctionResolver.js'; +import context from '@aws-lambda-powertools/testing-utils/context'; + +const baseBedrockAgentFunctionRequest = { + messageVersion: '1.0', + agent: { + name: '', + id: 'string', + alias: 'string', + version: 'string', + }, + inputText: 'string', + sessionId: 'string', + actionGroup: 'string', + function: 'string', + parameters: [], + sessionAttributes: {}, + promptSessionAttributes: {}, +}; + +describe('BedrockAgentFunctionResolver', () => { + const ENVIRONMENT_VARIABLES = process.env; + + beforeEach(() => { + vi.clearAllMocks(); + vi.resetModules(); + process.env = { ...ENVIRONMENT_VARIABLES }; + }); + + afterAll(() => { + process.env = ENVIRONMENT_VARIABLES; + }); + + class WrappedResolver extends BedrockAgentFunctionResolver { + public getRegistry() { + return this.registry; + } + } + + it('registers tools', () => { + // Arrange + const resolver = new WrappedResolver(); + // Act + resolver.tool(() => {}, { + name: 'noop', + definition: 'Does nothing', + validation: { input: {}, output: {} }, + requireConfirmation: false, + }); + + // Assess + expect(resolver.getRegistry().get('noop')).toEqual( + expect.objectContaining({ + config: { + name: 'noop', + requireConfirmation: false, + definition: 'Does nothing', + validation: { input: {}, output: {} }, + }, + function: expect.any(Function), + }) + ); + }); + + it('resolves events to the correct tool', () => { + // Arrange + const resolver = new WrappedResolver(); + const noop = vi.fn(); + + resolver.tool(noop, { + name: 'noop', + definition: 'Does nothing', + validation: { input: {}, output: {} }, + requireConfirmation: false, + }); + + // Act + resolver.resolve( + { ...baseBedrockAgentFunctionRequest, function: 'noop' }, + context + ); + + expect(noop).toBeCalled(); + }); + + it('responds with the correct response structure when a tool is successfully invoked', () => { + const resolver = new WrappedResolver(); + const uppercaser = ({ str }): string => str.toUpperCase(); + + resolver.tool(uppercaser, { + name: 'uppercaser', + definition: 'Converts a string to uppercase', + validation: { input: {}, output: {} }, + requireConfirmation: false, + }); + + // Act + const response = resolver.resolve( + { + ...baseBedrockAgentFunctionRequest, + function: 'uppercaser', + parameters: [{ name: 'str', value: 'hello world', type: 'string' }], + }, + context + ); + + expect(response).toEqual( + expect.objectContaining({ + messageVersion: '1.0', + response: { + actionGroup: baseBedrockAgentFunctionRequest.actionGroup, + function: 'uppercaser', + functionResponse: { + responseBody: { + TEXT: { + body: 'HELLO WORLD', + }, + }, + }, + }, + }) + ); + }); +});