diff --git a/examples/messages_stream.py b/examples/messages_stream.py index 523c485e..be69a2c1 100644 --- a/examples/messages_stream.py +++ b/examples/messages_stream.py @@ -16,8 +16,12 @@ async def main() -> None: ], model="claude-3-opus-20240229", ) as stream: - async for text in stream.text_stream: - print(text, end="", flush=True) + async for event in stream: + if event.type == "text": + print(event.text, end="", flush=True) + elif event.type == "content_block_stop": + print() + print("\ncontent block finished accumulating:", event.content_block) print() # you can still get the accumulated final message outside of diff --git a/examples/messages_stream_handler.py b/examples/messages_stream_handler.py deleted file mode 100644 index 6bf98dab..00000000 --- a/examples/messages_stream_handler.py +++ /dev/null @@ -1,32 +0,0 @@ -import asyncio -from typing_extensions import override - -from anthropic import AsyncAnthropic, AsyncMessageStream -from anthropic.types import MessageStreamEvent - -client = AsyncAnthropic() - - -class MyStream(AsyncMessageStream): - @override - async def on_stream_event(self, event: MessageStreamEvent) -> None: - print("on_event fired with:", event) - - -async def main() -> None: - async with client.messages.stream( - max_tokens=1024, - messages=[ - { - "role": "user", - "content": "Say hello there!", - } - ], - model="claude-3-opus-20240229", - event_handler=MyStream, - ) as stream: - accumulated = await stream.get_final_message() - print("accumulated message: ", accumulated.to_json()) - - -asyncio.run(main()) diff --git a/examples/tools_stream.py b/examples/tools_stream.py index 5b0c1927..28aa6d5a 100644 --- a/examples/tools_stream.py +++ b/examples/tools_stream.py @@ -1,20 +1,10 @@ import asyncio -from typing_extensions import override from anthropic import AsyncAnthropic -from anthropic.lib.streaming.beta import AsyncToolsBetaMessageStream client = AsyncAnthropic() -class MyHandler(AsyncToolsBetaMessageStream): - @override - async def on_input_json(self, delta: str, snapshot: object) -> None: - print(f"delta: {repr(delta)}") - print(f"snapshot: {snapshot}") - print() - - async def main() -> None: async with client.beta.tools.messages.stream( max_tokens=1024, @@ -38,9 +28,11 @@ async def main() -> None: } ], messages=[{"role": "user", "content": "What is the weather in SF?"}], - event_handler=MyHandler, ) as stream: - await stream.until_done() + async for event in stream: + if event.type == "input_json": + print(f"delta: {repr(event.partial_json)}") + print(f"snapshot: {event.snapshot}") print() diff --git a/helpers.md b/helpers.md index ef3d4817..399a9bba 100644 --- a/helpers.md +++ b/helpers.md @@ -26,7 +26,7 @@ object for you). The stream will be cancelled when the context manager exits but you can also close it prematurely by calling `stream.close()`. -See an example of streaming helpers in action in [`examples/messages_stream.py`](examples/messages_stream.py) and defining custom event handlers in [`examples/messages_stream_handler.py`](examples/messages_stream_handler.py) +See an example of streaming helpers in action in [`examples/messages_stream.py`](examples/messages_stream.py). > [!NOTE] > The synchronous client has the same interface just without `async/await`. @@ -45,79 +45,65 @@ print() ### Events -You can pass an `event_handler` argument to `client.messages.stream` to register callback methods that are fired when certain events happen: +The events listed here are just the event types that the SDK extends, for a full list of the events returned by the API, see [these docs](https://docs.anthropic.com/en/api/messages-streaming#event-types). ```py -import asyncio -from typing_extensions import override - -from anthropic import AsyncAnthropic, AsyncMessageStream -from anthropic.types import MessageStreamEvent +from anthropic import AsyncAnthropic client = AsyncAnthropic() -class MyStream(AsyncMessageStream): - @override - async def on_text(self, text: str, snapshot: str) -> None: - print(text, end="", flush=True) - - @override - async def on_stream_event(self, event: MessageStreamEvent) -> None: - print("on_event fired with:", event) - -async def main() -> None: - async with client.messages.stream( - max_tokens=1024, - messages=[ - { - "role": "user", - "content": "Say hello there!", - } - ], - model="claude-3-opus-20240229", - event_handler=MyStream, - ) as stream: - message = await stream.get_final_message() - print("accumulated message: ", message.to_json()) - -asyncio.run(main()) -``` - -#### `await on_stream_event(event: MessageStreamEvent)` - -The event is fired when an event is received from the API. - -#### `await on_message(message: Message)` - -The event is fired when a full Message object has been accumulated. This corresponds to the `message_stop` SSE. +async with client.messages.stream( + max_tokens=1024, + messages=[ + { + "role": "user", + "content": "Say hello there!", + } + ], + model="claude-3-opus-20240229", +) as stream: + async for event in stream: + if event.type == "text": + print(event.text, end="", flush=True) + elif event.type == 'content_block_stop': + print('\n\ncontent block finished accumulating:', event.content_block) -#### `await on_content_block(content_block: ContentBlock)` + print() -The event is fired when a full ContentBlock object has been accumulated. This corresponds to the `content_block_stop` SSE. +# you can still get the accumulated final message outside of +# the context manager, as long as the entire stream was consumed +# inside of the context manager +accumulated = await stream.get_final_message() +print("accumulated message: ", accumulated.to_json()) +``` -#### `await on_text(text: str, snapshot: str)` +#### `text` -The event is fired when a `text` ContentBlock object is being accumulated. The first argument is the text delta and the second is the current accumulated text, for example: +This event is yielded whenever a `content_block_delta` event is returned by the API & includes the delta and the accumulated snapshot, e.g. ```py -on_text('Hello', 'Hello') -on_text(' there', 'Hello there') -on_text('!', 'Hello there!') +if event.type == "text": + event.text # " there" + event.snapshot # "Hello, there" ``` -This corresponds to the `content_block_delta` SSE. - -#### `await on_exception(exception: Exception)` +#### `message_stop` -The event is fired when an exception is encountered while streaming the response. +The event is fired when a full Message object has been accumulated. -#### `await on_timeout()` +```py +if event.type == "message_stop": + event.message # Message +``` -The event is fired when the request times out. +#### `content_block_stop` -#### `await on_end()` +The event is fired when a full ContentBlock object has been accumulated. -The last event fired in the stream. +```py +if event.type == "content_block_stop": + event.content_block # ContentBlock +``` ### Methods diff --git a/src/anthropic/lib/streaming/__init__.py b/src/anthropic/lib/streaming/__init__.py index 71c5efd6..a9329d5d 100644 --- a/src/anthropic/lib/streaming/__init__.py +++ b/src/anthropic/lib/streaming/__init__.py @@ -1,3 +1,9 @@ +from ._types import ( + TextEvent as TextEvent, + MessageStopEvent as MessageStopEvent, + MessageStreamEvent as MessageStreamEvent, + ContentBlockStopEvent as ContentBlockStopEvent, +) from ._messages import ( MessageStream as MessageStream, MessageStreamT as MessageStreamT, diff --git a/src/anthropic/lib/streaming/_messages.py b/src/anthropic/lib/streaming/_messages.py index b41e6df7..0a49c819 100644 --- a/src/anthropic/lib/streaming/_messages.py +++ b/src/anthropic/lib/streaming/_messages.py @@ -7,7 +7,8 @@ import httpx -from ...types import Message, ContentBlock, MessageStreamEvent +from ._types import TextEvent, MessageStopEvent, MessageStreamEvent, ContentBlockStopEvent +from ...types import Message, ContentBlock, RawMessageStreamEvent from ..._utils import consume_sync_iterator, consume_async_iterator from ..._streaming import Stream, AsyncStream @@ -31,7 +32,7 @@ class MessageStream: def __init__( self, *, - cast_to: type[MessageStreamEvent], + cast_to: type[RawMessageStreamEvent], response: httpx.Response, client: Anthropic, ) -> None: @@ -43,7 +44,7 @@ def __init__( self.__final_message_snapshot: Message | None = None self._iterator = self.__stream__() - self._raw_stream: Stream[MessageStreamEvent] = Stream(cast_to=cast_to, response=response, client=client) + self._raw_stream: Stream[RawMessageStreamEvent] = Stream(cast_to=cast_to, response=response, client=client) def __next__(self) -> MessageStreamEvent: return self._iterator.__next__() @@ -110,7 +111,7 @@ def current_message_snapshot(self) -> Message: return self.__final_message_snapshot # event handlers - def on_stream_event(self, event: MessageStreamEvent) -> None: + def on_stream_event(self, event: RawMessageStreamEvent) -> None: """Callback that is fired for every Server-Sent-Event""" def on_message(self, message: Message) -> None: @@ -154,9 +155,10 @@ def __stream__(self) -> Iterator[MessageStreamEvent]: event=sse_event, current_snapshot=self.__final_message_snapshot, ) - self._emit_sse_event(sse_event) - yield sse_event + events_to_fire = self._emit_sse_event(sse_event) + for event in events_to_fire: + yield event except (httpx.TimeoutException, asyncio.TimeoutError) as exc: self.on_timeout() self.on_exception(exc) @@ -172,32 +174,47 @@ def __stream_text__(self) -> Iterator[str]: if chunk.type == "content_block_delta" and chunk.delta.type == "text_delta": yield chunk.delta.text - def _emit_sse_event(self, event: MessageStreamEvent) -> None: + def _emit_sse_event(self, event: RawMessageStreamEvent) -> list[MessageStreamEvent]: self.on_stream_event(event) + events_to_fire: list[MessageStreamEvent] = [] + if event.type == "message_start": - # nothing special we want to fire here - pass + events_to_fire.append(event) elif event.type == "message_delta": - # nothing special we want to fire here - pass + events_to_fire.append(event) elif event.type == "message_stop": self.on_message(self.current_message_snapshot) + events_to_fire.append(MessageStopEvent(type="message_stop", message=self.current_message_snapshot)) elif event.type == "content_block_start": - # nothing special we want to fire here - pass + events_to_fire.append(event) elif event.type == "content_block_delta": - content = self.current_message_snapshot.content[event.index] - if event.delta.type == "text_delta" and content.type == "text": - self.on_text(event.delta.text, content.text) + events_to_fire.append(event) + + content_block = self.current_message_snapshot.content[event.index] + if event.delta.type == "text_delta" and content_block.type == "text": + self.on_text(event.delta.text, content_block.text) + events_to_fire.append( + TextEvent( + type="text", + text=event.delta.text, + snapshot=content_block.text, + ) + ) elif event.type == "content_block_stop": - content = self.current_message_snapshot.content[event.index] - self.on_content_block(content) + content_block = self.current_message_snapshot.content[event.index] + self.on_content_block(content_block) + + events_to_fire.append( + ContentBlockStopEvent(type="content_block_stop", index=event.index, content_block=content_block), + ) else: # we only want exhaustive checking for linters, not at runtime if TYPE_CHECKING: # type: ignore[unreachable] assert_never(event) + return events_to_fire + MessageStreamT = TypeVar("MessageStreamT", bound=MessageStream) @@ -213,7 +230,7 @@ class MessageStreamManager(Generic[MessageStreamT]): """ def __init__( - self, api_request: Callable[[], Stream[MessageStreamEvent]], event_handler_cls: type[MessageStreamT] + self, api_request: Callable[[], Stream[RawMessageStreamEvent]], event_handler_cls: type[MessageStreamT] ) -> None: self.__event_handler: MessageStreamT | None = None self.__event_handler_cls: type[MessageStreamT] = event_handler_cls @@ -256,7 +273,7 @@ class AsyncMessageStream: def __init__( self, *, - cast_to: type[MessageStreamEvent], + cast_to: type[RawMessageStreamEvent], response: httpx.Response, client: AsyncAnthropic, ) -> None: @@ -268,7 +285,7 @@ def __init__( self.__final_message_snapshot: Message | None = None self._iterator = self.__stream__() - self._raw_stream: AsyncStream[MessageStreamEvent] = AsyncStream( + self._raw_stream: AsyncStream[RawMessageStreamEvent] = AsyncStream( cast_to=cast_to, response=response, client=client ) @@ -337,7 +354,7 @@ def current_message_snapshot(self) -> Message: return self.__final_message_snapshot # event handlers - async def on_stream_event(self, event: MessageStreamEvent) -> None: + async def on_stream_event(self, event: RawMessageStreamEvent) -> None: """Callback that is fired for every Server-Sent-Event""" async def on_message(self, message: Message) -> None: @@ -387,9 +404,10 @@ async def __stream__(self) -> AsyncIterator[MessageStreamEvent]: event=sse_event, current_snapshot=self.__final_message_snapshot, ) - await self._emit_sse_event(sse_event) - yield sse_event + events_to_fire = await self._emit_sse_event(sse_event) + for event in events_to_fire: + yield event except (httpx.TimeoutException, asyncio.TimeoutError) as exc: await self.on_timeout() await self.on_exception(exc) @@ -405,35 +423,47 @@ async def __stream_text__(self) -> AsyncIterator[str]: if chunk.type == "content_block_delta" and chunk.delta.type == "text_delta": yield chunk.delta.text - async def _emit_sse_event(self, event: MessageStreamEvent) -> None: + async def _emit_sse_event(self, event: RawMessageStreamEvent) -> list[MessageStreamEvent]: await self.on_stream_event(event) + events_to_fire: list[MessageStreamEvent] = [] + if event.type == "message_start": - # nothing special we want to fire here - pass + events_to_fire.append(event) elif event.type == "message_delta": - # nothing special we want to fire here - pass + events_to_fire.append(event) elif event.type == "message_stop": await self.on_message(self.current_message_snapshot) + events_to_fire.append(MessageStopEvent(type="message_stop", message=self.current_message_snapshot)) elif event.type == "content_block_start": - # nothing special we want to fire here - pass + events_to_fire.append(event) elif event.type == "content_block_delta": - content = self.current_message_snapshot.content[event.index] - if event.delta.type == "text_delta" and content.type == "text": - await self.on_text(event.delta.text, content.text) + events_to_fire.append(event) + + content_block = self.current_message_snapshot.content[event.index] + if event.delta.type == "text_delta" and content_block.type == "text": + await self.on_text(event.delta.text, content_block.text) + events_to_fire.append( + TextEvent( + type="text", + text=event.delta.text, + snapshot=content_block.text, + ) + ) elif event.type == "content_block_stop": - content = self.current_message_snapshot.content[event.index] - await self.on_content_block(content) + content_block = self.current_message_snapshot.content[event.index] + await self.on_content_block(content_block) - if content.type == "text": - await self.on_final_text(content.text) + events_to_fire.append( + ContentBlockStopEvent(type="content_block_stop", index=event.index, content_block=content_block), + ) else: # we only want exhaustive checking for linters, not at runtime if TYPE_CHECKING: # type: ignore[unreachable] assert_never(event) + return events_to_fire + AsyncMessageStreamT = TypeVar("AsyncMessageStreamT", bound=AsyncMessageStream) @@ -451,7 +481,7 @@ class AsyncMessageStreamManager(Generic[AsyncMessageStreamT]): """ def __init__( - self, api_request: Awaitable[AsyncStream[MessageStreamEvent]], event_handler_cls: type[AsyncMessageStreamT] + self, api_request: Awaitable[AsyncStream[RawMessageStreamEvent]], event_handler_cls: type[AsyncMessageStreamT] ) -> None: self.__event_handler: AsyncMessageStreamT | None = None self.__event_handler_cls: type[AsyncMessageStreamT] = event_handler_cls @@ -478,7 +508,7 @@ async def __aexit__( await self.__event_handler.close() -def accumulate_event(*, event: MessageStreamEvent, current_snapshot: Message | None) -> Message: +def accumulate_event(*, event: RawMessageStreamEvent, current_snapshot: Message | None) -> Message: if current_snapshot is None: if event.type == "message_start": return event.message diff --git a/src/anthropic/lib/streaming/_types.py b/src/anthropic/lib/streaming/_types.py new file mode 100644 index 00000000..ad19ae93 --- /dev/null +++ b/src/anthropic/lib/streaming/_types.py @@ -0,0 +1,47 @@ +from typing import Union +from typing_extensions import Literal + +from ...types import ( + Message, + ContentBlock, + MessageDeltaEvent as RawMessageDeltaEvent, + MessageStartEvent as RawMessageStartEvent, + RawMessageStopEvent, + ContentBlockDeltaEvent as RawContentBlockDeltaEvent, + ContentBlockStartEvent as RawContentBlockStartEvent, + RawContentBlockStopEvent, +) +from ..._models import BaseModel + + +class TextEvent(BaseModel): + type: Literal["text"] + + text: str + """The text delta""" + + snapshot: str + """The entire accumulated text""" + + +class MessageStopEvent(RawMessageStopEvent): + type: Literal["message_stop"] + + message: Message + + +class ContentBlockStopEvent(RawContentBlockStopEvent): + type: Literal["content_block_stop"] + + content_block: ContentBlock + + +MessageStreamEvent = Union[ + TextEvent, + RawMessageStartEvent, + RawMessageDeltaEvent, + MessageStopEvent, + RawContentBlockStartEvent, + RawContentBlockDeltaEvent, + ContentBlockStopEvent, +] diff --git a/src/anthropic/lib/streaming/beta/__init__.py b/src/anthropic/lib/streaming/beta/__init__.py index 6fd08cdb..83cefcd5 100644 --- a/src/anthropic/lib/streaming/beta/__init__.py +++ b/src/anthropic/lib/streaming/beta/__init__.py @@ -6,3 +6,9 @@ ToolsBetaMessageStreamManager as ToolsBetaMessageStreamManager, AsyncToolsBetaMessageStreamManager as AsyncToolsBetaMessageStreamManager, ) +from ._types import ( + ToolsBetaInputJsonEvent as ToolsBetaInputJsonEvent, + ToolsBetaMessageStopEvent as ToolsBetaMessageStopEvent, + ToolsBetaMessageStreamEvent as ToolsBetaMessageStreamEvent, + ToolsBetaContentBlockStopEvent as ToolsBetaContentBlockStopEvent, +) diff --git a/src/anthropic/lib/streaming/beta/_tools.py b/src/anthropic/lib/streaming/beta/_tools.py index fab77a63..0a43618d 100644 --- a/src/anthropic/lib/streaming/beta/_tools.py +++ b/src/anthropic/lib/streaming/beta/_tools.py @@ -3,20 +3,27 @@ import asyncio from types import TracebackType from typing import TYPE_CHECKING, Any, Generic, TypeVar, Callable, cast -from typing_extensions import Iterator, Awaitable, AsyncIterator, override, assert_never +from typing_extensions import Self, Iterator, Awaitable, AsyncIterator, assert_never import httpx +from ._types import ( + TextEvent, + ToolsBetaInputJsonEvent, + ToolsBetaMessageStopEvent, + ToolsBetaMessageStreamEvent, + ToolsBetaContentBlockStopEvent, +) from ...._utils import consume_sync_iterator, consume_async_iterator from ...._models import construct_type from ...._streaming import Stream, AsyncStream -from ....types.beta.tools import ToolsBetaMessage, ToolsBetaContentBlock, ToolsBetaMessageStreamEvent +from ....types.beta.tools import ToolsBetaMessage, ToolsBetaContentBlock, RawToolsBetaMessageStreamEvent if TYPE_CHECKING: from ...._client import Anthropic, AsyncAnthropic -class ToolsBetaMessageStream(Stream[ToolsBetaMessageStreamEvent]): +class ToolsBetaMessageStream: text_stream: Iterator[str] """Iterator over just the text deltas in the stream. @@ -27,18 +34,53 @@ class ToolsBetaMessageStream(Stream[ToolsBetaMessageStreamEvent]): ``` """ + response: httpx.Response + def __init__( self, *, - cast_to: type[ToolsBetaMessageStreamEvent], + cast_to: type[RawToolsBetaMessageStreamEvent], response: httpx.Response, client: Anthropic, ) -> None: - super().__init__(cast_to=cast_to, response=response, client=client) + self.response = response + self._cast_to = cast_to + self._client = client self.text_stream = self.__stream_text__() self.__final_message_snapshot: ToolsBetaMessage | None = None - self.__events: list[ToolsBetaMessageStreamEvent] = [] + + self._iterator = self.__stream__() + self._raw_stream: Stream[RawToolsBetaMessageStreamEvent] = Stream( + cast_to=cast_to, response=response, client=client + ) + + def __next__(self) -> ToolsBetaMessageStreamEvent: + return self._iterator.__next__() + + def __iter__(self) -> Iterator[ToolsBetaMessageStreamEvent]: + for item in self._iterator: + yield item + + def __enter__(self) -> Self: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + self.close() + + def close(self) -> None: + """ + Close the response and release the connection. + + Automatically called if the response body is read to completion. + """ + self.response.close() + self.on_end() def get_final_message(self) -> ToolsBetaMessage: """Waits until the stream has been read to completion and returns @@ -71,11 +113,6 @@ def until_done(self) -> None: """Blocks until the stream has been consumed""" consume_sync_iterator(self) - @override - def close(self) -> None: - super().close() - self.on_end() - # properties @property def current_message_snapshot(self) -> ToolsBetaMessage: @@ -83,7 +120,7 @@ def current_message_snapshot(self) -> ToolsBetaMessage: return self.__final_message_snapshot # event handlers - def on_stream_event(self, event: ToolsBetaMessageStreamEvent) -> None: + def on_stream_event(self, event: RawToolsBetaMessageStreamEvent) -> None: """Callback that is fired for every Server-Sent-Event""" def on_message(self, message: ToolsBetaMessage) -> None: @@ -132,19 +169,17 @@ def on_end(self) -> None: def on_timeout(self) -> None: """Fires if the request times out""" - @override def __stream__(self) -> Iterator[ToolsBetaMessageStreamEvent]: try: - for event in super().__stream__(): - self.__events.append(event) - + for sse_event in self._raw_stream: self.__final_message_snapshot = accumulate_event( - event=event, + event=sse_event, current_snapshot=self.__final_message_snapshot, ) - self._emit_sse_event(event) - yield event + events_to_fire = self._emit_sse_event(sse_event) + for event in events_to_fire: + yield event except (httpx.TimeoutException, asyncio.TimeoutError) as exc: self.on_timeout() self.on_exception(exc) @@ -160,34 +195,58 @@ def __stream_text__(self) -> Iterator[str]: if chunk.type == "content_block_delta" and chunk.delta.type == "text_delta": yield chunk.delta.text - def _emit_sse_event(self, event: ToolsBetaMessageStreamEvent) -> None: + def _emit_sse_event(self, event: RawToolsBetaMessageStreamEvent) -> list[ToolsBetaMessageStreamEvent]: self.on_stream_event(event) + events_to_fire: list[ToolsBetaMessageStreamEvent] = [] + if event.type == "message_start": - # nothing special we want to fire here - pass + events_to_fire.append(event) elif event.type == "message_delta": - # nothing special we want to fire here - pass + events_to_fire.append(event) elif event.type == "message_stop": self.on_message(self.current_message_snapshot) + events_to_fire.append(ToolsBetaMessageStopEvent(type="message_stop", message=self.current_message_snapshot)) elif event.type == "content_block_start": - # nothing special we want to fire here - pass + events_to_fire.append(event) elif event.type == "content_block_delta": - content = self.current_message_snapshot.content[event.index] - if event.delta.type == "text_delta" and content.type == "text": - self.on_text(event.delta.text, content.text) - elif event.delta.type == "input_json_delta" and content.type == "tool_use": - self.on_input_json(event.delta.partial_json, content.input) + events_to_fire.append(event) + + content_block = self.current_message_snapshot.content[event.index] + if event.delta.type == "text_delta" and content_block.type == "text": + self.on_text(event.delta.text, content_block.text) + events_to_fire.append( + TextEvent( + type="text", + text=event.delta.text, + snapshot=content_block.text, + ) + ) + elif event.delta.type == "input_json_delta" and content_block.type == "tool_use": + self.on_input_json(event.delta.partial_json, content_block.input) + events_to_fire.append( + ToolsBetaInputJsonEvent( + type="input_json", + partial_json=event.delta.partial_json, + snapshot=content_block.input, + ) + ) elif event.type == "content_block_stop": - content = self.current_message_snapshot.content[event.index] - self.on_content_block(content) + content_block = self.current_message_snapshot.content[event.index] + self.on_content_block(content_block) + + events_to_fire.append( + ToolsBetaContentBlockStopEvent( + type="content_block_stop", index=event.index, content_block=content_block + ), + ) else: # we only want exhaustive checking for linters, not at runtime if TYPE_CHECKING: # type: ignore[unreachable] assert_never(event) + return events_to_fire + ToolsBetaMessageStreamT = TypeVar("ToolsBetaMessageStreamT", bound=ToolsBetaMessageStream) @@ -202,13 +261,25 @@ class ToolsBetaMessageStreamManager(Generic[ToolsBetaMessageStreamT]): ``` """ - def __init__(self, api_request: Callable[[], ToolsBetaMessageStreamT]) -> None: - self.__stream: ToolsBetaMessageStreamT | None = None + def __init__( + self, + api_request: Callable[[], Stream[RawToolsBetaMessageStreamEvent]], + event_handler_cls: type[ToolsBetaMessageStreamT], + ) -> None: + self.__event_handler: ToolsBetaMessageStreamT | None = None + self.__event_handler_cls: type[ToolsBetaMessageStreamT] = event_handler_cls self.__api_request = api_request def __enter__(self) -> ToolsBetaMessageStreamT: - self.__stream = self.__api_request() - return self.__stream + raw_stream = self.__api_request() + + self.__event_handler = self.__event_handler_cls( + cast_to=raw_stream._cast_to, + response=raw_stream.response, + client=raw_stream._client, + ) + + return self.__event_handler def __exit__( self, @@ -216,11 +287,11 @@ def __exit__( exc: BaseException | None, exc_tb: TracebackType | None, ) -> None: - if self.__stream is not None: - self.__stream.close() + if self.__event_handler is not None: + self.__event_handler.close() -class AsyncToolsBetaMessageStream(AsyncStream[ToolsBetaMessageStreamEvent]): +class AsyncToolsBetaMessageStream: text_stream: AsyncIterator[str] """Async iterator over just the text deltas in the stream. @@ -231,18 +302,53 @@ class AsyncToolsBetaMessageStream(AsyncStream[ToolsBetaMessageStreamEvent]): ``` """ + response: httpx.Response + def __init__( self, *, - cast_to: type[ToolsBetaMessageStreamEvent], + cast_to: type[RawToolsBetaMessageStreamEvent], response: httpx.Response, client: AsyncAnthropic, ) -> None: - super().__init__(cast_to=cast_to, response=response, client=client) + self.response = response + self._cast_to = cast_to + self._client = client self.text_stream = self.__stream_text__() self.__final_message_snapshot: ToolsBetaMessage | None = None - self.__events: list[ToolsBetaMessageStreamEvent] = [] + + self._iterator = self.__stream__() + self._raw_stream: AsyncStream[RawToolsBetaMessageStreamEvent] = AsyncStream( + cast_to=cast_to, response=response, client=client + ) + + async def __anext__(self) -> ToolsBetaMessageStreamEvent: + return await self._iterator.__anext__() + + async def __aiter__(self) -> AsyncIterator[ToolsBetaMessageStreamEvent]: + async for item in self._iterator: + yield item + + async def __aenter__(self) -> Self: + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + await self.close() + + async def close(self) -> None: + """ + Close the response and release the connection. + + Automatically called if the response body is read to completion. + """ + await self.response.aclose() + await self.on_end() async def get_final_message(self) -> ToolsBetaMessage: """Waits until the stream has been read to completion and returns @@ -275,11 +381,6 @@ async def until_done(self) -> None: """Waits until the stream has been consumed""" await consume_async_iterator(self) - @override - async def close(self) -> None: - await super().close() - await self.on_end() - # properties @property def current_message_snapshot(self) -> ToolsBetaMessage: @@ -287,7 +388,7 @@ def current_message_snapshot(self) -> ToolsBetaMessage: return self.__final_message_snapshot # event handlers - async def on_stream_event(self, event: ToolsBetaMessageStreamEvent) -> None: + async def on_stream_event(self, event: RawToolsBetaMessageStreamEvent) -> None: """Callback that is fired for every Server-Sent-Event""" async def on_message(self, message: ToolsBetaMessage) -> None: @@ -342,19 +443,17 @@ async def on_end(self) -> None: async def on_timeout(self) -> None: """Fires if the request times out""" - @override async def __stream__(self) -> AsyncIterator[ToolsBetaMessageStreamEvent]: try: - async for event in super().__stream__(): - self.__events.append(event) - + async for sse_event in self._raw_stream: self.__final_message_snapshot = accumulate_event( - event=event, + event=sse_event, current_snapshot=self.__final_message_snapshot, ) - await self._emit_sse_event(event) - yield event + events_to_fire = await self._emit_sse_event(sse_event) + for event in events_to_fire: + yield event except (httpx.TimeoutException, asyncio.TimeoutError) as exc: await self.on_timeout() await self.on_exception(exc) @@ -370,40 +469,61 @@ async def __stream_text__(self) -> AsyncIterator[str]: if chunk.type == "content_block_delta" and chunk.delta.type == "text_delta": yield chunk.delta.text - async def _emit_sse_event(self, event: ToolsBetaMessageStreamEvent) -> None: + async def _emit_sse_event(self, event: RawToolsBetaMessageStreamEvent) -> list[ToolsBetaMessageStreamEvent]: await self.on_stream_event(event) + events_to_fire: list[ToolsBetaMessageStreamEvent] = [] + if event.type == "message_start": - # nothing special we want to fire here - pass + events_to_fire.append(event) elif event.type == "message_delta": - # nothing special we want to fire here - pass + events_to_fire.append(event) elif event.type == "message_stop": await self.on_message(self.current_message_snapshot) + events_to_fire.append(ToolsBetaMessageStopEvent(type="message_stop", message=self.current_message_snapshot)) elif event.type == "content_block_start": - # nothing special we want to fire here - pass + events_to_fire.append(event) elif event.type == "content_block_delta": - content = self.current_message_snapshot.content[event.index] - if event.delta.type == "text_delta" and content.type == "text": - await self.on_text(event.delta.text, content.text) - elif event.delta.type == "input_json_delta" and content.type == "tool_use": - await self.on_input_json(event.delta.partial_json, content.input) - else: - # TODO: warn? - pass + events_to_fire.append(event) + + content_block = self.current_message_snapshot.content[event.index] + if event.delta.type == "text_delta" and content_block.type == "text": + await self.on_text(event.delta.text, content_block.text) + events_to_fire.append( + TextEvent( + type="text", + text=event.delta.text, + snapshot=content_block.text, + ) + ) + elif event.delta.type == "input_json_delta" and content_block.type == "tool_use": + await self.on_input_json(event.delta.partial_json, content_block.input) + events_to_fire.append( + ToolsBetaInputJsonEvent( + type="input_json", + partial_json=event.delta.partial_json, + snapshot=content_block.input, + ) + ) elif event.type == "content_block_stop": - content = self.current_message_snapshot.content[event.index] - await self.on_content_block(content) + content_block = self.current_message_snapshot.content[event.index] + await self.on_content_block(content_block) - if content.type == "text": - await self.on_final_text(content.text) + if content_block.type == "text": + await self.on_final_text(content_block.text) + + events_to_fire.append( + ToolsBetaContentBlockStopEvent( + type="content_block_stop", index=event.index, content_block=content_block + ), + ) else: # we only want exhaustive checking for linters, not at runtime if TYPE_CHECKING: # type: ignore[unreachable] assert_never(event) + return events_to_fire + AsyncToolsBetaMessageStreamT = TypeVar("AsyncToolsBetaMessageStreamT", bound=AsyncToolsBetaMessageStream) @@ -420,13 +540,25 @@ class AsyncToolsBetaMessageStreamManager(Generic[AsyncToolsBetaMessageStreamT]): ``` """ - def __init__(self, api_request: Awaitable[AsyncToolsBetaMessageStreamT]) -> None: - self.__stream: AsyncToolsBetaMessageStreamT | None = None + def __init__( + self, + api_request: Awaitable[AsyncStream[RawToolsBetaMessageStreamEvent]], + event_handler_cls: type[AsyncToolsBetaMessageStreamT], + ) -> None: + self.__event_handler: AsyncToolsBetaMessageStreamT | None = None + self.__event_handler_cls: type[AsyncToolsBetaMessageStreamT] = event_handler_cls self.__api_request = api_request async def __aenter__(self) -> AsyncToolsBetaMessageStreamT: - self.__stream = await self.__api_request - return self.__stream + raw_stream = await self.__api_request + + self.__event_handler = self.__event_handler_cls( + cast_to=raw_stream._cast_to, + response=raw_stream.response, + client=raw_stream._client, + ) + + return self.__event_handler async def __aexit__( self, @@ -434,8 +566,8 @@ async def __aexit__( exc: BaseException | None, exc_tb: TracebackType | None, ) -> None: - if self.__stream is not None: - await self.__stream.close() + if self.__event_handler is not None: + await self.__event_handler.close() JSON_BUF_PROPERTY = "__json_buf" @@ -443,7 +575,7 @@ async def __aexit__( def accumulate_event( *, - event: ToolsBetaMessageStreamEvent, + event: RawToolsBetaMessageStreamEvent, current_snapshot: ToolsBetaMessage | None, ) -> ToolsBetaMessage: if current_snapshot is None: diff --git a/src/anthropic/lib/streaming/beta/_types.py b/src/anthropic/lib/streaming/beta/_types.py new file mode 100644 index 00000000..3fd3555d --- /dev/null +++ b/src/anthropic/lib/streaming/beta/_types.py @@ -0,0 +1,53 @@ +from typing import Union +from typing_extensions import Literal + +from .._types import TextEvent +from ....types import RawMessageStopEvent, RawMessageDeltaEvent, RawMessageStartEvent, RawContentBlockStopEvent +from ...._models import BaseModel +from ....types.beta.tools import ( + ToolsBetaMessage, + ToolsBetaContentBlock, + RawToolsBetaContentBlockDeltaEvent, + RawToolsBetaContentBlockStartEvent, +) + + +class ToolsBetaMessageStopEvent(RawMessageStopEvent): + type: Literal["message_stop"] + + message: ToolsBetaMessage + + +class ToolsBetaContentBlockStopEvent(RawContentBlockStopEvent): + type: Literal["content_block_stop"] + + content_block: ToolsBetaContentBlock + + +class ToolsBetaInputJsonEvent(BaseModel): + type: Literal["input_json"] + + partial_json: str + """A partial JSON string delta + + e.g. `'"San Francisco,'` + """ + + snapshot: object + """The currently accumulated parsed object. + + + e.g. `{'location': 'San Francisco, CA'}` + """ + + +ToolsBetaMessageStreamEvent = Union[ + TextEvent, + RawMessageStartEvent, + RawMessageDeltaEvent, + ToolsBetaMessageStopEvent, + RawToolsBetaContentBlockDeltaEvent, + RawToolsBetaContentBlockStartEvent, + ToolsBetaInputJsonEvent, + ToolsBetaContentBlockStopEvent, +] diff --git a/src/anthropic/resources/beta/tools/messages.py b/src/anthropic/resources/beta/tools/messages.py index aa211c4f..eb6aa197 100644 --- a/src/anthropic/resources/beta/tools/messages.py +++ b/src/anthropic/resources/beta/tools/messages.py @@ -1008,9 +1008,9 @@ def stream( # pyright: ignore[reportInconsistentOverload] ), cast_to=ToolsBetaMessage, stream=True, - stream_cls=event_handler, + stream_cls=Stream[RawToolsBetaMessageStreamEvent], ) - return ToolsBetaMessageStreamManager(make_request) + return ToolsBetaMessageStreamManager(make_request, event_handler) class AsyncMessages(AsyncAPIResource): @@ -1984,9 +1984,9 @@ def stream( # pyright: ignore[reportInconsistentOverload] ), cast_to=ToolsBetaMessage, stream=True, - stream_cls=event_handler, + stream_cls=AsyncStream[RawToolsBetaMessageStreamEvent], ) - return AsyncToolsBetaMessageStreamManager(request) + return AsyncToolsBetaMessageStreamManager(request, event_handler) class MessagesWithRawResponse: