Skip to content

Commit

Permalink
fix(streaming): accumulate citations (#675)
Browse files Browse the repository at this point in the history
  • Loading branch information
RobertCraigie authored Jan 27, 2025
1 parent 751ecd0 commit 522118f
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 44 deletions.
83 changes: 61 additions & 22 deletions src/lib/BetaMessageStream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import {
type MessageCreateParams as BetaMessageCreateParams,
type MessageCreateParamsBase as BetaMessageCreateParamsBase,
type BetaTextBlock,
type BetaTextCitation,
} from '@anthropic-ai/sdk/resources/beta/messages/messages';
import { type ReadableStream, type Response } from '@anthropic-ai/sdk/_shims/index';
import { Stream } from '@anthropic-ai/sdk/streaming';
Expand All @@ -18,6 +19,7 @@ export interface MessageStreamEvents {
connect: () => void;
streamEvent: (event: BetaMessageStreamEvent, snapshot: BetaMessage) => void;
text: (textDelta: string, textSnapshot: string) => void;
citation: (citation: BetaTextCitation, citationsSnapshot: BetaTextCitation[]) => void;
inputJson: (partialJson: string, jsonSnapshot: unknown) => void;
message: (message: BetaMessage) => void;
contentBlock: (content: BetaContentBlock) => void;
Expand Down Expand Up @@ -413,12 +415,27 @@ export class BetaMessageStream implements AsyncIterable<BetaMessageStreamEvent>
switch (event.type) {
case 'content_block_delta': {
const content = messageSnapshot.content.at(-1)!;
if (event.delta.type === 'text_delta' && content.type === 'text') {
this._emit('text', event.delta.text, content.text || '');
} else if (event.delta.type === 'input_json_delta' && content.type === 'tool_use') {
if (content.input) {
this._emit('inputJson', event.delta.partial_json, content.input);
switch (event.delta.type) {
case 'text_delta': {
if (content.type === 'text') {
this._emit('text', event.delta.text, content.text || '');
}
break;
}
case 'citations_delta': {
if (content.type === 'text') {
this._emit('citation', event.delta.citation, content.citations ?? []);
}
break;
}
case 'input_json_delta': {
if (content.type === 'tool_use' && content.input) {
this._emit('inputJson', event.delta.partial_json, content.input);
}
break;
}
default:
checkNever(event.delta);
}
break;
}
Expand Down Expand Up @@ -505,24 +522,43 @@ export class BetaMessageStream implements AsyncIterable<BetaMessageStreamEvent>
return snapshot;
case 'content_block_delta': {
const snapshotContent = snapshot.content.at(event.index);
if (snapshotContent?.type === 'text' && event.delta.type === 'text_delta') {
snapshotContent.text += event.delta.text;
} else if (snapshotContent?.type === 'tool_use' && event.delta.type === 'input_json_delta') {
// we need to keep track of the raw JSON string as well so that we can
// re-parse it for each delta, for now we just store it as an untyped
// non-enumerable property on the snapshot
let jsonBuf = (snapshotContent as any)[JSON_BUF_PROPERTY] || '';
jsonBuf += event.delta.partial_json;

Object.defineProperty(snapshotContent, JSON_BUF_PROPERTY, {
value: jsonBuf,
enumerable: false,
writable: true,
});

if (jsonBuf) {
snapshotContent.input = partialParse(jsonBuf);

switch (event.delta.type) {
case 'text_delta': {
if (snapshotContent?.type === 'text') {
snapshotContent.text += event.delta.text;
}
break;
}
case 'citations_delta': {
if (snapshotContent?.type === 'text') {
snapshotContent.citations ??= [];
snapshotContent.citations.push(event.delta.citation);
}
break;
}
case 'input_json_delta': {
if (snapshotContent?.type === 'tool_use') {
// we need to keep track of the raw JSON string as well so that we can
// re-parse it for each delta, for now we just store it as an untyped
// non-enumerable property on the snapshot
let jsonBuf = (snapshotContent as any)[JSON_BUF_PROPERTY] || '';
jsonBuf += event.delta.partial_json;

Object.defineProperty(snapshotContent, JSON_BUF_PROPERTY, {
value: jsonBuf,
enumerable: false,
writable: true,
});

if (jsonBuf) {
snapshotContent.input = partialParse(jsonBuf);
}
}
break;
}
default:
checkNever(event.delta);
}
return snapshot;
}
Expand Down Expand Up @@ -597,3 +633,6 @@ export class BetaMessageStream implements AsyncIterable<BetaMessageStreamEvent>
return stream.toReadableStream();
}
}

// used to ensure exhaustive case matching without throwing a runtime error
function checkNever(x: never) {}
84 changes: 62 additions & 22 deletions src/lib/MessageStream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import {
type MessageCreateParams,
type MessageCreateParamsBase,
type TextBlock,
type TextCitation,
} from '@anthropic-ai/sdk/resources/messages';
import { type ReadableStream, type Response } from '@anthropic-ai/sdk/_shims/index';
import { Stream } from '@anthropic-ai/sdk/streaming';
Expand All @@ -18,6 +19,7 @@ export interface MessageStreamEvents {
connect: () => void;
streamEvent: (event: MessageStreamEvent, snapshot: Message) => void;
text: (textDelta: string, textSnapshot: string) => void;
citation: (citation: TextCitation, citationsSnapshot: TextCitation[]) => void;
inputJson: (partialJson: string, jsonSnapshot: unknown) => void;
message: (message: Message) => void;
contentBlock: (content: ContentBlock) => void;
Expand Down Expand Up @@ -413,12 +415,27 @@ export class MessageStream implements AsyncIterable<MessageStreamEvent> {
switch (event.type) {
case 'content_block_delta': {
const content = messageSnapshot.content.at(-1)!;
if (event.delta.type === 'text_delta' && content.type === 'text') {
this._emit('text', event.delta.text, content.text || '');
} else if (event.delta.type === 'input_json_delta' && content.type === 'tool_use') {
if (content.input) {
this._emit('inputJson', event.delta.partial_json, content.input);
switch (event.delta.type) {
case 'text_delta': {
if (content.type === 'text') {
this._emit('text', event.delta.text, content.text || '');
}
break;
}
case 'citations_delta': {
if (content.type === 'text') {
this._emit('citation', event.delta.citation, content.citations ?? []);
}
break;
}
case 'input_json_delta': {
if (content.type === 'tool_use' && content.input) {
this._emit('inputJson', event.delta.partial_json, content.input);
}
break;
}
default:
checkNever(event.delta);
}
break;
}
Expand Down Expand Up @@ -505,25 +522,45 @@ export class MessageStream implements AsyncIterable<MessageStreamEvent> {
return snapshot;
case 'content_block_delta': {
const snapshotContent = snapshot.content.at(event.index);
if (snapshotContent?.type === 'text' && event.delta.type === 'text_delta') {
snapshotContent.text += event.delta.text;
} else if (snapshotContent?.type === 'tool_use' && event.delta.type === 'input_json_delta') {
// we need to keep track of the raw JSON string as well so that we can
// re-parse it for each delta, for now we just store it as an untyped
// non-enumerable property on the snapshot
let jsonBuf = (snapshotContent as any)[JSON_BUF_PROPERTY] || '';
jsonBuf += event.delta.partial_json;

Object.defineProperty(snapshotContent, JSON_BUF_PROPERTY, {
value: jsonBuf,
enumerable: false,
writable: true,
});

if (jsonBuf) {
snapshotContent.input = partialParse(jsonBuf);

switch (event.delta.type) {
case 'text_delta': {
if (snapshotContent?.type === 'text') {
snapshotContent.text += event.delta.text;
}
break;
}
case 'citations_delta': {
if (snapshotContent?.type === 'text') {
snapshotContent.citations ??= [];
snapshotContent.citations.push(event.delta.citation);
}
break;
}
case 'input_json_delta': {
if (snapshotContent?.type === 'tool_use') {
// we need to keep track of the raw JSON string as well so that we can
// re-parse it for each delta, for now we just store it as an untyped
// non-enumerable property on the snapshot
let jsonBuf = (snapshotContent as any)[JSON_BUF_PROPERTY] || '';
jsonBuf += event.delta.partial_json;

Object.defineProperty(snapshotContent, JSON_BUF_PROPERTY, {
value: jsonBuf,
enumerable: false,
writable: true,
});

if (jsonBuf) {
snapshotContent.input = partialParse(jsonBuf);
}
}
break;
}
default:
checkNever(event.delta);
}

return snapshot;
}
case 'content_block_stop':
Expand Down Expand Up @@ -597,3 +634,6 @@ export class MessageStream implements AsyncIterable<MessageStreamEvent> {
return stream.toReadableStream();
}
}

// used to ensure exhaustive case matching without throwing a runtime error
function checkNever(x: never) {}

0 comments on commit 522118f

Please sign in to comment.