From a3ff1ff797bc929bec89a46563674d83cbd83284 Mon Sep 17 00:00:00 2001 From: Robert Craigie Date: Tue, 13 Aug 2024 18:51:12 -0400 Subject: [PATCH] feat(client): add streaming helpers --- src/lib/PromptCachingBetaMessageStream.ts | 579 ++++++++++++++++++ src/resources/beta/prompt-caching/messages.ts | 10 + 2 files changed, 589 insertions(+) create mode 100644 src/lib/PromptCachingBetaMessageStream.ts diff --git a/src/lib/PromptCachingBetaMessageStream.ts b/src/lib/PromptCachingBetaMessageStream.ts new file mode 100644 index 00000000..0e742cba --- /dev/null +++ b/src/lib/PromptCachingBetaMessageStream.ts @@ -0,0 +1,579 @@ +import * as Core from '@anthropic-ai/sdk/core'; +import { AnthropicError, APIUserAbortError } from '@anthropic-ai/sdk/error'; +import { type ContentBlock, type TextBlock } from '@anthropic-ai/sdk/resources/messages'; +import { + Messages, + type PromptCachingBetaMessage, + type RawPromptCachingBetaMessageStreamEvent, + type PromptCachingBetaMessageParam, + type MessageCreateParams, + type MessageCreateParamsBase, +} from '@anthropic-ai/sdk/resources/beta/prompt-caching/messages'; +import { type ReadableStream } from '@anthropic-ai/sdk/_shims/index'; +import { Stream } from '@anthropic-ai/sdk/streaming'; +import { partialParse } from '../_vendor/partial-json-parser/parser'; + +export interface PromptCachingBetaMessageStreamEvents { + connect: () => void; + streamEvent: (event: RawPromptCachingBetaMessageStreamEvent, snapshot: PromptCachingBetaMessage) => void; + text: (textDelta: string, textSnapshot: string) => void; + inputJson: (partialJson: string, jsonSnapshot: unknown) => void; + message: (message: PromptCachingBetaMessage) => void; + contentBlock: (content: ContentBlock) => void; + finalPromptCachingBetaMessage: (message: PromptCachingBetaMessage) => void; + error: (error: AnthropicError) => void; + abort: (error: APIUserAbortError) => void; + end: () => void; +} + +type PromptCachingBetaMessageStreamEventListeners = + { + listener: PromptCachingBetaMessageStreamEvents[Event]; + once?: boolean; + }[]; + +const JSON_BUF_PROPERTY = '__json_buf'; + +export class PromptCachingBetaMessageStream implements AsyncIterable { + messages: PromptCachingBetaMessageParam[] = []; + receivedMessages: PromptCachingBetaMessage[] = []; + #currentMessageSnapshot: PromptCachingBetaMessage | undefined; + + controller: AbortController = new AbortController(); + + #connectedPromise: Promise; + #resolveConnectedPromise: () => void = () => {}; + #rejectConnectedPromise: (error: AnthropicError) => void = () => {}; + + #endPromise: Promise; + #resolveEndPromise: () => void = () => {}; + #rejectEndPromise: (error: AnthropicError) => void = () => {}; + + #listeners: { + [Event in keyof PromptCachingBetaMessageStreamEvents]?: PromptCachingBetaMessageStreamEventListeners; + } = {}; + + #ended = false; + #errored = false; + #aborted = false; + #catchingPromiseCreated = false; + + constructor() { + this.#connectedPromise = new Promise((resolve, reject) => { + this.#resolveConnectedPromise = resolve; + this.#rejectConnectedPromise = reject; + }); + + this.#endPromise = new Promise((resolve, reject) => { + this.#resolveEndPromise = resolve; + this.#rejectEndPromise = reject; + }); + + // Don't let these promises cause unhandled rejection errors. + // we will manually cause an unhandled rejection error later + // if the user hasn't registered any error listener or called + // any promise-returning method. + this.#connectedPromise.catch(() => {}); + this.#endPromise.catch(() => {}); + } + + /** + * Intended for use on the frontend, consuming a stream produced with + * `.toReadableStream()` on the backend. + * + * Note that messages sent to the model do not appear in `.on('message')` + * in this context. + */ + static fromReadableStream(stream: ReadableStream): PromptCachingBetaMessageStream { + const runner = new PromptCachingBetaMessageStream(); + runner._run(() => runner._fromReadableStream(stream)); + return runner; + } + + static createMessage( + messages: Messages, + params: MessageCreateParamsBase, + options?: Core.RequestOptions, + ): PromptCachingBetaMessageStream { + const runner = new PromptCachingBetaMessageStream(); + for (const message of params.messages) { + runner._addPromptCachingBetaMessageParam(message); + } + runner._run(() => + runner._createPromptCachingBetaMessage( + messages, + { ...params, stream: true }, + { ...options, headers: { ...options?.headers, 'X-Stainless-Helper-Method': 'stream' } }, + ), + ); + return runner; + } + + protected _run(executor: () => Promise) { + executor().then(() => { + this._emitFinal(); + this._emit('end'); + }, this.#handleError); + } + + protected _addPromptCachingBetaMessageParam(message: PromptCachingBetaMessageParam) { + this.messages.push(message); + } + + protected _addPromptCachingBetaMessage(message: PromptCachingBetaMessage, emit = true) { + this.receivedMessages.push(message); + if (emit) { + this._emit('message', message); + } + } + + protected async _createPromptCachingBetaMessage( + messages: Messages, + params: MessageCreateParams, + options?: Core.RequestOptions, + ): Promise { + const signal = options?.signal; + if (signal) { + if (signal.aborted) this.controller.abort(); + signal.addEventListener('abort', () => this.controller.abort()); + } + this.#beginRequest(); + const stream = await messages.create( + { ...params, stream: true }, + { ...options, signal: this.controller.signal }, + ); + this._connected(); + for await (const event of stream) { + this.#addStreamEvent(event); + } + if (stream.controller.signal?.aborted) { + throw new APIUserAbortError(); + } + this.#endRequest(); + } + + protected _connected() { + if (this.ended) return; + this.#resolveConnectedPromise(); + this._emit('connect'); + } + + get ended(): boolean { + return this.#ended; + } + + get errored(): boolean { + return this.#errored; + } + + get aborted(): boolean { + return this.#aborted; + } + + abort() { + this.controller.abort(); + } + + /** + * Adds the listener function to the end of the listeners array for the event. + * No checks are made to see if the listener has already been added. Multiple calls passing + * the same combination of event and listener will result in the listener being added, and + * called, multiple times. + * @returns this PromptCachingBetaMessageStream, so that calls can be chained + */ + on( + event: Event, + listener: PromptCachingBetaMessageStreamEvents[Event], + ): this { + const listeners: PromptCachingBetaMessageStreamEventListeners = + this.#listeners[event] || (this.#listeners[event] = []); + listeners.push({ listener }); + return this; + } + + /** + * Removes the specified listener from the listener array for the event. + * off() will remove, at most, one instance of a listener from the listener array. If any single + * listener has been added multiple times to the listener array for the specified event, then + * off() must be called multiple times to remove each instance. + * @returns this PromptCachingBetaMessageStream, so that calls can be chained + */ + off( + event: Event, + listener: PromptCachingBetaMessageStreamEvents[Event], + ): this { + const listeners = this.#listeners[event]; + if (!listeners) return this; + const index = listeners.findIndex((l) => l.listener === listener); + if (index >= 0) listeners.splice(index, 1); + return this; + } + + /** + * Adds a one-time listener function for the event. The next time the event is triggered, + * this listener is removed and then invoked. + * @returns this PromptCachingBetaMessageStream, so that calls can be chained + */ + once( + event: Event, + listener: PromptCachingBetaMessageStreamEvents[Event], + ): this { + const listeners: PromptCachingBetaMessageStreamEventListeners = + this.#listeners[event] || (this.#listeners[event] = []); + listeners.push({ listener, once: true }); + return this; + } + + /** + * This is similar to `.once()`, but returns a Promise that resolves the next time + * the event is triggered, instead of calling a listener callback. + * @returns a Promise that resolves the next time given event is triggered, + * or rejects if an error is emitted. (If you request the 'error' event, + * returns a promise that resolves with the error). + * + * Example: + * + * const message = await stream.emitted('message') // rejects if the stream errors + */ + emitted( + event: Event, + ): Promise< + Parameters extends [infer Param] ? Param + : Parameters extends [] ? void + : Parameters + > { + return new Promise((resolve, reject) => { + this.#catchingPromiseCreated = true; + if (event !== 'error') this.once('error', reject); + this.once(event, resolve as any); + }); + } + + async done(): Promise { + this.#catchingPromiseCreated = true; + await this.#endPromise; + } + + get currentMessage(): PromptCachingBetaMessage | undefined { + return this.#currentMessageSnapshot; + } + + #getFinalMessage(): PromptCachingBetaMessage { + if (this.receivedMessages.length === 0) { + throw new AnthropicError( + 'stream ended without producing a PromptCachingBetaMessage with role=assistant', + ); + } + return this.receivedMessages.at(-1)!; + } + + /** + * @returns a promise that resolves with the the final assistant PromptCachingBetaMessage response, + * or rejects if an error occurred or the stream ended prematurely without producing a PromptCachingBetaMessage. + */ + async finalMessage(): Promise { + await this.done(); + return this.#getFinalMessage(); + } + + #getFinalText(): string { + if (this.receivedMessages.length === 0) { + throw new AnthropicError( + 'stream ended without producing a PromptCachingBetaMessage with role=assistant', + ); + } + const textBlocks = this.receivedMessages + .at(-1)! + .content.filter((block): block is TextBlock => block.type === 'text') + .map((block) => block.text); + if (textBlocks.length === 0) { + throw new AnthropicError('stream ended without producing a content block with type=text'); + } + return textBlocks.join(' '); + } + + /** + * @returns a promise that resolves with the the final assistant PromptCachingBetaMessage's text response, concatenated + * together if there are more than one text blocks. + * Rejects if an error occurred or the stream ended prematurely without producing a PromptCachingBetaMessage. + */ + async finalText(): Promise { + await this.done(); + return this.#getFinalText(); + } + + #handleError = (error: unknown) => { + this.#errored = true; + if (error instanceof Error && error.name === 'AbortError') { + error = new APIUserAbortError(); + } + if (error instanceof APIUserAbortError) { + this.#aborted = true; + return this._emit('abort', error); + } + if (error instanceof AnthropicError) { + return this._emit('error', error); + } + if (error instanceof Error) { + const anthropicError: AnthropicError = new AnthropicError(error.message); + // @ts-ignore + anthropicError.cause = error; + return this._emit('error', anthropicError); + } + return this._emit('error', new AnthropicError(String(error))); + }; + + protected _emit( + event: Event, + ...args: Parameters + ) { + // make sure we don't emit any PromptCachingBetaMessageStreamEvents after end + if (this.#ended) return; + + if (event === 'end') { + this.#ended = true; + this.#resolveEndPromise(); + } + + const listeners: PromptCachingBetaMessageStreamEventListeners | undefined = this.#listeners[event]; + if (listeners) { + this.#listeners[event] = listeners.filter((l) => !l.once) as any; + listeners.forEach(({ listener }: any) => listener(...args)); + } + + if (event === 'abort') { + const error = args[0] as APIUserAbortError; + if (!this.#catchingPromiseCreated && !listeners?.length) { + Promise.reject(error); + } + this.#rejectConnectedPromise(error); + this.#rejectEndPromise(error); + this._emit('end'); + return; + } + + if (event === 'error') { + // NOTE: _emit('error', error) should only be called from #handleError(). + + const error = args[0] as AnthropicError; + if (!this.#catchingPromiseCreated && !listeners?.length) { + // Trigger an unhandled rejection if the user hasn't registered any error handlers. + // If you are seeing stack traces here, make sure to handle errors via either: + // - runner.on('error', () => ...) + // - await runner.done() + // - await runner.final...() + // - etc. + Promise.reject(error); + } + this.#rejectConnectedPromise(error); + this.#rejectEndPromise(error); + this._emit('end'); + } + } + + protected _emitFinal() { + const finalPromptCachingBetaMessage = this.receivedMessages.at(-1); + if (finalPromptCachingBetaMessage) { + this._emit('finalPromptCachingBetaMessage', this.#getFinalMessage()); + } + } + + #beginRequest() { + if (this.ended) return; + this.#currentMessageSnapshot = undefined; + } + #addStreamEvent(event: RawPromptCachingBetaMessageStreamEvent) { + if (this.ended) return; + const messageSnapshot = this.#accumulateMessage(event); + this._emit('streamEvent', event, messageSnapshot); + + switch (event.type) { + case 'content_block_delta': { + const content = messageSnapshot.content.at(-1)!; + if (event.delta.type === 'text_delta' && content.type === 'text') { + this._emit('text', event.delta.text, content.text || ''); + } else if (event.delta.type === 'input_json_delta' && content.type === 'tool_use') { + if (content.input) { + this._emit('inputJson', event.delta.partial_json, content.input); + } + } + break; + } + case 'message_stop': { + this._addPromptCachingBetaMessageParam(messageSnapshot); + this._addPromptCachingBetaMessage(messageSnapshot, true); + break; + } + case 'content_block_stop': { + this._emit('contentBlock', messageSnapshot.content.at(-1)!); + break; + } + case 'message_start': { + this.#currentMessageSnapshot = messageSnapshot; + break; + } + case 'content_block_start': + case 'message_delta': + break; + } + } + #endRequest(): PromptCachingBetaMessage { + if (this.ended) { + throw new AnthropicError(`stream has ended, this shouldn't happen`); + } + const snapshot = this.#currentMessageSnapshot; + if (!snapshot) { + throw new AnthropicError(`request ended without sending any chunks`); + } + this.#currentMessageSnapshot = undefined; + return snapshot; + } + + protected async _fromReadableStream( + readableStream: ReadableStream, + options?: Core.RequestOptions, + ): Promise { + const signal = options?.signal; + if (signal) { + if (signal.aborted) this.controller.abort(); + signal.addEventListener('abort', () => this.controller.abort()); + } + this.#beginRequest(); + this._connected(); + const stream = Stream.fromReadableStream( + readableStream, + this.controller, + ); + for await (const event of stream) { + this.#addStreamEvent(event); + } + if (stream.controller.signal?.aborted) { + throw new APIUserAbortError(); + } + this.#endRequest(); + } + + /** + * Mutates this.#currentPromptCachingBetaMessage with the current event. Handling the accumulation of multiple messages + * will be needed to be handled by the caller, this method will throw if you try to accumulate for multiple + * messages. + */ + #accumulateMessage(event: RawPromptCachingBetaMessageStreamEvent): PromptCachingBetaMessage { + let snapshot = this.#currentMessageSnapshot; + + if (event.type === 'message_start') { + if (snapshot) { + throw new AnthropicError(`Unexpected event order, got ${event.type} before receiving "message_stop"`); + } + return event.message; + } + + if (!snapshot) { + throw new AnthropicError(`Unexpected event order, got ${event.type} before "message_start"`); + } + + switch (event.type) { + case 'message_stop': + return snapshot; + case 'message_delta': + snapshot.stop_reason = event.delta.stop_reason; + snapshot.stop_sequence = event.delta.stop_sequence; + snapshot.usage.output_tokens = event.usage.output_tokens; + return snapshot; + case 'content_block_start': + snapshot.content.push(event.content_block); + return snapshot; + case 'content_block_delta': { + const snapshotContent = snapshot.content.at(event.index); + if (snapshotContent?.type === 'text' && event.delta.type === 'text_delta') { + snapshotContent.text += event.delta.text; + } else if (snapshotContent?.type === 'tool_use' && event.delta.type === 'input_json_delta') { + // we need to keep track of the raw JSON string as well so that we can + // re-parse it for each delta, for now we just store it as an untyped + // non-enumerable property on the snapshot + let jsonBuf = (snapshotContent as any)[JSON_BUF_PROPERTY] || ''; + jsonBuf += event.delta.partial_json; + + Object.defineProperty(snapshotContent, JSON_BUF_PROPERTY, { + value: jsonBuf, + enumerable: false, + writable: true, + }); + + if (jsonBuf) { + snapshotContent.input = partialParse(jsonBuf); + } + } + return snapshot; + } + case 'content_block_stop': + return snapshot; + } + } + + [Symbol.asyncIterator](): AsyncIterator { + const pushQueue: RawPromptCachingBetaMessageStreamEvent[] = []; + const readQueue: { + resolve: (chunk: RawPromptCachingBetaMessageStreamEvent | undefined) => void; + reject: (error: unknown) => void; + }[] = []; + let done = false; + + this.on('streamEvent', (event) => { + const reader = readQueue.shift(); + if (reader) { + reader.resolve(event); + } else { + pushQueue.push(event); + } + }); + + this.on('end', () => { + done = true; + for (const reader of readQueue) { + reader.resolve(undefined); + } + readQueue.length = 0; + }); + + this.on('abort', (err) => { + done = true; + for (const reader of readQueue) { + reader.reject(err); + } + readQueue.length = 0; + }); + + this.on('error', (err) => { + done = true; + for (const reader of readQueue) { + reader.reject(err); + } + readQueue.length = 0; + }); + + return { + next: async (): Promise> => { + if (!pushQueue.length) { + if (done) { + return { value: undefined, done: true }; + } + return new Promise((resolve, reject) => + readQueue.push({ resolve, reject }), + ).then((chunk) => (chunk ? { value: chunk, done: false } : { value: undefined, done: true })); + } + const chunk = pushQueue.shift()!; + return { value: chunk, done: false }; + }, + return: async () => { + this.abort(); + return { value: undefined, done: true }; + }, + }; + } + + toReadableStream(): ReadableStream { + const stream = new Stream(this[Symbol.asyncIterator].bind(this), this.controller); + return stream.toReadableStream(); + } +} diff --git a/src/resources/beta/prompt-caching/messages.ts b/src/resources/beta/prompt-caching/messages.ts index 2c27dd36..77f4d475 100644 --- a/src/resources/beta/prompt-caching/messages.ts +++ b/src/resources/beta/prompt-caching/messages.ts @@ -6,6 +6,7 @@ import * as Core from '../../../core'; import * as PromptCachingMessagesAPI from './messages'; import * as MessagesAPI from '../../messages'; import { Stream } from '../../../streaming'; +import { PromptCachingBetaMessageStream } from '../../../lib/PromptCachingBetaMessageStream'; export class Messages extends APIResource { /** @@ -41,8 +42,17 @@ export class Messages extends APIResource { stream: body.stream ?? false, }) as APIPromise | APIPromise>; } + + /** + * Create a Message stream + */ + stream(body: MessageStreamParams, options?: Core.RequestOptions): PromptCachingBetaMessageStream { + return PromptCachingBetaMessageStream.createMessage(this, body, options); + } } +export type MessageStreamParams = MessageCreateParamsBase; + export interface PromptCachingBetaCacheControlEphemeral { type: 'ephemeral'; }