diff --git a/package.json b/package.json index e4ee62738..2e728d452 100644 --- a/package.json +++ b/package.json @@ -93,7 +93,8 @@ "digest-fetch": "^1.3.0", "form-data-encoder": "1.7.2", "formdata-node": "^4.3.2", - "node-fetch": "^2.6.7" + "node-fetch": "^2.6.7", + "web-streams-polyfill": "^3.2.1" }, "devDependencies": { "@types/jest": "^29.4.0", diff --git a/src/_shims/auto/types.d.ts b/src/_shims/auto/types.d.ts index 9c1cc2550..d7755070b 100644 --- a/src/_shims/auto/types.d.ts +++ b/src/_shims/auto/types.d.ts @@ -96,4 +96,6 @@ export declare class FsReadStream extends Readable { // @ts-ignore type _ReadableStream = unknown extends ReadableStream ? never : ReadableStream; -export { type _ReadableStream as ReadableStream }; +// @ts-ignore +declare const _ReadableStream: unknown extends typeof ReadableStream ? never : typeof ReadableStream; +export { _ReadableStream as ReadableStream }; diff --git a/src/_shims/index.d.ts b/src/_shims/index.d.ts index 044f2cfcf..4e52b952e 100644 --- a/src/_shims/index.d.ts +++ b/src/_shims/index.d.ts @@ -62,6 +62,8 @@ export type Readable = SelectType; export type FsReadStream = SelectType; // @ts-ignore export type ReadableStream = SelectType; +// @ts-ignore +export const ReadableStream: SelectType; export function getMultipartRequestOptions>( form: FormData, diff --git a/src/_shims/node-runtime.ts b/src/_shims/node-runtime.ts index e2398e2b3..cc16e8542 100644 --- a/src/_shims/node-runtime.ts +++ b/src/_shims/node-runtime.ts @@ -13,6 +13,7 @@ import { Readable } from 'node:stream'; import { type RequestOptions } from '../core'; import { MultipartBody } from './MultipartBody'; import { type Shims } from './registry'; +import { ReadableStream } from 'web-streams-polyfill'; type FileFromPathOptions = Omit; @@ -71,6 +72,7 @@ export function getRuntime(): Shims { FormData: fd.FormData, Blob: fd.Blob, File: fd.File, + ReadableStream, getMultipartRequestOptions, getDefaultAgent: (url: string): Agent => (url.startsWith('https') ? defaultHttpsAgent : defaultHttpAgent), fileFromPath, diff --git a/src/_shims/node-types.d.ts b/src/_shims/node-types.d.ts index 28fe60499..b31698f78 100644 --- a/src/_shims/node-types.d.ts +++ b/src/_shims/node-types.d.ts @@ -7,7 +7,7 @@ import * as fd from 'formdata-node'; export { type Agent } from 'node:http'; export { type Readable } from 'node:stream'; export { type ReadStream as FsReadStream } from 'node:fs'; -export { type ReadableStream } from 'web-streams-polyfill'; +export { ReadableStream } from 'web-streams-polyfill'; export const fetch: typeof nf.default; diff --git a/src/_shims/registry.ts b/src/_shims/registry.ts index 0e0706877..65b570d0c 100644 --- a/src/_shims/registry.ts +++ b/src/_shims/registry.ts @@ -12,6 +12,7 @@ export interface Shims { FormData: any; Blob: any; File: any; + ReadableStream: any; getMultipartRequestOptions: >( form: Shims['FormData'], opts: RequestOptions, @@ -32,6 +33,7 @@ export let Headers: Shims['Headers'] | undefined = undefined; export let FormData: Shims['FormData'] | undefined = undefined; export let Blob: Shims['Blob'] | undefined = undefined; export let File: Shims['File'] | undefined = undefined; +export let ReadableStream: Shims['ReadableStream'] | undefined = undefined; export let getMultipartRequestOptions: Shims['getMultipartRequestOptions'] | undefined = undefined; export let getDefaultAgent: Shims['getDefaultAgent'] | undefined = undefined; export let fileFromPath: Shims['fileFromPath'] | undefined = undefined; @@ -55,6 +57,7 @@ export function setShims(shims: Shims, options: { auto: boolean } = { auto: fals FormData = shims.FormData; Blob = shims.Blob; File = shims.File; + ReadableStream = shims.ReadableStream; getMultipartRequestOptions = shims.getMultipartRequestOptions; getDefaultAgent = shims.getDefaultAgent; fileFromPath = shims.fileFromPath; diff --git a/src/_shims/web-runtime.ts b/src/_shims/web-runtime.ts index 12d73e965..92dadeb89 100644 --- a/src/_shims/web-runtime.ts +++ b/src/_shims/web-runtime.ts @@ -72,6 +72,18 @@ export function getRuntime({ manuallyImported }: { manuallyImported?: boolean } } } ), + ReadableStream: + // @ts-ignore + typeof ReadableStream !== 'undefined' ? ReadableStream : ( + class ReadableStream { + // @ts-ignore + constructor() { + throw new Error( + `streaming isn't supported in this environment yet as 'ReadableStream' is undefined. ${recommendation}`, + ); + } + } + ), getMultipartRequestOptions: async >( // @ts-ignore form: FormData, diff --git a/src/_shims/web-types.d.ts b/src/_shims/web-types.d.ts index ec96cd817..4ff351383 100644 --- a/src/_shims/web-types.d.ts +++ b/src/_shims/web-types.d.ts @@ -79,4 +79,5 @@ export declare class FsReadStream extends Readable { } type _ReadableStream = ReadableStream; -export { type _ReadableStream as ReadableStream }; +declare const _ReadableStream: typeof ReadableStream; +export { _ReadableStream as ReadableStream }; diff --git a/src/core.ts b/src/core.ts index a715f4507..73b2f5574 100644 --- a/src/core.ts +++ b/src/core.ts @@ -44,7 +44,7 @@ async function defaultParseResponse(props: APIResponseProps): Promise { if (props.options.stream) { // Note: there is an invariant here that isn't represented in the type system // that if you set `stream: true` the response type must also be `Stream` - return new Stream(response, props.controller) as any; + return Stream.fromSSEResponse(response, props.controller) as any; } const contentType = response.headers.get('content-type'); diff --git a/src/error.ts b/src/error.ts index 873087c78..28a8b540f 100644 --- a/src/error.ts +++ b/src/error.ts @@ -19,7 +19,7 @@ export class APIError extends OpenAIError { message: string | undefined, headers: Headers | undefined, ) { - super(`${status} ${APIError.makeMessage(error, message)}`); + super(`${APIError.makeMessage(status, error, message)}`); this.status = status; this.headers = headers; @@ -30,13 +30,14 @@ export class APIError extends OpenAIError { this.type = data?.['type']; } - private static makeMessage(error: any, message: string | undefined) { + private static makeMessage(status: number | undefined, error: any, message: string | undefined) { return ( - error?.message ? + (status || '') + + (error?.message ? typeof error.message === 'string' ? error.message : JSON.stringify(error.message) : error ? JSON.stringify(error) - : message || 'status code (no body)' + : message || 'status code (no body)') ); } diff --git a/src/shims/node.ts b/src/shims/node.ts index 9273d4eae..73df5600c 100644 --- a/src/shims/node.ts +++ b/src/shims/node.ts @@ -45,6 +45,6 @@ declare module '../_shims/manual-types' { // @ts-ignore export type FsReadStream = types.FsReadStream; // @ts-ignore - export type ReadableStream = types.ReadableStream; + export import ReadableStream = types.ReadableStream; } } diff --git a/src/shims/web.ts b/src/shims/web.ts index 9970f1db8..f72d78444 100644 --- a/src/shims/web.ts +++ b/src/shims/web.ts @@ -45,6 +45,6 @@ declare module '../_shims/manual-types' { // @ts-ignore export type FsReadStream = types.FsReadStream; // @ts-ignore - export type ReadableStream = types.ReadableStream; + export import ReadableStream = types.ReadableStream; } } diff --git a/src/streaming.ts b/src/streaming.ts index 7f9ebe0a0..f69724d64 100644 --- a/src/streaming.ts +++ b/src/streaming.ts @@ -1,4 +1,4 @@ -import { type Response } from './_shims/index'; +import { ReadableStream, type Response } from './_shims/index'; import { OpenAIError } from './error'; type Bytes = string | ArrayBuffer | Uint8Array | Buffer | null | undefined; @@ -12,67 +12,175 @@ type ServerSentEvent = { export class Stream implements AsyncIterable { controller: AbortController; - private response: Response; - private decoder: SSEDecoder; - - constructor(response: Response, controller: AbortController) { - this.response = response; + constructor(private iterator: () => AsyncIterator, controller: AbortController) { this.controller = controller; - this.decoder = new SSEDecoder(); } - private async *iterMessages(): AsyncGenerator { - if (!this.response.body) { - this.controller.abort(); - throw new OpenAIError(`Attempted to iterate over a response with no body`); - } + static fromSSEResponse(response: Response, controller: AbortController) { + let consumed = false; + const decoder = new SSEDecoder(); + + async function* iterMessages(): AsyncGenerator { + if (!response.body) { + controller.abort(); + throw new OpenAIError(`Attempted to iterate over a response with no body`); + } + + const lineDecoder = new LineDecoder(); - const lineDecoder = new LineDecoder(); + const iter = readableStreamAsyncIterable(response.body); + for await (const chunk of iter) { + for (const line of lineDecoder.decode(chunk)) { + const sse = decoder.decode(line); + if (sse) yield sse; + } + } - const iter = readableStreamAsyncIterable(this.response.body); - for await (const chunk of iter) { - for (const line of lineDecoder.decode(chunk)) { - const sse = this.decoder.decode(line); + for (const line of lineDecoder.flush()) { + const sse = decoder.decode(line); if (sse) yield sse; } } - for (const line of lineDecoder.flush()) { - const sse = this.decoder.decode(line); - if (sse) yield sse; + async function* iterator(): AsyncIterator { + if (consumed) { + throw new Error('Cannot iterate over a consumed stream, use `.tee()` to split the stream.'); + } + consumed = true; + let done = false; + try { + for await (const sse of iterMessages()) { + if (done) continue; + + if (sse.data.startsWith('[DONE]')) { + done = true; + continue; + } + + if (sse.event === null) { + try { + yield JSON.parse(sse.data); + } catch (e) { + console.error(`Could not parse message into JSON:`, sse.data); + console.error(`From chunk:`, sse.raw); + throw e; + } + } + } + done = true; + } catch (e) { + // If the user calls `stream.controller.abort()`, we should exit without throwing. + if (e instanceof Error && e.name === 'AbortError') return; + throw e; + } finally { + // If the user `break`s, abort the ongoing request. + if (!done) controller.abort(); + } } + + return new Stream(iterator, controller); } - async *[Symbol.asyncIterator](): AsyncIterator { - let done = false; - try { - for await (const sse of this.iterMessages()) { - if (done) continue; + // Generates a Stream from a newline-separated ReadableStream where each item + // is a JSON Value. + static fromReadableStream(readableStream: ReadableStream, controller: AbortController) { + let consumed = false; - if (sse.data.startsWith('[DONE]')) { - done = true; - continue; + async function* iterLines(): AsyncGenerator { + const lineDecoder = new LineDecoder(); + + const iter = readableStreamAsyncIterable(readableStream); + for await (const chunk of iter) { + for (const line of lineDecoder.decode(chunk)) { + yield line; } + } - if (sse.event === null) { - try { - yield JSON.parse(sse.data); - } catch (e) { - console.error(`Could not parse message into JSON:`, sse.data); - console.error(`From chunk:`, sse.raw); - throw e; - } + for (const line of lineDecoder.flush()) { + yield line; + } + } + + async function* iterator(): AsyncIterator { + if (consumed) { + throw new Error('Cannot iterate over a consumed stream, use `.tee()` to split the stream.'); + } + consumed = true; + let done = false; + try { + for await (const line of iterLines()) { + if (done) continue; + if (line) yield JSON.parse(line); } + done = true; + } catch (e) { + // If the user calls `stream.controller.abort()`, we should exit without throwing. + if (e instanceof Error && e.name === 'AbortError') return; + throw e; + } finally { + // If the user `break`s, abort the ongoing request. + if (!done) controller.abort(); } - done = true; - } catch (e) { - // If the user calls `stream.controller.abort()`, we should exit without throwing. - if (e instanceof Error && e.name === 'AbortError') return; - throw e; - } finally { - // If the user `break`s, abort the ongoing request. - if (!done) this.controller.abort(); } + + return new Stream(iterator, controller); + } + + [Symbol.asyncIterator](): AsyncIterator { + return this.iterator(); + } + + tee(): [Stream, Stream] { + const left: Array>> = []; + const right: Array>> = []; + const iterator = this.iterator(); + + const teeIterator = (queue: Array>>): AsyncIterator => { + return { + next: () => { + if (queue.length === 0) { + const result = iterator.next(); + left.push(result); + right.push(result); + } + return queue.shift()!; + }, + }; + }; + + return [ + new Stream(() => teeIterator(left), this.controller), + new Stream(() => teeIterator(right), this.controller), + ]; + } + + // Converts this stream to a newline-separated ReadableStream of JSON Stringified values in the stream + // which can be turned back into a Stream with Stream.fromReadableStream. + toReadableStream(): ReadableStream { + const self = this; + let iter: AsyncIterator; + const encoder = new TextEncoder(); + + return new ReadableStream({ + async start() { + iter = self[Symbol.asyncIterator](); + }, + async pull(ctrl) { + try { + const { value, done } = await iter.next(); + if (done) return ctrl.close(); + + const bytes = encoder.encode(JSON.stringify(value) + '\n'); + + ctrl.enqueue(bytes); + } catch (err) { + ctrl.error(err); + } + }, + async cancel() { + await iter.return?.(); + }, + }); } } diff --git a/yarn.lock b/yarn.lock index fb2b9a2c1..d01ab81de 100644 --- a/yarn.lock +++ b/yarn.lock @@ -4067,6 +4067,11 @@ web-streams-polyfill@4.0.0-beta.1: resolved "https://registry.yarnpkg.com/web-streams-polyfill/-/web-streams-polyfill-4.0.0-beta.1.tgz#3b19b9817374b7cee06d374ba7eeb3aeb80e8c95" integrity sha512-3ux37gEX670UUphBF9AMCq8XM6iQ8Ac6A+DSRRjDoRBm1ufCkaCDdNVbaqq60PsEkdNlLKrGtv/YBP4EJXqNtQ== +web-streams-polyfill@^3.2.1: + version "3.2.1" + resolved "https://registry.yarnpkg.com/web-streams-polyfill/-/web-streams-polyfill-3.2.1.tgz#71c2718c52b45fd49dbeee88634b3a60ceab42a6" + integrity sha512-e0MO3wdXWKrLbL0DgGnUV7WHVuw9OUvL4hjgnPkIeEvESk74gAITi5G606JtZPp39cd8HA9VQzCIvA49LpPN5Q== + webidl-conversions@^3.0.0: version "3.0.1" resolved "https://registry.yarnpkg.com/webidl-conversions/-/webidl-conversions-3.0.1.tgz#24534275e2a7bc6be7bc86611cc16ae0a5654871"