Skip to content

Commit

Permalink
fix(MessageStream): handle errors more gracefully in async iterator (#…
Browse files Browse the repository at this point in the history
…301)

Errors that were thrown in the async iterator of the MessageStream would previously
throw as an unhandled promise rejection.
See: #298
  • Loading branch information
stainless-bot authored and RobertCraigie committed Mar 4, 2024
1 parent d55c320 commit 95232ed
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 10 deletions.
31 changes: 25 additions & 6 deletions src/lib/MessageStream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -468,13 +468,16 @@ export class MessageStream implements AsyncIterable<MessageStreamEvent> {

[Symbol.asyncIterator](): AsyncIterator<MessageStreamEvent> {
const pushQueue: MessageStreamEvent[] = [];
const readQueue: ((chunk: MessageStreamEvent | undefined) => void)[] = [];
const readQueue: {
resolve: (chunk: MessageStreamEvent | undefined) => void;
reject: (error: unknown) => void;
}[] = [];
let done = false;

this.on('streamEvent', (event) => {
const reader = readQueue.shift();
if (reader) {
reader(event);
reader.resolve(event);
} else {
pushQueue.push(event);
}
Expand All @@ -483,7 +486,23 @@ export class MessageStream implements AsyncIterable<MessageStreamEvent> {
this.on('end', () => {
done = true;
for (const reader of readQueue) {
reader(undefined);
reader.resolve(undefined);
}
readQueue.length = 0;
});

this.on('abort', (err) => {
done = true;
for (const reader of readQueue) {
reader.reject(err);
}
readQueue.length = 0;
});

this.on('error', (err) => {
done = true;
for (const reader of readQueue) {
reader.reject(err);
}
readQueue.length = 0;
});
Expand All @@ -494,9 +513,9 @@ export class MessageStream implements AsyncIterable<MessageStreamEvent> {
if (done) {
return { value: undefined, done: true };
}
return new Promise<MessageStreamEvent | undefined>((resolve) => readQueue.push(resolve)).then(
(chunk) => (chunk ? { value: chunk, done: false } : { value: undefined, done: true }),
);
return new Promise<MessageStreamEvent | undefined>((resolve, reject) =>
readQueue.push({ resolve, reject }),
).then((chunk) => (chunk ? { value: chunk, done: false } : { value: undefined, done: true }));
}
const chunk = pushQueue.shift()!;
return { value: chunk, done: false };
Expand Down
60 changes: 56 additions & 4 deletions tests/api-resources/MessageStream.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { PassThrough } from 'stream';
import { Response } from 'node-fetch';
import Anthropic, { APIUserAbortError } from '@anthropic-ai/sdk';
import Anthropic, { APIConnectionError, APIUserAbortError } from '@anthropic-ai/sdk';
import { Message, MessageStreamEvent } from '@anthropic-ai/sdk/resources/messages';
import { type RequestInfo, type RequestInit } from '@anthropic-ai/sdk/_shims/index';

Expand Down Expand Up @@ -336,10 +336,62 @@ describe('MessageStream class', () => {
}
}

await stream.done().catch((e) => {
expect(e).toBeInstanceOf(APIUserAbortError);
});
await expect(async () => stream.done()).rejects.toThrow(APIUserAbortError);

expect(stream.aborted).toBe(true);
});

it('handles network errors', async () => {
const { fetch, handleRequest } = mockFetch();

const anthropic = new Anthropic({ apiKey: '...', fetch });

const stream = anthropic.messages.stream(
{
max_tokens: 1024,
model: 'claude-2.1',
messages: [{ role: 'user', content: 'Say hello there!' }],
},
{ maxRetries: 0 },
);

handleRequest(async () => {
throw new Error('mock request error');
});

async function runStream() {
await stream.done();
}

await expect(runStream).rejects.toThrow(APIConnectionError);
});

it('handles network errors on async iterator', async () => {
const { fetch, handleRequest } = mockFetch();

const anthropic = new Anthropic({ apiKey: '...', fetch });

const stream = anthropic.messages.stream(
{
max_tokens: 1024,
model: 'claude-2.1',
messages: [{ role: 'user', content: 'Say hello there!' }],
},
{ maxRetries: 0 },
);

handleRequest(async () => {
throw new Error('mock request error');
});

async function runStream() {
for await (const event of stream) {
if (event.type === 'content_block_delta' && event.delta.text.includes('He')) {
break;
}
}
}

await expect(runStream).rejects.toThrow(APIConnectionError);
});
});

0 comments on commit 95232ed

Please sign in to comment.