Skip to content

Commit

Permalink
nicer work-around
Browse files Browse the repository at this point in the history
  • Loading branch information
graingert committed Dec 28, 2024
1 parent 6e60cda commit 3c8071f
Showing 1 changed file with 26 additions and 36 deletions.
62 changes: 26 additions & 36 deletions starlette/testclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,20 @@ class _Eof(enum.Enum):
EOF: typing.Final = _Eof.EOF
Eof = typing.Literal[_Eof.EOF]

_T_co = typing.TypeVar("_T_co", covariant=True)

class _HadException(Exception):
def __init__(self, wrapped: BaseException, /, *args: object):
super().__init__(wrapped, *args)
self.wrapped: typing.Final = wrapped

class StartableAsyncFn(typing.Generic[_T_co], typing.Protocol):
async def __call__(self, /, *, task_status: anyio.abc.TaskStatus[_T_co]) -> None: ...


@contextlib.contextmanager
def _handle_task(portal: anyio.abc.BlockingPortal, async_fn: StartableAsyncFn[_T_co]) -> typing.Generator[_T_co]:
fut, result = portal.start_task(async_fn)
try:
yield result
finally:
fut.result()


class WebSocketTestSession:
Expand All @@ -116,40 +125,21 @@ def __init__(
self.should_close: anyio.Event

def __enter__(self) -> WebSocketTestSession:
try:
with contextlib.ExitStack() as stack:
self.portal = portal = stack.enter_context(self.portal_factory())

fut, cs = self.portal.start_task(self._run)

@stack.callback
def handle_task() -> None:
portal.call(cs.cancel)
e = fut.exception()
if e is None:
return
# work-around for https://github.com/python/cpython/issues/69968
try:
raise _HadException(e)
finally:
del e

self.send({"type": "websocket.connect"})
message = self.receive()
self._raise_on_close(message)
self.accepted_subprotocol = message.get("subprotocol", None)
self.extra_headers = message.get("headers", None)
stack.callback(self.close, 1000)
self.exit_stack = stack.pop_all()
return self
except _HadException as e:
raise e.wrapped
with contextlib.ExitStack() as stack:
self.portal = portal = stack.enter_context(self.portal_factory())
cs = stack.enter_context(_handle_task(portal, self._run))
stack.callback(portal.call, cs.cancel)
self.send({"type": "websocket.connect"})
message = self.receive()
self._raise_on_close(message)
self.accepted_subprotocol = message.get("subprotocol", None)
self.extra_headers = message.get("headers", None)
stack.callback(self.close, 1000)
self.exit_stack = stack.pop_all()
return self

def __exit__(self, *args: typing.Any) -> None:
try:
self.exit_stack.close()
except _HadException as e:
raise e.wrapped
self.exit_stack.close()

while True:
message = self._send_queue.get()
Expand Down

0 comments on commit 3c8071f

Please sign in to comment.