diff --git a/src/anthropic/lib/streaming/_beta_messages.py b/src/anthropic/lib/streaming/_beta_messages.py index b6241d2e..e9325f55 100644 --- a/src/anthropic/lib/streaming/_beta_messages.py +++ b/src/anthropic/lib/streaming/_beta_messages.py @@ -1,14 +1,14 @@ from __future__ import annotations from types import TracebackType -from typing import TYPE_CHECKING, Any, Callable, cast +from typing import TYPE_CHECKING, Any, Type, Callable, cast from typing_extensions import Self, Iterator, Awaitable, AsyncIterator, assert_never import httpx from pydantic import BaseModel from ..._utils import consume_sync_iterator, consume_async_iterator -from ..._models import build, construct_type +from ..._models import build, construct_type, construct_type_unchecked from ._beta_types import ( BetaTextEvent, BetaCitationEvent, @@ -372,8 +372,16 @@ def accumulate_event( event: BetaRawMessageStreamEvent, current_snapshot: BetaMessage | None, ) -> BetaMessage: - if not isinstance(event, BaseModel): # pyright: ignore[reportUnnecessaryIsInstance] - raise TypeError(f"Unexpected event runtime type - {event}") + if not isinstance(cast(Any, event), BaseModel): + event = cast( # pyright: ignore[reportUnnecessaryCast] + BetaRawMessageStreamEvent, + construct_type_unchecked( + type_=cast(Type[BetaRawMessageStreamEvent], BetaRawMessageStreamEvent), + value=event, + ), + ) + if not isinstance(cast(Any, event), BaseModel): + raise TypeError(f"Unexpected event runtime type, after deserialising twice - {event} - {type(event)}") if current_snapshot is None: if event.type == "message_start": diff --git a/src/anthropic/lib/streaming/_messages.py b/src/anthropic/lib/streaming/_messages.py index 146a1bab..082ed383 100644 --- a/src/anthropic/lib/streaming/_messages.py +++ b/src/anthropic/lib/streaming/_messages.py @@ -1,7 +1,7 @@ from __future__ import annotations from types import TracebackType -from typing import TYPE_CHECKING, Any, Callable, cast +from typing import TYPE_CHECKING, Any, Type, Callable, cast from typing_extensions import Self, Iterator, Awaitable, AsyncIterator, assert_never import httpx @@ -17,7 +17,7 @@ ) from ...types import Message, ContentBlock, RawMessageStreamEvent from ..._utils import consume_sync_iterator, consume_async_iterator -from ..._models import build, construct_type +from ..._models import build, construct_type, construct_type_unchecked from ..._streaming import Stream, AsyncStream @@ -372,8 +372,16 @@ def accumulate_event( event: RawMessageStreamEvent, current_snapshot: Message | None, ) -> Message: - if not isinstance(event, BaseModel): # pyright: ignore[reportUnnecessaryIsInstance] - raise TypeError(f"Unexpected event runtime type - {event}") + if not isinstance(cast(Any, event), BaseModel): + event = cast( # pyright: ignore[reportUnnecessaryCast] + RawMessageStreamEvent, + construct_type_unchecked( + type_=cast(Type[RawMessageStreamEvent], RawMessageStreamEvent), + value=event, + ), + ) + if not isinstance(cast(Any, event), BaseModel): + raise TypeError(f"Unexpected event runtime type, after deserialising twice - {event} - {type(event)}") if current_snapshot is None: if event.type == "message_start":