Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ai): conversation streaming and error propagation #381

Merged
merged 2 commits into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/clean-apples-camp.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@aws-amplify/data-schema": minor
---

add streaming support for conversation routes
5 changes: 5 additions & 0 deletions .changeset/ten-turkeys-pay.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@aws-amplify/data-schema": minor
---

propagate conversation errors through subscription
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ bench('p50 conversation operations', async () => {

await client.conversations.ChatBot.list();

conversation?.onMessage(() => {});
conversation?.onStreamEvent(() => {});

await conversation?.sendMessage({
content: [{ text: 'foo' }],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -635,7 +635,7 @@ bench('prod p50 conversation operations', async () => {

await client.conversations.ChatBot.list();

conversation?.onMessage(() => {});
conversation?.onStreamEvent(() => {});

await conversation?.sendMessage({
content: [{ text: 'foo' }],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,25 @@ type ToolSpecification {

type ToolInputSchema {
json: AWSJSON
}

type ConversationMessageStreamPart @aws_cognito_user_pools {
id: ID!
owner: String
conversationId: ID!
associatedUserMessageId: ID!
contentBlockIndex: Int!
contentBlockText: String
contentBlockDeltaIndex: Int
contentBlockToolUse: ToolUseBlock
contentBlockDoneAtIndex: Int
stopReason: String
errors: [ConversationTurnError]
}

type ConversationTurnError @aws_cognito_user_pools {
message: String!
errorType: String!
}"
`;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ describe('convertItemToConversation()', () => {
id: mockConversationId,
createdAt: '2023-06-01T12:00:00Z',
updatedAt: '2023-06-02T12:00:00Z',
onMessage: expect.any(Function),
onStreamEvent: expect.any(Function),
sendMessage: expect.any(Function),
listMessages: expect.any(Function),
metadata: undefined,
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

import type { Conversation } from '../../../src/ai/ConversationType';
import type { BaseClient } from '../../../src/runtime';
import type { ModelIntrospectionSchema } from '../../../src/runtime/bridge-types';
import { customOpFactory } from '../../../src/runtime/internals/operations/custom';
import { createOnStreamEventFunction } from '../../../src/runtime/internals/ai/createOnStreamEventFunction';
import { convertItemToConversationStreamEvent } from '../../../src/runtime/internals/ai/conversationStreamEventDeserializers';
jest.mock('../../../src/runtime/internals/ai/conversationStreamEventDeserializers');
jest.mock('../../../src/runtime/internals/operations/custom');

describe('createOnStreamEventFunction()', () => {
let onStreamEvent: Conversation['onStreamEvent'];
const mockConversationName = 'conversation-name';
const mockConversationId = 'conversation-id';
const mockRole = 'user';
const mockMessageId = 'message-id';
const mockAssociatedUserMessageId = 'associated-user-message-id';
const mockContentBlockIndex = 0;
const mockContentBlockDeltaIndex = 0;
const mockText = 'hello';
const mockStreamEvent = {
associatedUserMessageId: mockAssociatedUserMessageId,
contentBlockIndex: mockContentBlockIndex,
contentBlockDeltaIndex: mockContentBlockDeltaIndex,
text: mockText,
conversationId: mockConversationId,
id: mockMessageId,
role: mockRole,
};
const mockError = { message: 'error message', errorType: 'errorType' };
const mockStreamEventError = {
errors: [mockError],
id: mockMessageId,
conversationId: mockConversationId,
associatedUserMessageId: mockAssociatedUserMessageId,
};
const mockConversationSchema = { message: { subscribe: {} } };
const mockModelIntrospectionSchema = {
conversations: { [mockConversationName]: mockConversationSchema },
} as unknown as ModelIntrospectionSchema;
// assert mocks
const mockCustomOpFactory = customOpFactory as jest.Mock;
const mockConvertItemToConversationStreamEvent =
convertItemToConversationStreamEvent as jest.Mock;
// create mocks
const mockCustomOp = jest.fn();
const mockSubscribe = jest.fn();
const mockHandler = {
next: jest.fn(),
error: jest.fn(),
};

beforeAll(async () => {
mockCustomOp.mockReturnValue({ subscribe: mockSubscribe });
mockCustomOpFactory.mockReturnValue(mockCustomOp);
onStreamEvent = await createOnStreamEventFunction(
{} as BaseClient,
mockModelIntrospectionSchema,
mockConversationId,
mockConversationName,
jest.fn(),
);
});

afterEach(() => {
jest.clearAllMocks();
});

it('returns a onStreamEvent function', async () => {
expect(onStreamEvent).toBeDefined();
});

describe('onStreamEvent()', () => {
it('triggers next handler', async () => {
mockConvertItemToConversationStreamEvent.mockImplementation((next) => ({ next }));
mockSubscribe.mockImplementation((subscription) => {
subscription(mockStreamEvent);
});
const expectedData = {
associatedUserMessageId: mockAssociatedUserMessageId,
contentBlockIndex: mockContentBlockIndex,
contentBlockDeltaIndex: mockContentBlockDeltaIndex,
text: mockText,
conversationId: mockConversationId,
id: mockMessageId,
role: mockRole,
};
const expectedUndefinedFields = {
contentBlockDoneAtIndex: undefined,
toolUse: undefined,
stopReason: undefined,
};
onStreamEvent(mockHandler);

expect(mockCustomOpFactory).toHaveBeenCalledWith(
{},
mockModelIntrospectionSchema,
'subscription',
mockConversationSchema.message.subscribe,
false,
expect.any(Function),
{ action: '7', category: 'ai' },
);
expect(mockCustomOp).toHaveBeenCalledWith({
conversationId: mockConversationId,
});
expect(mockConvertItemToConversationStreamEvent).toHaveBeenCalledWith(
{ ...expectedData, ...expectedUndefinedFields },
);
expect(mockHandler.next).toHaveBeenCalledWith(expectedData);
});

it('triggers errors handler', async () => {
mockConvertItemToConversationStreamEvent.mockImplementation((error) => ({ error }));
mockSubscribe.mockImplementation((subscription) => {
subscription(mockStreamEventError);
});
const expectedError = {
id: mockMessageId,
associatedUserMessageId: mockAssociatedUserMessageId,
conversationId: mockConversationId,
errors: [mockError],
};
const expectedUndefinedFields = {
contentBlockDoneAtIndex: undefined,
toolUse: undefined,
stopReason: undefined,
contentBlockIndex: undefined,
contentBlockDeltaIndex: undefined,
text: undefined,
role: undefined,
};
onStreamEvent(mockHandler);
expect(mockCustomOpFactory).toHaveBeenCalledWith(
{},
mockModelIntrospectionSchema,
'subscription',
mockConversationSchema.message.subscribe,
false,
expect.any(Function),
{ action: '7', category: 'ai' },
);
expect(mockCustomOp).toHaveBeenCalledWith({
conversationId: mockConversationId,
});
expect(mockConvertItemToConversationStreamEvent).toHaveBeenCalledWith(
{ ...expectedError, ...expectedUndefinedFields },
);
expect(mockHandler.error).toHaveBeenCalledWith(expectedError);
});
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import type {
Conversation,
ConversationMessage,
ConversationStreamEvent,
} from '../../ai/ConversationType';
import type { ClientSchemaProperty } from '../Core';

Expand All @@ -12,4 +13,5 @@ export interface ClientConversation
__entityType: 'customConversation';
type: Conversation;
messageType: ConversationMessage;
streamEventType: ConversationStreamEvent;
}
21 changes: 21 additions & 0 deletions packages/data-schema/src/ai/ConversationSchemaTypes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,25 @@ const ToolInputSchema = `type ToolInputSchema {
json: AWSJSON
}`;

const ConversationMessageStreamEvent = `type ConversationMessageStreamPart @aws_cognito_user_pools {
id: ID!
owner: String
conversationId: ID!
associatedUserMessageId: ID!
contentBlockIndex: Int!
contentBlockText: String
contentBlockDeltaIndex: Int
contentBlockToolUse: ToolUseBlock
contentBlockDoneAtIndex: Int
stopReason: String
errors: [ConversationTurnError]
}`;

const ConversationTurnError = `type ConversationTurnError @aws_cognito_user_pools {
message: String!
errorType: String!
}`;

export const conversationTypes: string[] = [
ConversationParticipantRole,
ConversationMessage,
Expand Down Expand Up @@ -262,4 +281,6 @@ export const conversationTypes: string[] = [
Tool,
ToolSpecification,
ToolInputSchema,
ConversationMessageStreamEvent,
ConversationTurnError,
];
Loading
Loading