Skip to content

Commit

Permalink
fix: rewrite staggered_race to be race safe (#101)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
bdraco and pre-commit-ci[bot] authored Sep 30, 2024
1 parent c8f1fa9 commit 9db617a
Show file tree
Hide file tree
Showing 7 changed files with 492 additions and 40 deletions.
159 changes: 130 additions & 29 deletions src/aiohappyeyeballs/_staggered.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,54 @@
import asyncio
import contextlib
from typing import Awaitable, Callable, Iterable, List, Optional, Tuple, TypeVar
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Iterable,
List,
Optional,
Set,
Tuple,
TypeVar,
Union,
)

_T = TypeVar("_T")

class _Done(Exception):
pass

def _set_result(wait_next: "asyncio.Future[None]") -> None:
"""Set the result of a future if it is not already done."""
if not wait_next.done():
wait_next.set_result(None)

_T = TypeVar("_T")

async def _wait_one(
futures: "Iterable[asyncio.Future[Any]]",
loop: asyncio.AbstractEventLoop,
) -> _T:
"""Wait for the first future to complete."""
wait_next = loop.create_future()

def _on_completion(fut: "asyncio.Future[Any]") -> None:
if not wait_next.done():
wait_next.set_result(fut)

for f in futures:
f.add_done_callback(_on_completion)

try:
return await wait_next
finally:
for f in futures:
f.remove_done_callback(_on_completion)


async def staggered_race(
coro_fns: Iterable[Callable[[], Awaitable[_T]]], delay: Optional[float]
coro_fns: Iterable[Callable[[], Awaitable[_T]]],
delay: Optional[float],
*,
loop: Optional[asyncio.AbstractEventLoop] = None,
) -> Tuple[Optional[_T], Optional[int], List[Optional[BaseException]]]:
"""
Run coroutines with staggered start times and take the first to finish.
Expand All @@ -38,14 +75,18 @@ async def staggered_race(
raise
Args:
----
coro_fns: an iterable of coroutine functions, i.e. callables that
return a coroutine object when called. Use ``functools.partial`` or
lambdas to pass arguments.
delay: amount of time, in seconds, between starting coroutines. If
``None``, the coroutines will run sequentially.
loop: the event loop to use. If ``None``, the running loop is used.
Returns:
-------
tuple *(winner_result, winner_index, exceptions)* where
- *winner_result*: the result of the winning coroutine, or ``None``
Expand All @@ -62,40 +103,100 @@ async def staggered_race(
coroutine's entry is ``None``.
"""
# TODO: when we have aiter() and anext(), allow async iterables in coro_fns.
winner_result = None
winner_index = None
loop = loop or asyncio.get_running_loop()
exceptions: List[Optional[BaseException]] = []
tasks: Set[asyncio.Task[Optional[Tuple[_T, int]]]] = set()

async def run_one_coro(
this_index: int,
coro_fn: Callable[[], Awaitable[_T]],
this_failed: asyncio.Event,
) -> None:
this_index: int,
start_next: "asyncio.Future[None]",
) -> Optional[Tuple[_T, int]]:
"""
Run a single coroutine.
If the coroutine fails, set the exception in the exceptions list and
start the next coroutine by setting the result of the start_next.
If the coroutine succeeds, return the result and the index of the
coroutine in the coro_fns list.
If SystemExit or KeyboardInterrupt is raised, re-raise it.
"""
try:
result = await coro_fn()
except (SystemExit, KeyboardInterrupt):
raise
except BaseException as e:
exceptions[this_index] = e
this_failed.set() # Kickstart the next coroutine
else:
# Store winner's results
nonlocal winner_index, winner_result
assert winner_index is None # noqa: S101
winner_index = this_index
winner_result = result
raise _Done
_set_result(start_next) # Kickstart the next coroutine
return None

return result, this_index

start_next_timer: Optional[asyncio.TimerHandle] = None
start_next: Optional[asyncio.Future[None]]
task: asyncio.Task[Optional[Tuple[_T, int]]]
done: Union[asyncio.Future[None], asyncio.Task[Optional[Tuple[_T, int]]]]
coro_iter = iter(coro_fns)
this_index = -1
try:
async with asyncio.TaskGroup() as tg:
for this_index, coro_fn in enumerate(coro_fns):
this_failed = asyncio.Event()
while True:
if coro_fn := next(coro_iter, None):
this_index += 1
exceptions.append(None)
tg.create_task(run_one_coro(this_index, coro_fn, this_failed))
with contextlib.suppress(TimeoutError):
await asyncio.wait_for(this_failed.wait(), delay)
except* _Done:
pass

return winner_result, winner_index, exceptions
start_next = loop.create_future()
task = loop.create_task(run_one_coro(coro_fn, this_index, start_next))
tasks.add(task)
start_next_timer = (
loop.call_later(delay, _set_result, start_next) if delay else None
)
elif not tasks:
# We exhausted the coro_fns list and no tasks are running
# so we have no winner and all coroutines failed.
break

