Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert averager to libp2p backend #323

Merged
merged 70 commits into from
Jul 28, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
e907eb2
Implement DHT.p2p property
borzunov Jul 15, 2021
ea3b56d
Start converting averager to libp2p backend
borzunov Jul 15, 2021
2795176
Support inheritance and arbitrary parameter names for rpc_* methods i…
borzunov Jul 15, 2021
85785b9
Make test_load_state_from_peers work
borzunov Jul 15, 2021
c89b598
Support calling Servicer.get_stub without having servicer instances
borzunov Jul 15, 2021
a8fcb0a
Convert AllReduceRunner, Matchmaking, and GroupKeyManager to libp2p b…
borzunov Jul 15, 2021
83c5d30
Fix test_allreduce.py
borzunov Jul 15, 2021
0384737
Fix test_allreduce_once
borzunov Jul 16, 2021
4f5acb5
Fix test_averaging.py
borzunov Jul 16, 2021
7eb91a3
Move launch_dht_instances() to test_utils.py
borzunov Jul 16, 2021
36282f8
Continue fix test_averaging.py
borzunov Jul 16, 2021
2ae476f
Speed up DHT swarm creation
borzunov Jul 16, 2021
2e51140
Remove `endpoint` parameter of GroupKeyManager
borzunov Jul 16, 2021
12e8039
Merge remote-tracking branch 'origin/master' into averager-libp2p
borzunov Jul 16, 2021
20f19b1
Fix RPC in ServicerBase derivatives for test_training_averager
borzunov Jul 16, 2021
f615693
Rename _get_stub to _get_peer_stub
borzunov Jul 16, 2021
1256e09
Fix benchmark_averaging.py
borzunov Jul 16, 2021
02f1d47
Fix bugs with misusing `str` and PeerID
borzunov Jul 16, 2021
9557452
Make diff smaller
borzunov Jul 16, 2021
df26bfc
Try removing timeout
borzunov Jul 16, 2021
b90cef4
Remove excess import
borzunov Jul 16, 2021
0042076
Remove `listen_on` argument
borzunov Jul 16, 2021
a0fec34
Fix TrainingState field types
borzunov Jul 16, 2021
6c20858
Return timeout=5
borzunov Jul 16, 2021
7e01f8e
Reuse binary stream methods for protobuf streams
borzunov Jul 16, 2021
356cff0
Fix test_allreduce_grid(), skip test_overcrowded()
borzunov Jul 19, 2021
0cba0a0
Fix typing and sort imports
borzunov Jul 19, 2021
2bb98a9
Blackify
borzunov Jul 19, 2021
3c92944
Fix bug in benchmark_averaging.py from master
borzunov Jul 19, 2021
a01d6a2
Increase sleep time in test_decentralized_optimizer_averaging()
borzunov Jul 19, 2021
ae89f83
Merge remote-tracking branch 'origin/master' into averager-libp2p
borzunov Jul 19, 2021
27fbe2d
Increase test time limit
borzunov Jul 19, 2021
411981e
Improve docstring
borzunov Jul 20, 2021
06fdf1f
Copy dht.client_mode to averager.client_mode unless it is explicitly …
borzunov Jul 20, 2021
2041893
Increase test verbosity
borzunov Jul 20, 2021
93e4ccd
Recreate MPFuture locks after killing leftover children in tests
borzunov Jul 20, 2021
7a52b84
Reset MPFuture state in tests
borzunov Jul 20, 2021
f917b9d
Skip test_allreduce_grid()
borzunov Jul 20, 2021
4305987
Skip test_decentralized_optimizer_averaging()
borzunov Jul 20, 2021
bb10f4e
Update CI settings
borzunov Jul 20, 2021
10f6d08
Reset only MPFuture locks in tests (not the whole state)
borzunov Jul 20, 2021
2f2c7cd
Merge remote-tracking branch 'origin/master' into averager-libp2p
borzunov Jul 20, 2021
868be1b
Ensure group key equality in rpc_join_group()
borzunov Jul 21, 2021
7a9f48f
Increase test verbosity
borzunov Jul 21, 2021
8f5eeeb
Recreate MPFuture locks after killing child processes
borzunov Jul 21, 2021
8b6f13e
Reset the whole MPFuture backend in tests
borzunov Jul 21, 2021
c11f3fd
Add docstring for MPFuture.reset_state()
borzunov Jul 21, 2021
91c88f8
Add comments to protobufs
borzunov Jul 21, 2021
c524f96
Merge branch 'ensure-equal-group-keys' into averager-libp2p
borzunov Jul 21, 2021
b1a43a5
call_binary_stream_handler: Retry on ControlError
borzunov Jul 16, 2021
a21f6d9
Blackify
borzunov Jul 21, 2021
b0173f6
Extract _get_handler_name method in Servicer
borzunov Jul 22, 2021
182529f
Implement servicer namespaces
borzunov Jul 22, 2021
0d683fd
Merge remote-tracking branch 'origin/master' into averager-libp2p
borzunov Jul 22, 2021
9314701
Remove excess import
borzunov Jul 22, 2021
cb22719
Unskip test_allreduce_grid()
borzunov Jul 22, 2021
ca8b563
Merge remote-tracking branch 'origin/master' into averager-libp2p
borzunov Jul 26, 2021
79bf112
Fix some of @mryab's comments
borzunov Jul 26, 2021
ec4cc63
Implement asingle()
borzunov Jul 26, 2021
11cc1f0
Revert "call_binary_stream_handler: Retry on ControlError"
borzunov Jul 26, 2021
2e8a6a3
Make some DHT methods static
borzunov Jul 26, 2021
b374164
Fix some of @mryab's comments
borzunov Jul 26, 2021
c3a1747
Blackify
borzunov Jul 26, 2021
fc8d296
Remove circular references between AllReduceRunner/Matchmaking and De…
borzunov Jul 26, 2021
cdd66ad
Avoid duplicating defaults
borzunov Jul 27, 2021
3f39a62
Set DEFAULT_PART_SIZE_BYTES = 2 ** 19
borzunov Jul 27, 2021
85fa631
Fix @mryab's comments
borzunov Jul 28, 2021
40d60f0
Rename unused function/method args to _arg
borzunov Jul 28, 2021
1b2d13d
Merge branch 'master' into averager-libp2p
borzunov Jul 28, 2021
b9973fb
Merge branch 'master' into averager-libp2p
justheuristic Jul 28, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions hivemind/averaging/allreduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from hivemind.averaging.partition import TensorPartContainer, TensorPartReducer, AllreduceException
from hivemind.p2p import P2P, P2PContext, PeerID as Endpoint, ServicerBase, StubBase
from hivemind.utils import get_logger
from hivemind.utils.asyncio import anext, achain, aiter, aenumerate, amap_in_executor
from hivemind.utils.asyncio import anext, achain, aiter, aenumerate, amap_in_executor, asingle
from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
from hivemind.proto import averaging_pb2

