Skip to content

Commit

Permalink
fix: copy staggered from standard lib for python 3.12+ (#95)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco authored Sep 27, 2024
1 parent 04c42b4 commit c5a4023
Show file tree
Hide file tree
Showing 4 changed files with 193 additions and 1 deletion.
101 changes: 101 additions & 0 deletions src/aiohappyeyeballs/_staggered.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import asyncio
import contextlib
from typing import Awaitable, Callable, Iterable, List, Optional, Tuple, TypeVar


class _Done(Exception):
pass


_T = TypeVar("_T")


async def staggered_race(
coro_fns: Iterable[Callable[[], Awaitable[_T]]], delay: Optional[float]
) -> Tuple[Optional[_T], Optional[int], List[Optional[BaseException]]]:
"""
Run coroutines with staggered start times and take the first to finish.
This method takes an iterable of coroutine functions. The first one is
started immediately. From then on, whenever the immediately preceding one
fails (raises an exception), or when *delay* seconds has passed, the next
coroutine is started. This continues until one of the coroutines complete
successfully, in which case all others are cancelled, or until all
coroutines fail.
The coroutines provided should be well-behaved in the following way:
* They should only ``return`` if completed successfully.
* They should always raise an exception if they did not complete
successfully. In particular, if they handle cancellation, they should
probably reraise, like this::
try:
# do work
except asyncio.CancelledError:
# undo partially completed work
raise
Args:
coro_fns: an iterable of coroutine functions, i.e. callables that
return a coroutine object when called. Use ``functools.partial`` or
lambdas to pass arguments.
delay: amount of time, in seconds, between starting coroutines. If
``None``, the coroutines will run sequentially.
Returns:
tuple *(winner_result, winner_index, exceptions)* where
- *winner_result*: the result of the winning coroutine, or ``None``
if no coroutines won.
- *winner_index*: the index of the winning coroutine in
``coro_fns``, or ``None`` if no coroutines won. If the winning
coroutine may return None on success, *winner_index* can be used
to definitively determine whether any coroutine won.
- *exceptions*: list of exceptions returned by the coroutines.
``len(exceptions)`` is equal to the number of coroutines actually
started, and the order is the same as in ``coro_fns``. The winning
coroutine's entry is ``None``.
"""
# TODO: when we have aiter() and anext(), allow async iterables in coro_fns.
winner_result = None
winner_index = None
exceptions: List[Optional[BaseException]] = []

async def run_one_coro(
this_index: int,
coro_fn: Callable[[], Awaitable[_T]],
this_failed: asyncio.Event,
) -> None:
try:
result = await coro_fn()
except (SystemExit, KeyboardInterrupt):
raise
except BaseException as e:
exceptions[this_index] = e
this_failed.set() # Kickstart the next coroutine
else:
# Store winner's results
nonlocal winner_index, winner_result
assert winner_index is None # noqa: S101
winner_index = this_index
winner_result = result
raise _Done

try:
async with asyncio.TaskGroup() as tg:
for this_index, coro_fn in enumerate(coro_fns):
this_failed = asyncio.Event()
exceptions.append(None)
tg.create_task(run_one_coro(this_index, coro_fn, this_failed))
with contextlib.suppress(TimeoutError):
await asyncio.wait_for(this_failed.wait(), delay)
except* _Done:
pass

return winner_result, winner_index, exceptions
2 changes: 1 addition & 1 deletion src/aiohappyeyeballs/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
import itertools
import socket
import sys
from asyncio import staggered
from typing import List, Optional, Sequence

from . import staggered
from .types import AddrInfoType

if sys.version_info < (3, 8, 2): # noqa: UP036
Expand Down
9 changes: 9 additions & 0 deletions src/aiohappyeyeballs/staggered.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import sys

if sys.version_info > (3, 11):
# https://github.com/python/cpython/issues/124639#issuecomment-2378129834
from ._staggered import staggered_race
else:
from asyncio.staggered import staggered_race

__all__ = ["staggered_race"]
82 changes: 82 additions & 0 deletions tests/test_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1368,6 +1368,88 @@ async def _sock_connect(
]


@patch_socket
@pytest.mark.asyncio
@pytest.mark.xfail(reason="raises RuntimeError: coroutine ignored GeneratorExit")
async def test_handling_system_exit(
m_socket: ModuleType,
) -> None:
"""Test handling SystemExit."""
mock_socket = mock.MagicMock(
family=socket.AF_INET,
type=socket.SOCK_STREAM,
proto=socket.IPPROTO_TCP,
fileno=mock.MagicMock(return_value=1),
)
create_calls = []

def _socket(*args, **kw):
for attr in kw:
setattr(mock_socket, attr, kw[attr])
return mock_socket

async def _sock_connect(
sock: socket.socket, address: Tuple[str, int, int, int]
) -> None:
create_calls.append(address)
raise SystemExit

m_socket.socket = _socket # type: ignore
ipv6_addr_info = (
socket.AF_INET6,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("dead:beef::", 80, 0, 0),
)
ipv6_addr_info_2 = (
socket.AF_INET6,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("dead:aaaa::", 80, 0, 0),
)
ipv4_addr_info = (
socket.AF_INET,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("107.6.106.83", 80),
)
addr_info = [ipv6_addr_info, ipv6_addr_info_2, ipv4_addr_info]
local_addr_infos = [
(
socket.AF_INET6,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("::1", 0, 0, 0),
),
(
socket.AF_INET,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("127.0.0.1", 0),
),
]
loop = asyncio.get_running_loop()
with pytest.raises(SystemExit), mock.patch.object(
loop, "sock_connect", _sock_connect
):
await start_connection(
addr_info,
happy_eyeballs_delay=0.3,
interleave=2,
local_addr_infos=local_addr_infos,
)

# Stopped after the first call
assert create_calls == [
("dead:beef::", 80, 0, 0),
]


@pytest.mark.asyncio
@pytest.mark.skipif(sys.version_info >= (3, 8, 2), reason="requires < python 3.8.2")
def test_python_38_compat() -> None:
Expand Down

0 comments on commit c5a4023

Please sign in to comment.