diff --git a/.changeset/clean-apples-camp.md b/.changeset/clean-apples-camp.md new file mode 100644 index 000000000..2c29bc4a6 --- /dev/null +++ b/.changeset/clean-apples-camp.md @@ -0,0 +1,5 @@ +--- +"@aws-amplify/data-schema": minor +--- + +add streaming support for conversation routes diff --git a/.changeset/ten-turkeys-pay.md b/.changeset/ten-turkeys-pay.md new file mode 100644 index 000000000..9f8948f27 --- /dev/null +++ b/.changeset/ten-turkeys-pay.md @@ -0,0 +1,5 @@ +--- +"@aws-amplify/data-schema": minor +--- + +propagate conversation errors through subscription diff --git a/packages/benches/p50/ai/p50-conversation-operations.bench.ts b/packages/benches/p50/ai/p50-conversation-operations.bench.ts index a64e733f8..db85a9e47 100644 --- a/packages/benches/p50/ai/p50-conversation-operations.bench.ts +++ b/packages/benches/p50/ai/p50-conversation-operations.bench.ts @@ -89,7 +89,7 @@ bench('p50 conversation operations', async () => { await client.conversations.ChatBot.list(); - conversation?.onMessage(() => {}); + conversation?.onStreamEvent(() => {}); await conversation?.sendMessage({ content: [{ text: 'foo' }], diff --git a/packages/benches/p50/ai/p50-conversation-prod-operations.bench.ts b/packages/benches/p50/ai/p50-conversation-prod-operations.bench.ts index 811ebddfe..1c8f22acb 100644 --- a/packages/benches/p50/ai/p50-conversation-prod-operations.bench.ts +++ b/packages/benches/p50/ai/p50-conversation-prod-operations.bench.ts @@ -635,7 +635,7 @@ bench('prod p50 conversation operations', async () => { await client.conversations.ChatBot.list(); - conversation?.onMessage(() => {}); + conversation?.onStreamEvent(() => {}); await conversation?.sendMessage({ content: [{ text: 'foo' }], diff --git a/packages/data-schema/__tests__/__snapshots__/ClientSchema.test.ts.snap b/packages/data-schema/__tests__/__snapshots__/ClientSchema.test.ts.snap index fb4cb3643..0effe070e 100644 --- a/packages/data-schema/__tests__/__snapshots__/ClientSchema.test.ts.snap +++ b/packages/data-schema/__tests__/__snapshots__/ClientSchema.test.ts.snap @@ -207,6 +207,25 @@ type ToolSpecification { type ToolInputSchema { json: AWSJSON +} + +type ConversationMessageStreamPart @aws_cognito_user_pools { + id: ID! + owner: String + conversationId: ID! + associatedUserMessageId: ID! + contentBlockIndex: Int! + contentBlockText: String + contentBlockDeltaIndex: Int + contentBlockToolUse: ToolUseBlock + contentBlockDoneAtIndex: Int + stopReason: String + errors: [ConversationTurnError] +} + +type ConversationTurnError @aws_cognito_user_pools { + message: String! + errorType: String! }" `; diff --git a/packages/data-schema/__tests__/internals/ai/convertItemToConversation.test.ts b/packages/data-schema/__tests__/internals/ai/convertItemToConversation.test.ts index 3c03653ad..dbf88557a 100644 --- a/packages/data-schema/__tests__/internals/ai/convertItemToConversation.test.ts +++ b/packages/data-schema/__tests__/internals/ai/convertItemToConversation.test.ts @@ -27,7 +27,7 @@ describe('convertItemToConversation()', () => { id: mockConversationId, createdAt: '2023-06-01T12:00:00Z', updatedAt: '2023-06-02T12:00:00Z', - onMessage: expect.any(Function), + onStreamEvent: expect.any(Function), sendMessage: expect.any(Function), listMessages: expect.any(Function), metadata: undefined, diff --git a/packages/data-schema/__tests__/internals/ai/createOnMessageFunction.test.ts b/packages/data-schema/__tests__/internals/ai/createOnMessageFunction.test.ts deleted file mode 100644 index fba2e4469..000000000 --- a/packages/data-schema/__tests__/internals/ai/createOnMessageFunction.test.ts +++ /dev/null @@ -1,95 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -import type { Conversation } from '../../../src/ai/ConversationType'; -import type { BaseClient } from '../../../src/runtime'; -import type { ModelIntrospectionSchema } from '../../../src/runtime/bridge-types'; -import { convertItemToConversationMessage } from '../../../src/runtime/internals/ai/convertItemToConversationMessage'; -import { createOnMessageFunction } from '../../../src/runtime/internals/ai/createOnMessageFunction'; -import { customOpFactory } from '../../../src/runtime/internals/operations/custom'; - -jest.mock('../../../src/runtime/internals/ai/convertItemToConversationMessage'); -jest.mock('../../../src/runtime/internals/operations/custom'); - -describe('createOnMessageFunction()', () => { - let onMessage: Conversation['onMessage']; - const mockConversationName = 'conversation-name'; - const mockConversationId = 'conversation-id'; - const mockContent = [{ text: 'foo' }]; - const mockRole = 'user'; - const mockCreatedAt = '2024-06-27T00:00:00Z'; - const mockMessageId = 'message-id'; - const mockMessage = { - content: mockContent, - conversationId: mockConversationId, - createdAt: mockCreatedAt, - id: mockMessageId, - role: mockRole, - }; - const mockConversationSchema = { message: { subscribe: {} } }; - const mockModelIntrospectionSchema = { - conversations: { [mockConversationName]: mockConversationSchema }, - } as unknown as ModelIntrospectionSchema; - // assert mocks - const mockCustomOpFactory = customOpFactory as jest.Mock; - const mockConvertItemToConversationMessage = - convertItemToConversationMessage as jest.Mock; - // create mocks - const mockCustomOp = jest.fn(); - const mockSubscribe = jest.fn(); - const mockHandler = jest.fn(); - - beforeAll(async () => { - mockConvertItemToConversationMessage.mockImplementation((data) => data); - mockCustomOp.mockReturnValue({ subscribe: mockSubscribe }); - mockCustomOpFactory.mockReturnValue(mockCustomOp); - mockSubscribe.mockImplementation((subscription) => { - subscription(mockMessage); - }); - onMessage = await createOnMessageFunction( - {} as BaseClient, - mockModelIntrospectionSchema, - mockConversationId, - mockConversationName, - jest.fn(), - ); - }); - - afterEach(() => { - jest.clearAllMocks(); - }); - - it('returns a onMessage function', async () => { - expect(onMessage).toBeDefined(); - }); - - describe('onMessage()', () => { - it('triggers handler', async () => { - const expectedData = { - content: mockContent, - conversationId: mockConversationId, - createdAt: mockCreatedAt, - id: mockMessageId, - role: mockRole, - }; - onMessage(mockHandler); - - expect(mockCustomOpFactory).toHaveBeenCalledWith( - {}, - mockModelIntrospectionSchema, - 'subscription', - mockConversationSchema.message.subscribe, - false, - expect.any(Function), - { action: '7', category: 'ai' }, - ); - expect(mockCustomOp).toHaveBeenCalledWith({ - conversationId: mockConversationId, - }); - expect(mockConvertItemToConversationMessage).toHaveBeenCalledWith( - expectedData, - ); - expect(mockHandler).toHaveBeenCalledWith(expectedData); - }); - }); -}); diff --git a/packages/data-schema/__tests__/internals/ai/createOnStreamEventFunction.test.ts b/packages/data-schema/__tests__/internals/ai/createOnStreamEventFunction.test.ts new file mode 100644 index 000000000..063c42af6 --- /dev/null +++ b/packages/data-schema/__tests__/internals/ai/createOnStreamEventFunction.test.ts @@ -0,0 +1,154 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +import type { Conversation } from '../../../src/ai/ConversationType'; +import type { BaseClient } from '../../../src/runtime'; +import type { ModelIntrospectionSchema } from '../../../src/runtime/bridge-types'; +import { customOpFactory } from '../../../src/runtime/internals/operations/custom'; +import { createOnStreamEventFunction } from '../../../src/runtime/internals/ai/createOnStreamEventFunction'; +import { convertItemToConversationStreamEvent } from '../../../src/runtime/internals/ai/conversationStreamEventDeserializers'; +jest.mock('../../../src/runtime/internals/ai/conversationStreamEventDeserializers'); +jest.mock('../../../src/runtime/internals/operations/custom'); + +describe('createOnStreamEventFunction()', () => { + let onStreamEvent: Conversation['onStreamEvent']; + const mockConversationName = 'conversation-name'; + const mockConversationId = 'conversation-id'; + const mockRole = 'user'; + const mockMessageId = 'message-id'; + const mockAssociatedUserMessageId = 'associated-user-message-id'; + const mockContentBlockIndex = 0; + const mockContentBlockDeltaIndex = 0; + const mockText = 'hello'; + const mockStreamEvent = { + associatedUserMessageId: mockAssociatedUserMessageId, + contentBlockIndex: mockContentBlockIndex, + contentBlockDeltaIndex: mockContentBlockDeltaIndex, + text: mockText, + conversationId: mockConversationId, + id: mockMessageId, + role: mockRole, + }; + const mockError = { message: 'error message', errorType: 'errorType' }; + const mockStreamEventError = { + errors: [mockError], + id: mockMessageId, + conversationId: mockConversationId, + associatedUserMessageId: mockAssociatedUserMessageId, + }; + const mockConversationSchema = { message: { subscribe: {} } }; + const mockModelIntrospectionSchema = { + conversations: { [mockConversationName]: mockConversationSchema }, + } as unknown as ModelIntrospectionSchema; + // assert mocks + const mockCustomOpFactory = customOpFactory as jest.Mock; + const mockConvertItemToConversationStreamEvent = + convertItemToConversationStreamEvent as jest.Mock; + // create mocks + const mockCustomOp = jest.fn(); + const mockSubscribe = jest.fn(); + const mockHandler = { + next: jest.fn(), + error: jest.fn(), + }; + + beforeAll(async () => { + mockCustomOp.mockReturnValue({ subscribe: mockSubscribe }); + mockCustomOpFactory.mockReturnValue(mockCustomOp); + onStreamEvent = await createOnStreamEventFunction( + {} as BaseClient, + mockModelIntrospectionSchema, + mockConversationId, + mockConversationName, + jest.fn(), + ); + }); + + afterEach(() => { + jest.clearAllMocks(); + }); + + it('returns a onStreamEvent function', async () => { + expect(onStreamEvent).toBeDefined(); + }); + + describe('onStreamEvent()', () => { + it('triggers next handler', async () => { + mockConvertItemToConversationStreamEvent.mockImplementation((next) => ({ next })); + mockSubscribe.mockImplementation((subscription) => { + subscription(mockStreamEvent); + }); + const expectedData = { + associatedUserMessageId: mockAssociatedUserMessageId, + contentBlockIndex: mockContentBlockIndex, + contentBlockDeltaIndex: mockContentBlockDeltaIndex, + text: mockText, + conversationId: mockConversationId, + id: mockMessageId, + role: mockRole, + }; + const expectedUndefinedFields = { + contentBlockDoneAtIndex: undefined, + toolUse: undefined, + stopReason: undefined, + }; + onStreamEvent(mockHandler); + + expect(mockCustomOpFactory).toHaveBeenCalledWith( + {}, + mockModelIntrospectionSchema, + 'subscription', + mockConversationSchema.message.subscribe, + false, + expect.any(Function), + { action: '7', category: 'ai' }, + ); + expect(mockCustomOp).toHaveBeenCalledWith({ + conversationId: mockConversationId, + }); + expect(mockConvertItemToConversationStreamEvent).toHaveBeenCalledWith( + { ...expectedData, ...expectedUndefinedFields }, + ); + expect(mockHandler.next).toHaveBeenCalledWith(expectedData); + }); + + it('triggers errors handler', async () => { + mockConvertItemToConversationStreamEvent.mockImplementation((error) => ({ error })); + mockSubscribe.mockImplementation((subscription) => { + subscription(mockStreamEventError); + }); + const expectedError = { + id: mockMessageId, + associatedUserMessageId: mockAssociatedUserMessageId, + conversationId: mockConversationId, + errors: [mockError], + }; + const expectedUndefinedFields = { + contentBlockDoneAtIndex: undefined, + toolUse: undefined, + stopReason: undefined, + contentBlockIndex: undefined, + contentBlockDeltaIndex: undefined, + text: undefined, + role: undefined, + }; + onStreamEvent(mockHandler); + expect(mockCustomOpFactory).toHaveBeenCalledWith( + {}, + mockModelIntrospectionSchema, + 'subscription', + mockConversationSchema.message.subscribe, + false, + expect.any(Function), + { action: '7', category: 'ai' }, + ); + expect(mockCustomOp).toHaveBeenCalledWith({ + conversationId: mockConversationId, + }); + expect(mockConvertItemToConversationStreamEvent).toHaveBeenCalledWith( + { ...expectedError, ...expectedUndefinedFields }, + ); + expect(mockHandler.error).toHaveBeenCalledWith(expectedError); + }); + }); +}); diff --git a/packages/data-schema/src/ClientSchema/ai/ClientConversation.ts b/packages/data-schema/src/ClientSchema/ai/ClientConversation.ts index 4ace0d41d..3cb2d61d6 100644 --- a/packages/data-schema/src/ClientSchema/ai/ClientConversation.ts +++ b/packages/data-schema/src/ClientSchema/ai/ClientConversation.ts @@ -4,6 +4,7 @@ import type { Conversation, ConversationMessage, + ConversationStreamEvent, } from '../../ai/ConversationType'; import type { ClientSchemaProperty } from '../Core'; @@ -12,4 +13,5 @@ export interface ClientConversation __entityType: 'customConversation'; type: Conversation; messageType: ConversationMessage; + streamEventType: ConversationStreamEvent; } diff --git a/packages/data-schema/src/ai/ConversationSchemaTypes.ts b/packages/data-schema/src/ai/ConversationSchemaTypes.ts index 446f64430..6651d8550 100644 --- a/packages/data-schema/src/ai/ConversationSchemaTypes.ts +++ b/packages/data-schema/src/ai/ConversationSchemaTypes.ts @@ -230,6 +230,25 @@ const ToolInputSchema = `type ToolInputSchema { json: AWSJSON }`; +const ConversationMessageStreamEvent = `type ConversationMessageStreamPart @aws_cognito_user_pools { + id: ID! + owner: String + conversationId: ID! + associatedUserMessageId: ID! + contentBlockIndex: Int! + contentBlockText: String + contentBlockDeltaIndex: Int + contentBlockToolUse: ToolUseBlock + contentBlockDoneAtIndex: Int + stopReason: String + errors: [ConversationTurnError] +}`; + +const ConversationTurnError = `type ConversationTurnError @aws_cognito_user_pools { + message: String! + errorType: String! +}`; + export const conversationTypes: string[] = [ ConversationParticipantRole, ConversationMessage, @@ -262,4 +281,6 @@ export const conversationTypes: string[] = [ Tool, ToolSpecification, ToolInputSchema, + ConversationMessageStreamEvent, + ConversationTurnError, ]; diff --git a/packages/data-schema/src/ai/ConversationType.ts b/packages/data-schema/src/ai/ConversationType.ts index 11e2867a2..474a15f3c 100644 --- a/packages/data-schema/src/ai/ConversationType.ts +++ b/packages/data-schema/src/ai/ConversationType.ts @@ -12,6 +12,7 @@ import { } from './types/ConversationMessageContent'; import { ToolConfiguration } from './types/ToolConfiguration'; import { AiModel } from '@aws-amplify/data-schema-types'; +import { ConversationStreamErrorEvent, ConversationStreamEvent } from './types/ConversationStreamEvent'; export const brandName = 'conversationCustomOperation'; @@ -22,6 +23,7 @@ export interface ConversationMessage { createdAt: string; id: string; role: 'user' | 'assistant'; + associatedUserMessageId?: string; } // conversation route types @@ -104,7 +106,10 @@ interface ConversationListMessagesInput { nextToken?: string | null; } -type ConversationOnMessageHandler = (message: ConversationMessage) => void; +type ConversationOnStreamEventHandler = { + next: (event: ConversationStreamEvent) => void; + error: (error: ConversationStreamErrorEvent) => void; +}; export interface Conversation { id: string; @@ -132,9 +137,9 @@ export interface Conversation { /** * @experimental * - * Subscribes to new messages on the current conversation. + * Subscribes to new stream events on the current conversation. */ - onMessage: (handler: ConversationOnMessageHandler) => Subscription; + onStreamEvent: (handler: ConversationOnStreamEventHandler) => Subscription; } // schema definition input @@ -185,3 +190,5 @@ function _conversation(input: ConversationInput): ConversationType { export function conversation(input: ConversationInput): ConversationType { return _conversation(input); } +export { ConversationStreamEvent }; + diff --git a/packages/data-schema/src/ai/types/ConversationStreamEvent.ts b/packages/data-schema/src/ai/types/ConversationStreamEvent.ts new file mode 100644 index 000000000..f2cce47bf --- /dev/null +++ b/packages/data-schema/src/ai/types/ConversationStreamEvent.ts @@ -0,0 +1,70 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +import { ToolUseBlock } from "./contentBlocks"; + +export interface ConversationStreamTextEvent { + id: string; + conversationId: string; + associatedUserMessageId: string; + contentBlockIndex: number; + contentBlockDeltaIndex: number; + contentBlockDoneAtIndex?: never; + text: string; + toolUse?: never; + stopReason?: never; +} + +export interface ConversationStreamToolUseEvent { + id: string; + conversationId: string; + associatedUserMessageId: string; + contentBlockIndex: number; + contentBlockDeltaIndex?: never; + contentBlockDoneAtIndex?: never; + text?: never; + toolUse: ToolUseBlock; + stopReason?: never; +} + +export interface ConversationStreamDoneAtIndexEvent { + id: string; + conversationId: string; + associatedUserMessageId: string; + contentBlockIndex: number; + contentBlockDoneAtIndex: number; + contentBlockDeltaIndex?: never; + text?: never; + toolUse?: never; + stopReason?: never; +} + +export interface ConversationStreamTurnDoneEvent { + id: string; + conversationId: string; + associatedUserMessageId: string; + contentBlockIndex: number; + contentBlockDoneAtIndex?: never; + contentBlockDeltaIndex?: never; + text?: never; + toolUse?: never; + stopReason: string; +} + +export interface ConversationStreamErrorEvent { + id: string; + conversationId: string; + associatedUserMessageId: string; + errors: ConversationTurnError[] +} + +export interface ConversationTurnError { + message: string; + errorType: string; +} + +export type ConversationStreamEvent = + | ConversationStreamTextEvent + | ConversationStreamToolUseEvent + | ConversationStreamDoneAtIndexEvent + | ConversationStreamTurnDoneEvent; diff --git a/packages/data-schema/src/runtime/internals/ai/conversationStreamEventDeserializers.ts b/packages/data-schema/src/runtime/internals/ai/conversationStreamEventDeserializers.ts new file mode 100644 index 000000000..ab6ee2a23 --- /dev/null +++ b/packages/data-schema/src/runtime/internals/ai/conversationStreamEventDeserializers.ts @@ -0,0 +1,62 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +import { ConversationStreamEvent } from "../../../ai/ConversationType"; +import { ToolUseBlock } from "../../../ai/types/contentBlocks"; +import { ConversationStreamErrorEvent } from "../../../ai/types/ConversationStreamEvent"; + +export const convertItemToConversationStreamEvent = ({ + id, + conversationId, + associatedUserMessageId, + contentBlockIndex, + contentBlockDoneAtIndex, + contentBlockDeltaIndex, + contentBlockText, + contentBlockToolUse, + stopReason, + errors, +}: any): { next?: ConversationStreamEvent, error?: ConversationStreamErrorEvent } => { + if (errors) { + const error = { + id, + conversationId, + associatedUserMessageId, + errors, + }; + return { error }; + } + const next = removeNullsFromConversationStreamEvent({ + id, + conversationId, + associatedUserMessageId, + contentBlockIndex, + contentBlockDoneAtIndex, + contentBlockDeltaIndex, + text: contentBlockText, + toolUse: deserializeToolUseBlock(contentBlockToolUse), + stopReason, + }); + + return { next }; +}; + +const deserializeToolUseBlock = ( + contentBlockToolUse: any, +): ToolUseBlock | undefined => { + if (contentBlockToolUse) { + const toolUseBlock = { + ...contentBlockToolUse, + input: JSON.parse(contentBlockToolUse.input), + }; + + return toolUseBlock; + } +}; + +const removeNullsFromConversationStreamEvent = ( + block: ConversationStreamEvent, +): ConversationStreamEvent => + Object.fromEntries( + Object.entries(block).filter(([_, v]) => v !== null), + ) as ConversationStreamEvent; diff --git a/packages/data-schema/src/runtime/internals/ai/convertItemToConversation.ts b/packages/data-schema/src/runtime/internals/ai/convertItemToConversation.ts index a3ffcada2..7a65b25d8 100644 --- a/packages/data-schema/src/runtime/internals/ai/convertItemToConversation.ts +++ b/packages/data-schema/src/runtime/internals/ai/convertItemToConversation.ts @@ -10,7 +10,7 @@ import type { SchemaModel, } from '../../bridge-types'; import { createListMessagesFunction } from './createListMessagesFunction'; -import { createOnMessageFunction } from './createOnMessageFunction'; +import { createOnStreamEventFunction } from './createOnStreamEventFunction'; import { createSendMessageFunction } from './createSendMessageFunction'; export const convertItemToConversation = ( @@ -36,7 +36,7 @@ export const convertItemToConversation = ( updatedAt: conversationUpdatedAt, metadata: conversationMetadata, name: conversationName, - onMessage: createOnMessageFunction( + onStreamEvent: createOnStreamEventFunction( client as BaseBrowserClient, modelIntrospection, conversationId, diff --git a/packages/data-schema/src/runtime/internals/ai/convertItemToConversationMessage.ts b/packages/data-schema/src/runtime/internals/ai/convertItemToConversationMessage.ts index f1e69e8ea..3117bde9a 100644 --- a/packages/data-schema/src/runtime/internals/ai/convertItemToConversationMessage.ts +++ b/packages/data-schema/src/runtime/internals/ai/convertItemToConversationMessage.ts @@ -10,7 +10,7 @@ export const convertItemToConversationMessage = ({ conversationId, role, }: any) => ({ - content: deserializeContent(content), + content: deserializeContent(content ?? []), conversationId, createdAt, id, diff --git a/packages/data-schema/src/runtime/internals/ai/createOnMessageFunction.ts b/packages/data-schema/src/runtime/internals/ai/createOnStreamEventFunction.ts similarity index 77% rename from packages/data-schema/src/runtime/internals/ai/createOnMessageFunction.ts rename to packages/data-schema/src/runtime/internals/ai/createOnStreamEventFunction.ts index 5d847851d..0a02d2b8a 100644 --- a/packages/data-schema/src/runtime/internals/ai/createOnMessageFunction.ts +++ b/packages/data-schema/src/runtime/internals/ai/createOnStreamEventFunction.ts @@ -9,17 +9,17 @@ import { ModelIntrospectionSchema, } from '../../bridge-types'; import { customOpFactory } from '../operations/custom'; -import { convertItemToConversationMessage } from './convertItemToConversationMessage'; import { AiAction, getCustomUserAgentDetails } from './getCustomUserAgentDetails'; +import { convertItemToConversationStreamEvent } from './conversationStreamEventDeserializers'; -export const createOnMessageFunction = +export const createOnStreamEventFunction = ( client: BaseClient, modelIntrospection: ModelIntrospectionSchema, conversationId: string, conversationRouteName: string, getInternals: ClientInternalsGetter, - ): Conversation['onMessage'] => + ): Conversation['onStreamEvent'] => (handler): Subscription => { const { conversations } = modelIntrospection; // Safe guard for standalone function. When called as part of client generation, this should never be falsy. @@ -35,9 +35,11 @@ export const createOnMessageFunction = subscribeSchema, false, getInternals, - getCustomUserAgentDetails(AiAction.OnMessage), + getCustomUserAgentDetails(AiAction.OnStreamEvent), ) as (args?: Record) => Observable; return subscribeOperation({ conversationId }).subscribe((data) => { - handler(convertItemToConversationMessage(data)); + const { next, error } = convertItemToConversationStreamEvent(data); + if (error) handler.error(error); + if (next) handler.next(next); }); }; diff --git a/packages/data-schema/src/runtime/internals/ai/getCustomUserAgentDetails.ts b/packages/data-schema/src/runtime/internals/ai/getCustomUserAgentDetails.ts index 14ad2940c..3248a8a56 100644 --- a/packages/data-schema/src/runtime/internals/ai/getCustomUserAgentDetails.ts +++ b/packages/data-schema/src/runtime/internals/ai/getCustomUserAgentDetails.ts @@ -24,7 +24,7 @@ export enum AiAction { DeleteConversation = '4', SendMessage = '5', ListMessages = '6', - OnMessage = '7', + OnStreamEvent = '7', Generation = '8', UpdateConversation = '9', } diff --git a/packages/integration-tests/__tests__/defined-behavior/2-expected-use/__snapshots__/ai-conversation.ts.snap b/packages/integration-tests/__tests__/defined-behavior/2-expected-use/__snapshots__/ai-conversation.ts.snap index b5fa569ef..66bfa779f 100644 --- a/packages/integration-tests/__tests__/defined-behavior/2-expected-use/__snapshots__/ai-conversation.ts.snap +++ b/packages/integration-tests/__tests__/defined-behavior/2-expected-use/__snapshots__/ai-conversation.ts.snap @@ -106,6 +106,30 @@ exports[`AI Conversation Routes Conversations Update a conversation 1`] = ` ] `; +exports[`AI Conversation Routes Conversations Update a conversation 2`] = ` +[ + [ + { + "authMode": undefined, + "authToken": undefined, + "query": "mutation($input: UpdateConversationChatBotInput!) { updateConversationChatBot(input: $input) { id name metadata createdAt updatedAt owner } }", + "variables": { + "input": { + "id": "conversation-id", + "metadata": "{"arbitrary":"data"}", + "name": "updated conversation name", + }, + }, + Symbol(INTERNAL_USER_AGENT_OVERRIDE): { + "action": "9", + "category": "ai", + }, + }, + {}, + ], +] +`; + exports[`AI Conversation Routes Messages List messages 1`] = ` [ [ @@ -531,5 +555,24 @@ type ToolSpecification { type ToolInputSchema { json: AWSJSON +} + +type ConversationMessageStreamPart @aws_cognito_user_pools { + id: ID! + owner: String + conversationId: ID! + associatedUserMessageId: ID! + contentBlockIndex: Int! + contentBlockText: String + contentBlockDeltaIndex: Int + contentBlockToolUse: ToolUseBlock + contentBlockDoneAtIndex: Int + stopReason: String + errors: [ConversationTurnError] +} + +type ConversationTurnError @aws_cognito_user_pools { + message: String! + errorType: String! }" `; diff --git a/packages/integration-tests/__tests__/defined-behavior/2-expected-use/ai-conversation.ts b/packages/integration-tests/__tests__/defined-behavior/2-expected-use/ai-conversation.ts index 2eed4b9b1..8dc3833b6 100644 --- a/packages/integration-tests/__tests__/defined-behavior/2-expected-use/ai-conversation.ts +++ b/packages/integration-tests/__tests__/defined-behavior/2-expected-use/ai-conversation.ts @@ -7,6 +7,7 @@ import { subOptionsAndHeaders, pause, } from '../../utils'; +import { Subscriber } from 'rxjs'; describe('AI Conversation Routes', () => { // data/resource.ts @@ -64,7 +65,7 @@ describe('AI Conversation Routes', () => { listMessages: expect.any(Function), metadata: {}, name: 'Test Conversation', - onMessage: expect.any(Function), + onStreamEvent: expect.any(Function), sendMessage: expect.any(Function), updatedAt: '2023-08-02T12:00:00Z', }); @@ -113,7 +114,56 @@ describe('AI Conversation Routes', () => { listMessages: expect.any(Function), metadata: {}, name: 'Test Conversation', - onMessage: expect.any(Function), + onStreamEvent: expect.any(Function), + sendMessage: expect.any(Function), + updatedAt: '2023-08-02T12:00:00Z', + }); + // #endregion assertions + }); + + test('Update a conversation', async () => { + const sampleConversation = { + id: 'conversation-id', + createdAt: '2023-06-01T12:00:00Z', + updatedAt: '2023-08-02T12:00:00Z', + metadata: {}, + name: 'Test Conversation', + }; + + const { spy, generateClient } = mockedGenerateClient([ + { + data: { + updateConversation: sampleConversation, + }, + }, + ]); + // simulated amplifyconfiguration.json + const config = await buildAmplifyConfig(schema); + // #endregion mocking + + // #region api call + // App.tsx + Amplify.configure(config); + const client = generateClient(); + // create conversation + const { data: updatedConversation, errors: updateConversationErrors } = + await client.conversations.chatBot.update({ + id: sampleConversation.id, + name: 'updated conversation name', + metadata: { arbitrary: 'data' }, + }); + // #endregion api call + + // #region assertions + expect(optionsAndHeaders(spy)).toMatchSnapshot(); + expect(updateConversationErrors).toBeUndefined(); + expect(updatedConversation).toStrictEqual({ + createdAt: '2023-06-01T12:00:00Z', + id: sampleConversation.id, + listMessages: expect.any(Function), + metadata: {}, + name: 'Test Conversation', + onStreamEvent: expect.any(Function), sendMessage: expect.any(Function), updatedAt: '2023-08-02T12:00:00Z', }); @@ -156,7 +206,7 @@ describe('AI Conversation Routes', () => { listMessages: expect.any(Function), metadata: {}, name: 'Test Conversation', - onMessage: expect.any(Function), + onStreamEvent: expect.any(Function), sendMessage: expect.any(Function), updatedAt: '2023-06-01T12:00:00Z', }); @@ -209,7 +259,7 @@ describe('AI Conversation Routes', () => { listMessages: expect.any(Function), metadata: {}, name: 'Test Conversation', - onMessage: expect.any(Function), + onStreamEvent: expect.any(Function), sendMessage: expect.any(Function), updatedAt: '2023-08-02T12:00:00Z', }, @@ -219,7 +269,7 @@ describe('AI Conversation Routes', () => { listMessages: expect.any(Function), metadata: {}, name: 'Test Conversation2', - onMessage: expect.any(Function), + onStreamEvent: expect.any(Function), sendMessage: expect.any(Function), updatedAt: '2024-09-05T12:00:00Z', }, @@ -280,7 +330,7 @@ describe('AI Conversation Routes', () => { listMessages: expect.any(Function), metadata: {}, name: 'Test Conversation', - onMessage: expect.any(Function), + onStreamEvent: expect.any(Function), sendMessage: expect.any(Function), updatedAt: '2023-08-02T12:00:00Z', }, @@ -290,7 +340,7 @@ describe('AI Conversation Routes', () => { listMessages: expect.any(Function), metadata: {}, name: 'Test Conversation2', - onMessage: expect.any(Function), + onStreamEvent: expect.any(Function), sendMessage: expect.any(Function), updatedAt: '2024-09-05T12:00:00Z', }, @@ -319,6 +369,45 @@ describe('AI Conversation Routes', () => { id: 'message-id', role: 'user', }; + const sampleConversationStreamTextEvent = { + id: 'stream-text-event-id', + conversationId: sampleConversation.id, + associatedUserMessageId: sampleConversationMessage1.id, + contentBlockDeltaIndex: 0, + contentBlockIndex: 0, + contentBlockText: 'foo', + }; + const sampleConversationStreamToolUseEvent = { + id: 'stream-tooluse-event-id', + conversationId: sampleConversation.id, + associatedUserMessageId: sampleConversationMessage1.id, + contentBlockIndex: 0, + contentBlockToolUse: { + toolUseId: 'toolUseId', + name: 'toolUseName', + input: JSON.stringify({ toolUseParam: 'toolUseParam' }), + }, + }; + const sampleConversationStreamDoneAtIndexEvent = { + id: 'stream-doneatindex-event-id', + conversationId: sampleConversation.id, + associatedUserMessageId: sampleConversationMessage1.id, + contentBlockDoneAtIndex: 0, + contentBlockIndex: 0, + }; + const sampleConversationStreamTurnDoneEvent = { + id: 'stream-turndone-event-id', + conversationId: sampleConversation.id, + associatedUserMessageId: sampleConversationMessage1.id, + contentBlockIndex: 0, + stopReason: 'stopReason', + }; + const sampleConversationStreamErrorEvent = { + id: 'stream-error-event-id', + conversationId: sampleConversation.id, + associatedUserMessageId: sampleConversationMessage1.id, + errors: [{ message: 'error message', errorType: 'errorType' }], + }; // #endregion mocking common test('Send a message', async () => { @@ -471,6 +560,117 @@ describe('AI Conversation Routes', () => { // #endregion assertions }); + describe('Stream events', () => { + const mockNextHandler = jest.fn(); + const mockErrorHandler = jest.fn(); + let subs: Record>; + + beforeAll(async () => { + const { subs: mockedSubs, generateClient } = mockedGenerateClient([ + { + data: { + getConversation: sampleConversation, + }, + }, + ]); + subs = mockedSubs; + + const config = await buildAmplifyConfig(schema); + Amplify.configure(config); + + const client = generateClient(); + const { data: conversation } = await client.conversations.chatBot.get({ + id: sampleConversation.id, + }); + // subscribe to messages + conversation?.onStreamEvent({ + next: (streamEvent) => mockNextHandler(streamEvent), + error: (error) => mockErrorHandler(error), + }); + }); + + test('Text event', async () => { + subs.onCreateAssistantResponseChatBot.next({ + data: { + onCreateAssistantResponseChatBot: { + ...sampleConversationStreamTextEvent, + }, + }, + }); + + await pause(1); + const { + contentBlockText, + ...rest + } = sampleConversationStreamTextEvent; + const expectedConversationStreamEvent = { + text: contentBlockText, + ...rest, + }; + expect(mockNextHandler).toHaveBeenCalledWith(expectedConversationStreamEvent); + }); + + test('Tool use event', async () => { + subs.onCreateAssistantResponseChatBot.next({ + data: { + onCreateAssistantResponseChatBot: { + ...sampleConversationStreamToolUseEvent, + }, + }, + }); + + await pause(1); + const { + contentBlockToolUse: { input, toolUseId, name }, + ...rest + } = sampleConversationStreamToolUseEvent; + const expectedConversationStreamEvent = { + toolUse: { input: JSON.parse(input), toolUseId, name }, + ...rest, + }; + expect(mockNextHandler).toHaveBeenCalledWith(expectedConversationStreamEvent); + }); + + test('Done at index event', async () => { + subs.onCreateAssistantResponseChatBot.next({ + data: { + onCreateAssistantResponseChatBot: { + ...sampleConversationStreamDoneAtIndexEvent, + }, + }, + }); + + await pause(1); + expect(mockNextHandler).toHaveBeenCalledWith(sampleConversationStreamDoneAtIndexEvent); + }); + + test('Turn done event', async () => { + subs.onCreateAssistantResponseChatBot.next({ + data: { + onCreateAssistantResponseChatBot: { + ...sampleConversationStreamTurnDoneEvent, + }, + }, + }); + + await pause(1); + expect(mockNextHandler).toHaveBeenCalledWith(sampleConversationStreamTurnDoneEvent); + }); + + test('Error event', async () => { + subs.onCreateAssistantResponseChatBot.next({ + data: { + onCreateAssistantResponseChatBot: { + ...sampleConversationStreamErrorEvent, + }, + }, + }); + + await pause(1); + expect(mockErrorHandler).toHaveBeenCalledWith(sampleConversationStreamErrorEvent); + }); + }); + test('Subscribe to messages', async () => { // #region mocking const { spy, subSpy, subs, generateClient } = mockedGenerateClient([ @@ -485,7 +685,7 @@ describe('AI Conversation Routes', () => { }, }, ]); - const mockHandler = jest.fn(); + const mockNextHandler = jest.fn(); // simulated amplifyconfiguration.json const config = await buildAmplifyConfig(schema); // #endregion mocking @@ -499,15 +699,17 @@ describe('AI Conversation Routes', () => { id: sampleConversation.id, }); // subscribe to messages - conversation?.onMessage((message) => { - mockHandler(message); + conversation?.onStreamEvent({ + next: (streamEvent) => { + mockNextHandler(streamEvent); + }, + error: () => { }, }); subs.onCreateAssistantResponseChatBot.next({ data: { onCreateAssistantResponseChatBot: { - ...sampleConversationMessage1, - role: 'assistant', + ...sampleConversationStreamTextEvent, }, }, }); @@ -517,10 +719,15 @@ describe('AI Conversation Routes', () => { // #region assertions expect(optionsAndHeaders(spy)).toMatchSnapshot(); expect(subOptionsAndHeaders(subSpy)).toMatchSnapshot(); - expect(mockHandler).toHaveBeenCalledWith({ - ...sampleConversationMessage1, - role: 'assistant', - }); + const { + contentBlockText, + ...rest + } = sampleConversationStreamTextEvent; + const expectedConversationStreamEvent = { + text: contentBlockText, + ...rest, + }; + expect(mockNextHandler).toHaveBeenCalledWith(expectedConversationStreamEvent); // #endregion assertions });