diff --git a/hivemind/averaging/allreduce.py b/hivemind/averaging/allreduce.py index a87f9ee52..494f09a6e 100644 --- a/hivemind/averaging/allreduce.py +++ b/hivemind/averaging/allreduce.py @@ -231,8 +231,8 @@ async def _accumulate_parts_streaming(self, stream: AsyncIterator[averaging_pb2. async def _send_error_to_peer(self, peer_id: PeerID, code: averaging_pb2.MessageCode): error = averaging_pb2.AveragingData(group_id=self.group_id, code=code) - # In case of reporting the error, we expect the response stream to contain exactly one item - await asingle(self._get_peer_stub(peer_id).rpc_aggregate_part(aiter(error))) + # Coroutines are lazy, so we take the first item to start the couroutine's execution + await afirst(self._get_peer_stub(peer_id).rpc_aggregate_part(aiter(error))) def finalize(self, *, cancel: bool = False, exception: Optional[BaseException] = None): """finish or terminate AllReduceRunner, propagate any errors / cancellations to peers.""" diff --git a/hivemind/averaging/averager.py b/hivemind/averaging/averager.py index 5c1bd7bbf..c24802832 100644 --- a/hivemind/averaging/averager.py +++ b/hivemind/averaging/averager.py @@ -593,7 +593,7 @@ async def _load_state_from_peers(self, future: MPFuture): future.set_result((metadata, tensors)) self.last_updated = get_dht_time() return - except BaseException as e: + except Exception as e: logger.exception(f"Failed to download state from {peer} - {repr(e)}") finally: diff --git a/hivemind/averaging/matchmaking.py b/hivemind/averaging/matchmaking.py index 4dd13a39a..c4e0b2f79 100644 --- a/hivemind/averaging/matchmaking.py +++ b/hivemind/averaging/matchmaking.py @@ -189,7 +189,7 @@ async def request_join_group(self, leader: PeerID, expiration_time: DHTExpiratio gather=self.data_for_gather, group_key=self.group_key_manager.current_key, ) - ).__aiter__() + ) message = await asyncio.wait_for(anext(stream), timeout=self.request_timeout) if message.code == averaging_pb2.ACCEPTED: diff --git a/hivemind/p2p/p2p_daemon.py b/hivemind/p2p/p2p_daemon.py index 922aa4f40..547c9a45d 100644 --- a/hivemind/p2p/p2p_daemon.py +++ b/hivemind/p2p/p2p_daemon.py @@ -311,11 +311,17 @@ async def _read_stream() -> P2P.TInputStream: async def _process_stream() -> None: try: async for response in handler(_read_stream(), context): - await P2P.send_protobuf(response, writer) + try: + await P2P.send_protobuf(response, writer) + except Exception: + # The connection is unexpectedly closed by the caller or broken. + # The loglevel is DEBUG since the actual error will be reported on the caller + logger.debug("Exception while sending response:", exc_info=True) + break except Exception as e: - logger.warning("Exception while processing stream and sending responses:", exc_info=True) - # Sometimes `e` is a connection error, so we won't be able to report the error to the caller + logger.warning("Handler failed with the exception:", exc_info=True) with suppress(Exception): + # Sometimes `e` is a connection error, so it is okay if we fail to report `e` to the caller await P2P.send_protobuf(RPCError(message=str(e)), writer) with closing(writer): diff --git a/hivemind/utils/asyncio.py b/hivemind/utils/asyncio.py index 300a1eba4..f83abbd5a 100644 --- a/hivemind/utils/asyncio.py +++ b/hivemind/utils/asyncio.py @@ -59,7 +59,7 @@ async def aenumerate(aiterable: AsyncIterable[T]) -> AsyncIterable[Tuple[int, T] async def asingle(aiter: AsyncIterable[T]) -> T: - """If ``aiter`` has exactly one item, returns this item. Otherwise, raises `ValueError`.""" + """If ``aiter`` has exactly one item, returns this item. Otherwise, raises ``ValueError``.""" count = 0 async for item in aiter: count += 1 @@ -70,6 +70,13 @@ async def asingle(aiter: AsyncIterable[T]) -> T: return item +async def afirst(aiter: AsyncIterable[T], default: Optional[T] = None) -> Optional[T]: + """Returns the first item of ``aiter`` or ``default`` if ``aiter`` is empty.""" + async for item in aiter: + return item + return default + + async def await_cancelled(awaitable: Awaitable) -> bool: try: await awaitable diff --git a/tests/test_p2p_servicer.py b/tests/test_p2p_servicer.py index ebed6f532..17f4aac4f 100644 --- a/tests/test_p2p_servicer.py +++ b/tests/test_p2p_servicer.py @@ -5,6 +5,7 @@ from hivemind.p2p import P2P, P2PContext, ServicerBase from hivemind.proto import test_pb2 +from hivemind.utils.asyncio import anext @pytest.fixture @@ -139,9 +140,9 @@ async def rpc_wait( writer.close() elif cancel_reason == "close_generator": stub = ExampleServicer.get_stub(client, server.peer_id) - iter = stub.rpc_wait(test_pb2.TestRequest(number=10)).__aiter__() + iter = stub.rpc_wait(test_pb2.TestRequest(number=10)) - assert await iter.__anext__() == test_pb2.TestResponse(number=11) + assert await anext(iter) == test_pb2.TestResponse(number=11) await asyncio.sleep(0.25) await iter.aclose() diff --git a/tests/test_util_modules.py b/tests/test_util_modules.py index 7531c6841..3fec2194a 100644 --- a/tests/test_util_modules.py +++ b/tests/test_util_modules.py @@ -13,7 +13,7 @@ from hivemind.proto.runtime_pb2 import CompressionType from hivemind.proto.runtime_pb2_grpc import ConnectionHandlerStub from hivemind.utils import DHTExpiration, HeapEntry, MSGPackSerializer, ValueWithExpiration -from hivemind.utils.asyncio import achain, aenumerate, aiter, amap_in_executor, anext, azip +from hivemind.utils.asyncio import achain, aenumerate, afirst, aiter, amap_in_executor, anext, asingle, azip from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor from hivemind.utils.mpfuture import InvalidStateError @@ -498,3 +498,14 @@ async def _aiterate(): await anext(iterator) assert [item async for item in achain(_aiterate(), aiter(*range(5)))] == ["foo", "bar", "baz"] + list(range(5)) + + assert await asingle(aiter(1)) == 1 + with pytest.raises(ValueError): + await asingle(aiter()) + with pytest.raises(ValueError): + await asingle(aiter(1, 2, 3)) + + assert await afirst(aiter(1)) == 1 + assert await afirst(aiter()) is None + assert await afirst(aiter(), -1) == -1 + assert await afirst(aiter(1, 2, 3)) == 1