diff --git a/src/lib/BetaMessageStream.ts b/src/lib/BetaMessageStream.ts index 2baa482a..5de54f06 100644 --- a/src/lib/BetaMessageStream.ts +++ b/src/lib/BetaMessageStream.ts @@ -9,6 +9,7 @@ import { type MessageCreateParams as BetaMessageCreateParams, type MessageCreateParamsBase as BetaMessageCreateParamsBase, type BetaTextBlock, + type BetaTextCitation, } from '@anthropic-ai/sdk/resources/beta/messages/messages'; import { type ReadableStream, type Response } from '@anthropic-ai/sdk/_shims/index'; import { Stream } from '@anthropic-ai/sdk/streaming'; @@ -18,6 +19,7 @@ export interface MessageStreamEvents { connect: () => void; streamEvent: (event: BetaMessageStreamEvent, snapshot: BetaMessage) => void; text: (textDelta: string, textSnapshot: string) => void; + citation: (citation: BetaTextCitation, citationsSnapshot: BetaTextCitation[]) => void; inputJson: (partialJson: string, jsonSnapshot: unknown) => void; message: (message: BetaMessage) => void; contentBlock: (content: BetaContentBlock) => void; @@ -413,12 +415,27 @@ export class BetaMessageStream implements AsyncIterable switch (event.type) { case 'content_block_delta': { const content = messageSnapshot.content.at(-1)!; - if (event.delta.type === 'text_delta' && content.type === 'text') { - this._emit('text', event.delta.text, content.text || ''); - } else if (event.delta.type === 'input_json_delta' && content.type === 'tool_use') { - if (content.input) { - this._emit('inputJson', event.delta.partial_json, content.input); + switch (event.delta.type) { + case 'text_delta': { + if (content.type === 'text') { + this._emit('text', event.delta.text, content.text || ''); + } + break; } + case 'citations_delta': { + if (content.type === 'text') { + this._emit('citation', event.delta.citation, content.citations ?? []); + } + break; + } + case 'input_json_delta': { + if (content.type === 'tool_use' && content.input) { + this._emit('inputJson', event.delta.partial_json, content.input); + } + break; + } + default: + checkNever(event.delta); } break; } @@ -505,24 +522,43 @@ export class BetaMessageStream implements AsyncIterable return snapshot; case 'content_block_delta': { const snapshotContent = snapshot.content.at(event.index); - if (snapshotContent?.type === 'text' && event.delta.type === 'text_delta') { - snapshotContent.text += event.delta.text; - } else if (snapshotContent?.type === 'tool_use' && event.delta.type === 'input_json_delta') { - // we need to keep track of the raw JSON string as well so that we can - // re-parse it for each delta, for now we just store it as an untyped - // non-enumerable property on the snapshot - let jsonBuf = (snapshotContent as any)[JSON_BUF_PROPERTY] || ''; - jsonBuf += event.delta.partial_json; - - Object.defineProperty(snapshotContent, JSON_BUF_PROPERTY, { - value: jsonBuf, - enumerable: false, - writable: true, - }); - - if (jsonBuf) { - snapshotContent.input = partialParse(jsonBuf); + + switch (event.delta.type) { + case 'text_delta': { + if (snapshotContent?.type === 'text') { + snapshotContent.text += event.delta.text; + } + break; } + case 'citations_delta': { + if (snapshotContent?.type === 'text') { + snapshotContent.citations ??= []; + snapshotContent.citations.push(event.delta.citation); + } + break; + } + case 'input_json_delta': { + if (snapshotContent?.type === 'tool_use') { + // we need to keep track of the raw JSON string as well so that we can + // re-parse it for each delta, for now we just store it as an untyped + // non-enumerable property on the snapshot + let jsonBuf = (snapshotContent as any)[JSON_BUF_PROPERTY] || ''; + jsonBuf += event.delta.partial_json; + + Object.defineProperty(snapshotContent, JSON_BUF_PROPERTY, { + value: jsonBuf, + enumerable: false, + writable: true, + }); + + if (jsonBuf) { + snapshotContent.input = partialParse(jsonBuf); + } + } + break; + } + default: + checkNever(event.delta); } return snapshot; } @@ -597,3 +633,6 @@ export class BetaMessageStream implements AsyncIterable return stream.toReadableStream(); } } + +// used to ensure exhaustive case matching without throwing a runtime error +function checkNever(x: never) {} diff --git a/src/lib/MessageStream.ts b/src/lib/MessageStream.ts index b47cded0..4ce3a382 100644 --- a/src/lib/MessageStream.ts +++ b/src/lib/MessageStream.ts @@ -9,6 +9,7 @@ import { type MessageCreateParams, type MessageCreateParamsBase, type TextBlock, + type TextCitation, } from '@anthropic-ai/sdk/resources/messages'; import { type ReadableStream, type Response } from '@anthropic-ai/sdk/_shims/index'; import { Stream } from '@anthropic-ai/sdk/streaming'; @@ -18,6 +19,7 @@ export interface MessageStreamEvents { connect: () => void; streamEvent: (event: MessageStreamEvent, snapshot: Message) => void; text: (textDelta: string, textSnapshot: string) => void; + citation: (citation: TextCitation, citationsSnapshot: TextCitation[]) => void; inputJson: (partialJson: string, jsonSnapshot: unknown) => void; message: (message: Message) => void; contentBlock: (content: ContentBlock) => void; @@ -413,12 +415,27 @@ export class MessageStream implements AsyncIterable { switch (event.type) { case 'content_block_delta': { const content = messageSnapshot.content.at(-1)!; - if (event.delta.type === 'text_delta' && content.type === 'text') { - this._emit('text', event.delta.text, content.text || ''); - } else if (event.delta.type === 'input_json_delta' && content.type === 'tool_use') { - if (content.input) { - this._emit('inputJson', event.delta.partial_json, content.input); + switch (event.delta.type) { + case 'text_delta': { + if (content.type === 'text') { + this._emit('text', event.delta.text, content.text || ''); + } + break; } + case 'citations_delta': { + if (content.type === 'text') { + this._emit('citation', event.delta.citation, content.citations ?? []); + } + break; + } + case 'input_json_delta': { + if (content.type === 'tool_use' && content.input) { + this._emit('inputJson', event.delta.partial_json, content.input); + } + break; + } + default: + checkNever(event.delta); } break; } @@ -505,25 +522,45 @@ export class MessageStream implements AsyncIterable { return snapshot; case 'content_block_delta': { const snapshotContent = snapshot.content.at(event.index); - if (snapshotContent?.type === 'text' && event.delta.type === 'text_delta') { - snapshotContent.text += event.delta.text; - } else if (snapshotContent?.type === 'tool_use' && event.delta.type === 'input_json_delta') { - // we need to keep track of the raw JSON string as well so that we can - // re-parse it for each delta, for now we just store it as an untyped - // non-enumerable property on the snapshot - let jsonBuf = (snapshotContent as any)[JSON_BUF_PROPERTY] || ''; - jsonBuf += event.delta.partial_json; - - Object.defineProperty(snapshotContent, JSON_BUF_PROPERTY, { - value: jsonBuf, - enumerable: false, - writable: true, - }); - - if (jsonBuf) { - snapshotContent.input = partialParse(jsonBuf); + + switch (event.delta.type) { + case 'text_delta': { + if (snapshotContent?.type === 'text') { + snapshotContent.text += event.delta.text; + } + break; + } + case 'citations_delta': { + if (snapshotContent?.type === 'text') { + snapshotContent.citations ??= []; + snapshotContent.citations.push(event.delta.citation); + } + break; } + case 'input_json_delta': { + if (snapshotContent?.type === 'tool_use') { + // we need to keep track of the raw JSON string as well so that we can + // re-parse it for each delta, for now we just store it as an untyped + // non-enumerable property on the snapshot + let jsonBuf = (snapshotContent as any)[JSON_BUF_PROPERTY] || ''; + jsonBuf += event.delta.partial_json; + + Object.defineProperty(snapshotContent, JSON_BUF_PROPERTY, { + value: jsonBuf, + enumerable: false, + writable: true, + }); + + if (jsonBuf) { + snapshotContent.input = partialParse(jsonBuf); + } + } + break; + } + default: + checkNever(event.delta); } + return snapshot; } case 'content_block_stop': @@ -597,3 +634,6 @@ export class MessageStream implements AsyncIterable { return stream.toReadableStream(); } } + +// used to ensure exhaustive case matching without throwing a runtime error +function checkNever(x: never) {}