Skip to content

Commit

Permalink
feat(streaming): add tools support
Browse files Browse the repository at this point in the history
  • Loading branch information
RobertCraigie committed May 30, 2024
1 parent 2decf85 commit 4c83bb1
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 776 deletions.
2 changes: 1 addition & 1 deletion examples/tools-streaming.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import { inspect } from 'util';
const client = new Anthropic();

async function main() {
const stream = client.beta.tools.messages
const stream = client.messages
.stream({
messages: [
{
Expand Down
10 changes: 5 additions & 5 deletions examples/tools.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ import assert from 'node:assert';
const client = new Anthropic(); // gets API Key from environment variable ANTHROPIC_API_KEY

async function main() {
const userMessage: Anthropic.Beta.Tools.ToolsBetaMessageParam = {
const userMessage: Anthropic.MessageParam = {
role: 'user',
content: 'What is the weather in SF?',
};
const tools: Anthropic.Beta.Tools.Tool[] = [
const tools: Anthropic.Tool[] = [
{
name: 'get_weather',
description: 'Get the weather for a specific location',
Expand All @@ -21,7 +21,7 @@ async function main() {
},
];

const message = await client.beta.tools.messages.create({
const message = await client.messages.create({
model: 'claude-3-opus-20240229',
max_tokens: 1024,
messages: [userMessage],
Expand All @@ -33,11 +33,11 @@ async function main() {
assert(message.stop_reason === 'tool_use');

const tool = message.content.find(
(content): content is Anthropic.Beta.Tools.ToolUseBlock => content.type === 'tool_use',
(content): content is Anthropic.ToolUseBlock => content.type === 'tool_use',
);
assert(tool);

const result = await client.beta.tools.messages.create({
const result = await client.messages.create({
model: 'claude-3-opus-20240229',
max_tokens: 1024,
messages: [
Expand Down
36 changes: 31 additions & 5 deletions src/lib/MessageStream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,18 @@ import {
MessageStreamEvent,
MessageParam,
MessageCreateParams,
MessageStreamParams,
MessageCreateParamsBase,
} from '@anthropic-ai/sdk/resources/messages';
import { type ReadableStream } from '@anthropic-ai/sdk/_shims/index';
import { Stream } from '@anthropic-ai/sdk/streaming';
import { TextBlock } from '@anthropic-ai/sdk/resources';
import { partialParse } from '../_vendor/partial-json-parser/parser';

export interface MessageStreamEvents {
connect: () => void;
streamEvent: (event: MessageStreamEvent, snapshot: Message) => void;
text: (textDelta: string, textSnapshot: string) => void;
inputJson: (jsonDelta: string, jsonSnapshot: unknown) => void;
message: (message: Message) => void;
contentBlock: (content: ContentBlock) => void;
finalMessage: (message: Message) => void;
Expand All @@ -29,6 +32,8 @@ type MessageStreamEventListeners<Event extends keyof MessageStreamEvents> = {
once?: boolean;
}[];

const JSON_BUF_PROPERTY = '__json_buf';

export class MessageStream implements AsyncIterable<MessageStreamEvent> {
messages: MessageParam[] = [];
receivedMessages: Message[] = [];
Expand Down Expand Up @@ -85,7 +90,7 @@ export class MessageStream implements AsyncIterable<MessageStreamEvent> {

static createMessage(
messages: Messages,
params: MessageStreamParams,
params: MessageCreateParamsBase,
options?: Core.RequestOptions,
): MessageStream {
const runner = new MessageStream();
Expand Down Expand Up @@ -264,7 +269,7 @@ export class MessageStream implements AsyncIterable<MessageStreamEvent> {
}
const textBlocks = this.receivedMessages
.at(-1)!
.content.filter((block) => block.type === 'text')
.content.filter((block): block is TextBlock => block.type === 'text')
.map((block) => block.text);
if (textBlocks.length === 0) {
throw new AnthropicError('stream ended without producing a content block with type=text');
Expand Down Expand Up @@ -369,8 +374,13 @@ export class MessageStream implements AsyncIterable<MessageStreamEvent> {

switch (event.type) {
case 'content_block_delta': {
if (event.delta.type === 'text_delta') {
this._emit('text', event.delta.text, messageSnapshot.content.at(-1)!.text || '');
const content = messageSnapshot.content.at(-1)!;
if (event.delta.type === 'text_delta' && content.type === 'text') {
this._emit('text', event.delta.text, content.text || '');
} else if (event.delta.type === 'input_json_delta' && content.type === 'tool_use') {
if (content.input) {
this._emit('inputJson', event.delta.partial_json, content.input);
}
}
break;
}
Expand Down Expand Up @@ -459,6 +469,22 @@ export class MessageStream implements AsyncIterable<MessageStreamEvent> {
const snapshotContent = snapshot.content.at(event.index);
if (snapshotContent?.type === 'text' && event.delta.type === 'text_delta') {
snapshotContent.text += event.delta.text;
} else if (snapshotContent?.type === 'tool_use' && event.delta.type === 'input_json_delta') {
// we need to keep track of the raw JSON string as well so that we can
// re-parse it for each delta, for now we just store it as an untyped
// non-enumerable property on the snapshot
let jsonBuf = (snapshotContent as any)[JSON_BUF_PROPERTY] || '';
jsonBuf += event.delta.partial_json;

Object.defineProperty(snapshotContent, JSON_BUF_PROPERTY, {
value: jsonBuf,
enumerable: false,
writable: true,
});

if (jsonBuf) {
snapshotContent.input = partialParse(jsonBuf);
}
}
return snapshot;
}
Expand Down
Loading

0 comments on commit 4c83bb1

Please sign in to comment.