Skip to content

Commit

Permalink
feat(ai): conversation streaming and error propagation (#381)
Browse files Browse the repository at this point in the history
Co-authored-by: Danny Banks <[email protected]>
  • Loading branch information
atierian and dbanksdesign authored Nov 11, 2024
1 parent 281efd7 commit 20f30fe
Show file tree
Hide file tree
Showing 19 changed files with 628 additions and 126 deletions.
5 changes: 5 additions & 0 deletions .changeset/clean-apples-camp.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@aws-amplify/data-schema": minor
---

add streaming support for conversation routes
5 changes: 5 additions & 0 deletions .changeset/ten-turkeys-pay.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@aws-amplify/data-schema": minor
---

propagate conversation errors through subscription
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ bench('p50 conversation operations', async () => {

await client.conversations.ChatBot.list();

conversation?.onMessage(() => {});
conversation?.onStreamEvent(() => {});

await conversation?.sendMessage({
content: [{ text: 'foo' }],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -635,7 +635,7 @@ bench('prod p50 conversation operations', async () => {

await client.conversations.ChatBot.list();

conversation?.onMessage(() => {});
conversation?.onStreamEvent(() => {});

await conversation?.sendMessage({
content: [{ text: 'foo' }],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,25 @@ type ToolSpecification {
type ToolInputSchema {
json: AWSJSON
}
type ConversationMessageStreamPart @aws_cognito_user_pools {
id: ID!
owner: String
conversationId: ID!
associatedUserMessageId: ID!
contentBlockIndex: Int!
contentBlockText: String
contentBlockDeltaIndex: Int
contentBlockToolUse: ToolUseBlock
contentBlockDoneAtIndex: Int
stopReason: String
errors: [ConversationTurnError]
}
type ConversationTurnError @aws_cognito_user_pools {
message: String!
errorType: String!
}"
`;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ describe('convertItemToConversation()', () => {
id: mockConversationId,
createdAt: '2023-06-01T12:00:00Z',
updatedAt: '2023-06-02T12:00:00Z',
onMessage: expect.any(Function),
onStreamEvent: expect.any(Function),
sendMessage: expect.any(Function),
listMessages: expect.any(Function),
metadata: undefined,
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

import type { Conversation } from '../../../src/ai/ConversationType';
import type { BaseClient } from '../../../src/runtime';
import type { ModelIntrospectionSchema } from '../../../src/runtime/bridge-types';
import { customOpFactory } from '../../../src/runtime/internals/operations/custom';
import { createOnStreamEventFunction } from '../../../src/runtime/internals/ai/createOnStreamEventFunction';
import { convertItemToConversationStreamEvent } from '../../../src/runtime/internals/ai/conversationStreamEventDeserializers';
jest.mock('../../../src/runtime/internals/ai/conversationStreamEventDeserializers');
jest.mock('../../../src/runtime/internals/operations/custom');

describe('createOnStreamEventFunction()', () => {
let onStreamEvent: Conversation['onStreamEvent'];
const mockConversationName = 'conversation-name';
const mockConversationId = 'conversation-id';
const mockRole = 'user';
const mockMessageId = 'message-id';
const mockAssociatedUserMessageId = 'associated-user-message-id';
const mockContentBlockIndex = 0;
const mockContentBlockDeltaIndex = 0;
const mockText = 'hello';
const mockStreamEvent = {
associatedUserMessageId: mockAssociatedUserMessageId,
contentBlockIndex: mockContentBlockIndex,
contentBlockDeltaIndex: mockContentBlockDeltaIndex,
text: mockText,
conversationId: mockConversationId,
id: mockMessageId,
role: mockRole,
};
const mockError = { message: 'error message', errorType: 'errorType' };
const mockStreamEventError = {
errors: [mockError],
id: mockMessageId,
conversationId: mockConversationId,
associatedUserMessageId: mockAssociatedUserMessageId,
};
const mockConversationSchema = { message: { subscribe: {} } };
const mockModelIntrospectionSchema = {
conversations: { [mockConversationName]: mockConversationSchema },
} as unknown as ModelIntrospectionSchema;
// assert mocks
const mockCustomOpFactory = customOpFactory as jest.Mock;
const mockConvertItemToConversationStreamEvent =
convertItemToConversationStreamEvent as jest.Mock;
// create mocks
const mockCustomOp = jest.fn();
const mockSubscribe = jest.fn();
const mockHandler = {
next: jest.fn(),
error: jest.fn(),
};

beforeAll(async () => {
mockCustomOp.mockReturnValue({ subscribe: mockSubscribe });
mockCustomOpFactory.mockReturnValue(mockCustomOp);
onStreamEvent = await createOnStreamEventFunction(
{} as BaseClient,
mockModelIntrospectionSchema,
mockConversationId,
mockConversationName,
jest.fn(),
);
});

afterEach(() => {
jest.clearAllMocks();
});

it('returns a onStreamEvent function', async () => {
expect(onStreamEvent).toBeDefined();
});

describe('onStreamEvent()', () => {
it('triggers next handler', async () => {
mockConvertItemToConversationStreamEvent.mockImplementation((next) => ({ next }));
mockSubscribe.mockImplementation((subscription) => {
subscription(mockStreamEvent);
});
const expectedData = {
associatedUserMessageId: mockAssociatedUserMessageId,
contentBlockIndex: mockContentBlockIndex,
contentBlockDeltaIndex: mockContentBlockDeltaIndex,
text: mockText,
conversationId: mockConversationId,
id: mockMessageId,
role: mockRole,
};
const expectedUndefinedFields = {
contentBlockDoneAtIndex: undefined,
toolUse: undefined,
stopReason: undefined,
};
onStreamEvent(mockHandler);

expect(mockCustomOpFactory).toHaveBeenCalledWith(
{},
mockModelIntrospectionSchema,
'subscription',
mockConversationSchema.message.subscribe,
false,
expect.any(Function),
{ action: '7', category: 'ai' },
);
expect(mockCustomOp).toHaveBeenCalledWith({
conversationId: mockConversationId,
});
expect(mockConvertItemToConversationStreamEvent).toHaveBeenCalledWith(
{ ...expectedData, ...expectedUndefinedFields },
);
expect(mockHandler.next).toHaveBeenCalledWith(expectedData);
});

it('triggers errors handler', async () => {
mockConvertItemToConversationStreamEvent.mockImplementation((error) => ({ error }));
mockSubscribe.mockImplementation((subscription) => {
subscription(mockStreamEventError);
});
const expectedError = {
id: mockMessageId,
associatedUserMessageId: mockAssociatedUserMessageId,
conversationId: mockConversationId,
errors: [mockError],
};
const expectedUndefinedFields = {
contentBlockDoneAtIndex: undefined,
toolUse: undefined,
stopReason: undefined,
contentBlockIndex: undefined,
contentBlockDeltaIndex: undefined,
text: undefined,
role: undefined,
};
onStreamEvent(mockHandler);
expect(mockCustomOpFactory).toHaveBeenCalledWith(
{},
mockModelIntrospectionSchema,
'subscription',
mockConversationSchema.message.subscribe,
false,
expect.any(Function),
{ action: '7', category: 'ai' },
);
expect(mockCustomOp).toHaveBeenCalledWith({
conversationId: mockConversationId,
});
expect(mockConvertItemToConversationStreamEvent).toHaveBeenCalledWith(
{ ...expectedError, ...expectedUndefinedFields },
);
expect(mockHandler.error).toHaveBeenCalledWith(expectedError);
});
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import type {
Conversation,
ConversationMessage,
ConversationStreamEvent,
} from '../../ai/ConversationType';
import type { ClientSchemaProperty } from '../Core';

Expand All @@ -12,4 +13,5 @@ export interface ClientConversation
__entityType: 'customConversation';
type: Conversation;
messageType: ConversationMessage;
streamEventType: ConversationStreamEvent;
}
21 changes: 21 additions & 0 deletions packages/data-schema/src/ai/ConversationSchemaTypes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,25 @@ const ToolInputSchema = `type ToolInputSchema {
json: AWSJSON
}`;

const ConversationMessageStreamEvent = `type ConversationMessageStreamPart @aws_cognito_user_pools {
id: ID!
owner: String
conversationId: ID!
associatedUserMessageId: ID!
contentBlockIndex: Int!
contentBlockText: String
contentBlockDeltaIndex: Int
contentBlockToolUse: ToolUseBlock
contentBlockDoneAtIndex: Int
stopReason: String
errors: [ConversationTurnError]
}`;

const ConversationTurnError = `type ConversationTurnError @aws_cognito_user_pools {
message: String!
errorType: String!
}`;

export const conversationTypes: string[] = [
ConversationParticipantRole,
ConversationMessage,
Expand Down Expand Up @@ -262,4 +281,6 @@ export const conversationTypes: string[] = [
Tool,
ToolSpecification,
ToolInputSchema,
ConversationMessageStreamEvent,
ConversationTurnError,
];
Loading

0 comments on commit 20f30fe

Please sign in to comment.