Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix minor asyncio issues in averager #356

Merged
merged 5 commits into from
Aug 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions hivemind/averaging/allreduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,8 @@ async def _accumulate_parts_streaming(self, stream: AsyncIterator[averaging_pb2.

async def _send_error_to_peer(self, peer_id: PeerID, code: averaging_pb2.MessageCode):
error = averaging_pb2.AveragingData(group_id=self.group_id, code=code)
# In case of reporting the error, we expect the response stream to contain exactly one item
await asingle(self._get_peer_stub(peer_id).rpc_aggregate_part(aiter(error)))
# Coroutines are lazy, so we take the first item to start the couroutine's execution
await afirst(self._get_peer_stub(peer_id).rpc_aggregate_part(aiter(error)))

def finalize(self, *, cancel: bool = False, exception: Optional[BaseException] = None):
"""finish or terminate AllReduceRunner, propagate any errors / cancellations to peers."""
Expand Down
2 changes: 1 addition & 1 deletion hivemind/averaging/averager.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,7 @@ async def _load_state_from_peers(self, future: MPFuture):
future.set_result((metadata, tensors))
self.last_updated = get_dht_time()
return
except BaseException as e:
except Exception as e:
logger.exception(f"Failed to download state from {peer} - {repr(e)}")

finally:
Expand Down
2 changes: 1 addition & 1 deletion hivemind/averaging/matchmaking.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ async def request_join_group(self, leader: PeerID, expiration_time: DHTExpiratio
gather=self.data_for_gather,
group_key=self.group_key_manager.current_key,
)
).__aiter__()
)
message = await asyncio.wait_for(anext(stream), timeout=self.request_timeout)

if message.code == averaging_pb2.ACCEPTED:
Expand Down
12 changes: 9 additions & 3 deletions hivemind/p2p/p2p_daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,11 +311,17 @@ async def _read_stream() -> P2P.TInputStream:
async def _process_stream() -> None:
try:
async for response in handler(_read_stream(), context):
await P2P.send_protobuf(response, writer)
try:
await P2P.send_protobuf(response, writer)
except Exception:
# The connection is unexpectedly closed by the caller or broken.
# The loglevel is DEBUG since the actual error will be reported on the caller
logger.debug("Exception while sending response:", exc_info=True)
break
except Exception as e:
logger.warning("Exception while processing stream and sending responses:", exc_info=True)
# Sometimes `e` is a connection error, so we won't be able to report the error to the caller
logger.warning("Handler failed with the exception:", exc_info=True)
with suppress(Exception):
# Sometimes `e` is a connection error, so it is okay if we fail to report `e` to the caller
await P2P.send_protobuf(RPCError(message=str(e)), writer)

with closing(writer):
Expand Down
9 changes: 8 additions & 1 deletion hivemind/utils/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ async def aenumerate(aiterable: AsyncIterable[T]) -> AsyncIterable[Tuple[int, T]


async def asingle(aiter: AsyncIterable[T]) -> T:
"""If ``aiter`` has exactly one item, returns this item. Otherwise, raises `ValueError`."""
"""If ``aiter`` has exactly one item, returns this item. Otherwise, raises ``ValueError``."""
count = 0
async for item in aiter:
count += 1
Expand All @@ -70,6 +70,13 @@ async def asingle(aiter: AsyncIterable[T]) -> T:
return item


async def afirst(aiter: AsyncIterable[T], default: Optional[T] = None) -> Optional[T]:
"""Returns the first item of ``aiter`` or ``default`` if ``aiter`` is empty."""
async for item in aiter:
return item
return default


async def await_cancelled(awaitable: Awaitable) -> bool:
try:
await awaitable
Expand Down
5 changes: 3 additions & 2 deletions tests/test_p2p_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from hivemind.p2p import P2P, P2PContext, ServicerBase
from hivemind.proto import test_pb2
from hivemind.utils.asyncio import anext


@pytest.fixture
Expand Down Expand Up @@ -139,9 +140,9 @@ async def rpc_wait(
writer.close()
elif cancel_reason == "close_generator":
stub = ExampleServicer.get_stub(client, server.peer_id)
iter = stub.rpc_wait(test_pb2.TestRequest(number=10)).__aiter__()
iter = stub.rpc_wait(test_pb2.TestRequest(number=10))

assert await iter.__anext__() == test_pb2.TestResponse(number=11)
assert await anext(iter) == test_pb2.TestResponse(number=11)
await asyncio.sleep(0.25)

await iter.aclose()
Expand Down
13 changes: 12 additions & 1 deletion tests/test_util_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from hivemind.proto.runtime_pb2 import CompressionType
from hivemind.proto.runtime_pb2_grpc import ConnectionHandlerStub
from hivemind.utils import DHTExpiration, HeapEntry, MSGPackSerializer, ValueWithExpiration
from hivemind.utils.asyncio import achain, aenumerate, aiter, amap_in_executor, anext, azip
from hivemind.utils.asyncio import achain, aenumerate, afirst, aiter, amap_in_executor, anext, asingle, azip
from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
from hivemind.utils.mpfuture import InvalidStateError

Expand Down Expand Up @@ -498,3 +498,14 @@ async def _aiterate():
await anext(iterator)

assert [item async for item in achain(_aiterate(), aiter(*range(5)))] == ["foo", "bar", "baz"] + list(range(5))

assert await asingle(aiter(1)) == 1
with pytest.raises(ValueError):
await asingle(aiter())
with pytest.raises(ValueError):
await asingle(aiter(1, 2, 3))

assert await afirst(aiter(1)) == 1
assert await afirst(aiter()) is None
assert await afirst(aiter(), -1) == -1
assert await afirst(aiter(1, 2, 3)) == 1