Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(streaming): move to an iterator pattern #528

Merged
merged 3 commits into from
May 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions examples/messages_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,12 @@ async def main() -> None:
],
model="claude-3-opus-20240229",
) as stream:
async for text in stream.text_stream:
print(text, end="", flush=True)
async for event in stream:
if event.type == "text":
print(event.text, end="", flush=True)
elif event.type == "content_block_stop":
print()
print("\ncontent block finished accumulating:", event.content_block)
print()

# you can still get the accumulated final message outside of
Expand Down
32 changes: 0 additions & 32 deletions examples/messages_stream_handler.py

This file was deleted.

16 changes: 4 additions & 12 deletions examples/tools_stream.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,10 @@
import asyncio
from typing_extensions import override

from anthropic import AsyncAnthropic
from anthropic.lib.streaming.beta import AsyncToolsBetaMessageStream

client = AsyncAnthropic()


class MyHandler(AsyncToolsBetaMessageStream):
@override
async def on_input_json(self, delta: str, snapshot: object) -> None:
print(f"delta: {repr(delta)}")
print(f"snapshot: {snapshot}")
print()


async def main() -> None:
async with client.beta.tools.messages.stream(
max_tokens=1024,
Expand All @@ -38,9 +28,11 @@ async def main() -> None:
}
],
messages=[{"role": "user", "content": "What is the weather in SF?"}],
event_handler=MyHandler,
) as stream:
await stream.until_done()
async for event in stream:
if event.type == "input_json":
print(f"delta: {repr(event.partial_json)}")
print(f"snapshot: {event.snapshot}")

print()

Expand Down
98 changes: 42 additions & 56 deletions helpers.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ object for you).

The stream will be cancelled when the context manager exits but you can also close it prematurely by calling `stream.close()`.

See an example of streaming helpers in action in [`examples/messages_stream.py`](examples/messages_stream.py) and defining custom event handlers in [`examples/messages_stream_handler.py`](examples/messages_stream_handler.py)
See an example of streaming helpers in action in [`examples/messages_stream.py`](examples/messages_stream.py).

> [!NOTE]
> The synchronous client has the same interface just without `async/await`.
Expand All @@ -45,79 +45,65 @@ print()

### Events

You can pass an `event_handler` argument to `client.messages.stream` to register callback methods that are fired when certain events happen:
The events listed here are just the event types that the SDK extends, for a full list of the events returned by the API, see [these docs](https://docs.anthropic.com/en/api/messages-streaming#event-types).

```py
import asyncio
from typing_extensions import override

from anthropic import AsyncAnthropic, AsyncMessageStream
from anthropic.types import MessageStreamEvent
from anthropic import AsyncAnthropic

client = AsyncAnthropic()

class MyStream(AsyncMessageStream):
@override
async def on_text(self, text: str, snapshot: str) -> None:
print(text, end="", flush=True)

@override
async def on_stream_event(self, event: MessageStreamEvent) -> None:
print("on_event fired with:", event)

async def main() -> None:
async with client.messages.stream(
max_tokens=1024,
messages=[
{
"role": "user",
"content": "Say hello there!",
}
],
model="claude-3-opus-20240229",
event_handler=MyStream,
) as stream:
message = await stream.get_final_message()
print("accumulated message: ", message.to_json())

asyncio.run(main())
```

#### `await on_stream_event(event: MessageStreamEvent)`

The event is fired when an event is received from the API.

#### `await on_message(message: Message)`

The event is fired when a full Message object has been accumulated. This corresponds to the `message_stop` SSE.
async with client.messages.stream(
max_tokens=1024,
messages=[
{
"role": "user",
"content": "Say hello there!",
}
],
model="claude-3-opus-20240229",
) as stream:
async for event in stream:
if event.type == "text":
print(event.text, end="", flush=True)
elif event.type == 'content_block_stop':
print('\n\ncontent block finished accumulating:', event.content_block)

#### `await on_content_block(content_block: ContentBlock)`
print()

The event is fired when a full ContentBlock object has been accumulated. This corresponds to the `content_block_stop` SSE.
# you can still get the accumulated final message outside of
# the context manager, as long as the entire stream was consumed
# inside of the context manager
accumulated = await stream.get_final_message()
print("accumulated message: ", accumulated.to_json())
```

#### `await on_text(text: str, snapshot: str)`
#### `text`

The event is fired when a `text` ContentBlock object is being accumulated. The first argument is the text delta and the second is the current accumulated text, for example:
This event is yielded whenever a `content_block_delta` event is returned by the API & includes the delta and the accumulated snapshot, e.g.

```py
on_text('Hello', 'Hello')
on_text(' there', 'Hello there')
on_text('!', 'Hello there!')
if event.type == "text":
event.text # " there"
event.snapshot # "Hello, there"
```

This corresponds to the `content_block_delta` SSE.

#### `await on_exception(exception: Exception)`
#### `message_stop`

The event is fired when an exception is encountered while streaming the response.
The event is fired when a full Message object has been accumulated.

#### `await on_timeout()`
```py
if event.type == "message_stop":
event.message # Message
```

The event is fired when the request times out.
#### `content_block_stop`

#### `await on_end()`
The event is fired when a full ContentBlock object has been accumulated.

The last event fired in the stream.
```py
if event.type == "content_block_stop":
event.content_block # ContentBlock
```

### Methods

Expand Down
6 changes: 6 additions & 0 deletions src/anthropic/lib/streaming/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
from ._types import (
TextEvent as TextEvent,
MessageStopEvent as MessageStopEvent,
MessageStreamEvent as MessageStreamEvent,
ContentBlockStopEvent as ContentBlockStopEvent,
)
from ._messages import (
MessageStream as MessageStream,
MessageStreamT as MessageStreamT,
Expand Down
Loading