Expand Down Expand Up @@ -231,9 +231,9 @@ async def _accumulate_parts_streaming(self, stream: AsyncIterator[averaging_pb2.
yield averaging_pb2.AveragingData(code=averaging_pb2.AVERAGED_PART, tensor_part=serialized_delta)

async def _send_error_to_peer(self, peer_endpoint: Endpoint, code: averaging_pb2.MessageCode):
# In case of reporting the error, we expect the response stream to contain exactly one item
error = averaging_pb2.AveragingData(group_id=self.group_id, endpoint=self.endpoint.to_base58(), code=code)
async for _ in self._get_peer_stub(peer_endpoint).rpc_aggregate_part(aiter(error)):
pass
await asingle(self._get_peer_stub(peer_endpoint).rpc_aggregate_part(aiter(error)))

def finalize(self, *, cancel: bool = False, exception: Optional[BaseException] = None):
"""finish or terminate AllReduceRunner, propagate any errors / cancellations to peers."""
Expand Down
20 changes: 3 additions & 17 deletions hivemind/p2p/p2p_daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo, StreamInfo
from hivemind.p2p.p2p_daemon_bindings.utils import ControlFailure
from hivemind.proto.p2pd_pb2 import RPCError
from hivemind.utils.asyncio import aiter
from hivemind.utils.asyncio import aiter, asingle
from hivemind.utils.logging import get_logger

logger = get_logger(__name__)
Expand Down Expand Up @@ -401,15 +401,7 @@ async def add_protobuf_handler(
"""

async def _stream_handler(requests: P2P.TInputStream, context: P2PContext) -> P2P.TOutputStream:
if stream_input:
input = requests
else:
count = 0
async for input in requests:
count += 1
if count != 1:
raise ValueError(f"Got {count} requests for handler {name} instead of one")

input = requests if stream_input else await asingle(requests)
output = handler(input, context)

if isinstance(output, AsyncIterableABC):
Expand All @@ -429,13 +421,7 @@ async def call_protobuf_handler(
) -> Awaitable[TOutputProtobuf]:
requests = input if isinstance(input, AsyncIterableABC) else aiter(input)
responses = self._iterate_protobuf_stream_handler(peer_id, name, requests, output_protobuf_type)

count = 0
async for response in responses:
count += 1
if count != 1:
raise ValueError(f"Got {count} responses from handler {name} instead of one")
return response
return await asingle(responses)

def iterate_protobuf_handler(
self,
Expand Down
10 changes: 10 additions & 0 deletions hivemind/utils/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,16 @@ async def aenumerate(aiterable: AsyncIterable[T]) -> AsyncIterable[Tuple[int, T]
index += 1


async def asingle(aiter: AsyncIterable[T]) -> T:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name is inspired by Single() from LINQ (.NET functional programming functions).

"""If ``aiter`` has exactly one item, returns this item. Otherwise, raises `ValueError`."""
count = 0
async for item in aiter:
count += 1
if count != 1:
raise ValueError(f"Iterable has {count} items instead of one (as expected)")
return item


async def await_cancelled(awaitable: Awaitable) -> bool:
try:
await awaitable
Expand Down