From 95232eda6bcdf420fbf85625c7ff6baabab60f08 Mon Sep 17 00:00:00 2001 From: Stainless Bot <107565488+stainless-bot@users.noreply.github.com> Date: Fri, 1 Mar 2024 01:32:26 +0100 Subject: [PATCH] fix(MessageStream): handle errors more gracefully in async iterator (#301) Errors that were thrown in the async iterator of the MessageStream would previously throw as an unhandled promise rejection. See: https://github.com/anthropics/anthropic-sdk-typescript/issues/298 --- src/lib/MessageStream.ts | 31 +++++++++--- tests/api-resources/MessageStream.test.ts | 60 +++++++++++++++++++++-- 2 files changed, 81 insertions(+), 10 deletions(-) diff --git a/src/lib/MessageStream.ts b/src/lib/MessageStream.ts index 0ffdd5be..628b8255 100644 --- a/src/lib/MessageStream.ts +++ b/src/lib/MessageStream.ts @@ -468,13 +468,16 @@ export class MessageStream implements AsyncIterable { [Symbol.asyncIterator](): AsyncIterator { const pushQueue: MessageStreamEvent[] = []; - const readQueue: ((chunk: MessageStreamEvent | undefined) => void)[] = []; + const readQueue: { + resolve: (chunk: MessageStreamEvent | undefined) => void; + reject: (error: unknown) => void; + }[] = []; let done = false; this.on('streamEvent', (event) => { const reader = readQueue.shift(); if (reader) { - reader(event); + reader.resolve(event); } else { pushQueue.push(event); } @@ -483,7 +486,23 @@ export class MessageStream implements AsyncIterable { this.on('end', () => { done = true; for (const reader of readQueue) { - reader(undefined); + reader.resolve(undefined); + } + readQueue.length = 0; + }); + + this.on('abort', (err) => { + done = true; + for (const reader of readQueue) { + reader.reject(err); + } + readQueue.length = 0; + }); + + this.on('error', (err) => { + done = true; + for (const reader of readQueue) { + reader.reject(err); } readQueue.length = 0; }); @@ -494,9 +513,9 @@ export class MessageStream implements AsyncIterable { if (done) { return { value: undefined, done: true }; } - return new Promise((resolve) => readQueue.push(resolve)).then( - (chunk) => (chunk ? { value: chunk, done: false } : { value: undefined, done: true }), - ); + return new Promise((resolve, reject) => + readQueue.push({ resolve, reject }), + ).then((chunk) => (chunk ? { value: chunk, done: false } : { value: undefined, done: true })); } const chunk = pushQueue.shift()!; return { value: chunk, done: false }; diff --git a/tests/api-resources/MessageStream.test.ts b/tests/api-resources/MessageStream.test.ts index 552771af..bb8861aa 100644 --- a/tests/api-resources/MessageStream.test.ts +++ b/tests/api-resources/MessageStream.test.ts @@ -1,6 +1,6 @@ import { PassThrough } from 'stream'; import { Response } from 'node-fetch'; -import Anthropic, { APIUserAbortError } from '@anthropic-ai/sdk'; +import Anthropic, { APIConnectionError, APIUserAbortError } from '@anthropic-ai/sdk'; import { Message, MessageStreamEvent } from '@anthropic-ai/sdk/resources/messages'; import { type RequestInfo, type RequestInit } from '@anthropic-ai/sdk/_shims/index'; @@ -336,10 +336,62 @@ describe('MessageStream class', () => { } } - await stream.done().catch((e) => { - expect(e).toBeInstanceOf(APIUserAbortError); - }); + await expect(async () => stream.done()).rejects.toThrow(APIUserAbortError); expect(stream.aborted).toBe(true); }); + + it('handles network errors', async () => { + const { fetch, handleRequest } = mockFetch(); + + const anthropic = new Anthropic({ apiKey: '...', fetch }); + + const stream = anthropic.messages.stream( + { + max_tokens: 1024, + model: 'claude-2.1', + messages: [{ role: 'user', content: 'Say hello there!' }], + }, + { maxRetries: 0 }, + ); + + handleRequest(async () => { + throw new Error('mock request error'); + }); + + async function runStream() { + await stream.done(); + } + + await expect(runStream).rejects.toThrow(APIConnectionError); + }); + + it('handles network errors on async iterator', async () => { + const { fetch, handleRequest } = mockFetch(); + + const anthropic = new Anthropic({ apiKey: '...', fetch }); + + const stream = anthropic.messages.stream( + { + max_tokens: 1024, + model: 'claude-2.1', + messages: [{ role: 'user', content: 'Say hello there!' }], + }, + { maxRetries: 0 }, + ); + + handleRequest(async () => { + throw new Error('mock request error'); + }); + + async function runStream() { + for await (const event of stream) { + if (event.type === 'content_block_delta' && event.delta.text.includes('He')) { + break; + } + } + } + + await expect(runStream).rejects.toThrow(APIConnectionError); + }); });