while tasks:
done = await _wait_one(
[*tasks, start_next] if start_next else tasks, loop
)
if done is start_next:
# The current task has failed or the timer has expired
# so we need to start the next task.
start_next = None
if start_next_timer:
start_next_timer.cancel()
start_next_timer = None

# Break out of the task waiting loop to start the next
# task.
break

if TYPE_CHECKING:
assert isinstance(done, asyncio.Task)

tasks.remove(done)
if winner := done.result():
return *winner, exceptions
finally:
# We either have:
# - a winner
# - all tasks failed
# - a KeyboardInterrupt or SystemExit.

#
# If the timer is still running, cancel it.
#
if start_next_timer:
start_next_timer.cancel()

#
# If there are any tasks left, cancel them and than
# wait them so they fill the exceptions list.
#
for task in tasks:
task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await task

return None, None, exceptions
4 changes: 2 additions & 2 deletions src/aiohappyeyeballs/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import sys
from typing import List, Optional, Sequence, Union

from . import staggered
from . import _staggered
from .types import AddrInfoType

if sys.version_info < (3, 8, 2): # noqa: UP036
Expand Down Expand Up @@ -86,7 +86,7 @@ async def start_connection(
except (RuntimeError, OSError):
continue
else: # using happy eyeballs
sock, _, _ = await staggered.staggered_race(
sock, _, _ = await _staggered.staggered_race(
(
functools.partial(
_connect_sock, current_loop, exceptions, addrinfo, local_addr_infos
Expand Down
9 changes: 0 additions & 9 deletions src/aiohappyeyeballs/staggered.py

This file was deleted.

32 changes: 32 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""Configuration for the tests."""

import asyncio
import threading
from typing import Generator

import pytest


@pytest.fixture(autouse=True)
def verify_threads_ended():
"""Verify that the threads are not running after the test."""
threads_before = frozenset(threading.enumerate())
yield
threads = frozenset(threading.enumerate()) - threads_before
assert not threads


@pytest.fixture(autouse=True)
def verify_no_lingering_tasks(
event_loop: asyncio.AbstractEventLoop,
) -> Generator[None, None, None]:
"""Verify that all tasks are cleaned up."""
tasks_before = asyncio.all_tasks(event_loop)
yield

tasks = asyncio.all_tasks(event_loop) - tasks_before
for task in tasks:
pytest.fail(f"Task still running: {task!r}")
task.cancel()
if tasks:
event_loop.run_until_complete(asyncio.wait(tasks))
86 changes: 86 additions & 0 deletions tests/test_staggered.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import asyncio
import sys
from functools import partial

import pytest

from aiohappyeyeballs._staggered import staggered_race


@pytest.mark.asyncio
async def test_one_winners():
"""Test that there is only one winner when there is no await in the coro."""
winners = []

async def coro(idx):
winners.append(idx)
return idx

coros = [partial(coro, idx) for idx in range(4)]

winner, index, excs = await staggered_race(
coros,
delay=None,
)
assert len(winners) == 1
assert winners == [0]
assert winner == 0
assert index == 0
assert excs == [None]


@pytest.mark.asyncio
async def test_multiple_winners():
"""Test multiple winners are handled correctly."""
loop = asyncio.get_running_loop()
winners = []
finish = loop.create_future()

async def coro(idx):
await finish
winners.append(idx)
return idx

coros = [partial(coro, idx) for idx in range(4)]

task = loop.create_task(staggered_race(coros, delay=0.00001))
await asyncio.sleep(0.1)
loop.call_soon(finish.set_result, None)
winner, index, excs = await task
assert len(winners) == 4
assert winners == [0, 1, 2, 3]
assert winner == 0
assert index == 0
assert excs == [None, None, None, None]


@pytest.mark.skipif(sys.version_info < (3, 12), reason="requires python3.12 or higher")
def test_multiple_winners_eager_task_factory():
"""Test multiple winners are handled correctly."""
loop = asyncio.new_event_loop()
eager_task_factory = asyncio.create_eager_task_factory(asyncio.Task)
loop.set_task_factory(eager_task_factory)
asyncio.set_event_loop(None)

async def run():
winners = []
finish = loop.create_future()

async def coro(idx):
await finish
winners.append(idx)
return idx

coros = [partial(coro, idx) for idx in range(4)]

task = loop.create_task(staggered_race(coros, delay=0.00001))
await asyncio.sleep(0.1)
loop.call_soon(finish.set_result, None)
winner, index, excs = await task
assert len(winners) == 4
assert winners == [0, 1, 2, 3]
assert winner == 0
assert index == 0
assert excs == [None, None, None, None]

loop.run_until_complete(run())
Loading

0 comments on commit 9db617a

Please sign in to comment.