Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Matchmaking finalizers #357

Merged
merged 8 commits into from
Aug 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 29 additions & 41 deletions hivemind/averaging/matchmaking.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, ServicerBase
from hivemind.proto import averaging_pb2
from hivemind.utils import TimedStorage, get_dht_time, get_logger, timed_storage
from hivemind.utils.asyncio import anext
from hivemind.utils.asyncio import anext, cancel_and_wait

logger = get_logger(__name__)

Expand Down Expand Up @@ -127,10 +127,9 @@ async def look_for_group(self, *, data_for_gather: bytes, timeout: Optional[floa
raise

finally:
if not request_leaders_task.done():
request_leaders_task.cancel()
if not self.assembled_group.done():
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.cancel() just returns False when the awaitable is done.

self.assembled_group.cancel()
await cancel_and_wait(request_leaders_task)
self.assembled_group.cancel()

while len(self.current_followers) > 0:
await self.follower_was_discarded.wait()
self.follower_was_discarded.clear()
Expand Down Expand Up @@ -229,7 +228,7 @@ async def request_join_group(self, leader: PeerID, expiration_time: DHTExpiratio
logger.debug(f"{self} - potential leader {leader} did not respond within {self.request_timeout}")
return None
except (P2PHandlerError, StopAsyncIteration) as e:
logger.error(f"{self} - failed to request potential leader {leader}: {e}")
logger.exception(f"{self} - failed to request potential leader {leader}:")
return None

finally:
Expand Down Expand Up @@ -413,10 +412,9 @@ async def begin_search(self, key_manager: GroupKeyManager, timeout: Optional[flo
try:
yield self
finally:
if not update_queue_task.done():
update_queue_task.cancel()
if declare and not declare_averager_task.done():
declare_averager_task.cancel()
await cancel_and_wait(update_queue_task)
if declare:
await cancel_and_wait(declare_averager_task)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using cancel_and_wait() here resolves this frequent warning:

Task was destroyed but it is pending!
task: <Task pending name='Task-11' coro=<PotentialLeaders._declare_averager_periodically() done, defined at /home/jheuristic/Documents/exp/hivemind/hivemind/averaging/matchmaking.py:513> wait_for=<Future pending cb=[<TaskWakeupMethWrapper object at 0x7efcec662fd0>()]>>


for field in (
self.past_attempts,
Expand Down Expand Up @@ -477,37 +475,31 @@ def request_expiration_time(self) -> float:
else:
return min(get_dht_time() + self.averaging_expiration, self.search_end_time)

async def _update_queue_periodically(self, key_manager: GroupKeyManager):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code is identical besides removing the try-except block (rationale is explained in this comment).

try:
DISCREPANCY = timed_storage.MAX_DHT_TIME_DISCREPANCY_SECONDS
while get_dht_time() < self.search_end_time:
new_peers = await key_manager.get_averagers(key_manager.current_key, only_active=True)
self.max_assured_time = max(
self.max_assured_time, get_dht_time() + self.averaging_expiration - DISCREPANCY
)
async def _update_queue_periodically(self, key_manager: GroupKeyManager) -> None:
DISCREPANCY = timed_storage.MAX_DHT_TIME_DISCREPANCY_SECONDS
while get_dht_time() < self.search_end_time:
new_peers = await key_manager.get_averagers(key_manager.current_key, only_active=True)
self.max_assured_time = max(
self.max_assured_time, get_dht_time() + self.averaging_expiration - DISCREPANCY
)

self.leader_queue.clear()
for peer, peer_expiration_time in new_peers:
if peer == self.peer_id or (peer, peer_expiration_time) in self.past_attempts:
continue
self.leader_queue.store(peer, peer_expiration_time, peer_expiration_time)
self.max_assured_time = max(self.max_assured_time, peer_expiration_time - DISCREPANCY)
self.leader_queue.clear()
for peer, peer_expiration_time in new_peers:
if peer == self.peer_id or (peer, peer_expiration_time) in self.past_attempts:
continue
self.leader_queue.store(peer, peer_expiration_time, peer_expiration_time)
self.max_assured_time = max(self.max_assured_time, peer_expiration_time - DISCREPANCY)

self.update_finished.set()
self.update_finished.set()

await asyncio.wait(
{self.running.wait(), self.update_triggered.wait()},
return_when=asyncio.ALL_COMPLETED,
timeout=self.search_end_time - get_dht_time() if isfinite(self.search_end_time) else None,
)
self.update_triggered.clear()
except (concurrent.futures.CancelledError, asyncio.CancelledError):
return # note: this is a compatibility layer for python3.7
except Exception as e:
logger.error(f"{self.peer_id} - caught {type(e)}: {e}")
raise
await asyncio.wait(
{self.running.wait(), self.update_triggered.wait()},
return_when=asyncio.ALL_COMPLETED,
timeout=self.search_end_time - get_dht_time() if isfinite(self.search_end_time) else None,
)
self.update_triggered.clear()

async def _declare_averager_periodically(self, key_manager: GroupKeyManager):
async def _declare_averager_periodically(self, key_manager: GroupKeyManager) -> None:
async with self.lock_declare:
try:
while True:
Expand All @@ -521,10 +513,6 @@ async def _declare_averager_periodically(self, key_manager: GroupKeyManager):
await asyncio.sleep(self.declared_expiration_time - get_dht_time())
if self.running.is_set() and len(self.leader_queue) == 0:
await key_manager.update_key_on_not_enough_peers()
except (concurrent.futures.CancelledError, asyncio.CancelledError):
pass # note: this is a compatibility layer for python3.7
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[non-blocking]
Q: are you certain that we no longer need to handle concurrent.futures.CancelledError here or is it an educated guess?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have edited this PR to only remove the except statements in _update_queue_periodically() and _declare_averager_periodically() methods.

These methods are only awaited in cancel_and_wait() that handles both cancel and normal errors by itself.

except Exception as e: # note: we catch exceptions here because otherwise they are never printed
logger.error(f"{self.peer_id} - caught {type(e)}: {e}")
finally:
if self.declared_group_key is not None:
prev_declared_key, prev_expiration_time = self.declared_group_key, self.declared_expiration_time
Expand Down
15 changes: 4 additions & 11 deletions hivemind/dht/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from hivemind.dht.routing import DHTID, DHTKey, DHTValue, Subkey
from hivemind.dht.validation import CompositeValidator, RecordValidatorBase
from hivemind.p2p import P2P, PeerID
from hivemind.utils import DHTExpiration, MPFuture, ValueWithExpiration, await_cancelled, get_logger, switch_to_uvloop
from hivemind.utils import DHTExpiration, MPFuture, ValueWithExpiration, get_logger, switch_to_uvloop

logger = get_logger(__name__)

Expand Down Expand Up @@ -261,18 +261,11 @@ def run_coroutine(
async def _run_coroutine(
self, coro: Callable[[DHT, DHTNode], Awaitable[ReturnType]], future: MPFuture[ReturnType]
):
main_task = asyncio.create_task(coro(self, self._node))
Copy link
Member Author

@borzunov borzunov Aug 25, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cancels here did not work since the new MPFuture implementation does not support asyncio awaits in a child process.

They produced the following exception that was accidentally suppressed in await_cancelled():

Traceback (most recent call last):
  File "/home/borzunov/hivemind/hivemind/utils/asyncio.py", line 83, in await_cancelled
    await awaitable
  File "/home/borzunov/hivemind/hivemind/utils/mpfuture.py", line 284, in __await__
    raise RuntimeError("Can't await: MPFuture was created with no event loop")
RuntimeError: Can't await: MPFuture was created with no event loop

This PR removes them since they are not used, as discussed with @justheuristic.

cancel_task = asyncio.create_task(await_cancelled(future))
try:
await asyncio.wait({main_task, cancel_task}, return_when=asyncio.FIRST_COMPLETED)
if future.cancelled():
main_task.cancel()
else:
future.set_result(await main_task)
future.set_result(await coro(self, self._node))
except BaseException as e:
logger.exception(f"Caught an exception when running a coroutine: {e}")
if not future.done():
future.set_exception(e)
logger.exception("Caught an exception when running a coroutine:")
future.set_exception(e)

def add_validators(self, record_validators: Iterable[RecordValidatorBase]) -> None:
if not self._ready.done():
Expand Down
17 changes: 16 additions & 1 deletion hivemind/utils/asyncio.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import concurrent.futures
from concurrent.futures import ThreadPoolExecutor
from typing import AsyncIterable, AsyncIterator, Awaitable, Callable, Optional, Tuple, TypeVar, Union

Expand Down Expand Up @@ -81,12 +82,26 @@ async def await_cancelled(awaitable: Awaitable) -> bool:
try:
await awaitable
return False
except asyncio.CancelledError:
except (asyncio.CancelledError, concurrent.futures.CancelledError):
# In Python 3.7, awaiting a cancelled asyncio.Future raises concurrent.futures.CancelledError
# instead of asyncio.CancelledError
return True
except BaseException:
logger.exception(f"Exception in {awaitable}:")
return False


async def cancel_and_wait(awaitable: Awaitable) -> bool:
"""
Cancels ``awaitable`` and waits for its cancellation.
In case of ``asyncio.Task``, helps to avoid ``Task was destroyed but it is pending!`` errors.
In case of ``asyncio.Future``, equal to ``future.cancel()``.
"""

awaitable.cancel()
return await await_cancelled(awaitable)


async def amap_in_executor(
func: Callable[..., T],
*iterables: AsyncIterable,
Expand Down
44 changes: 43 additions & 1 deletion tests/test_util_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,17 @@
from hivemind.proto.runtime_pb2 import CompressionType
from hivemind.proto.runtime_pb2_grpc import ConnectionHandlerStub
from hivemind.utils import DHTExpiration, HeapEntry, MSGPackSerializer, ValueWithExpiration
from hivemind.utils.asyncio import achain, aenumerate, afirst, aiter, amap_in_executor, anext, asingle, azip
from hivemind.utils.asyncio import (
achain,
aenumerate,
afirst,
aiter,
amap_in_executor,
anext,
asingle,
azip,
cancel_and_wait,
)
from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
from hivemind.utils.mpfuture import InvalidStateError

Expand Down Expand Up @@ -509,3 +519,35 @@ async def _aiterate():
assert await afirst(aiter()) is None
assert await afirst(aiter(), -1) == -1
assert await afirst(aiter(1, 2, 3)) == 1


@pytest.mark.asyncio
async def test_cancel_and_wait():
finished_gracefully = False

async def coro_with_finalizer():
nonlocal finished_gracefully

try:
await asyncio.Event().wait()
except asyncio.CancelledError:
await asyncio.sleep(0.05)
finished_gracefully = True
raise

task = asyncio.create_task(coro_with_finalizer())
await asyncio.sleep(0.05)
assert await cancel_and_wait(task)
assert finished_gracefully

async def coro_with_result():
return 777

async def coro_with_error():
raise ValueError("error")

task_with_result = asyncio.create_task(coro_with_result())
task_with_error = asyncio.create_task(coro_with_error())
await asyncio.sleep(0.05)
assert not await cancel_and_wait(task_with_result)
assert not await cancel_and_wait(task_with_error)