Skip to content

Commit

Permalink
Implement state sharing priority (#415)
Browse files Browse the repository at this point in the history
Previously, peers would download state from an (effectively) random existing peer.
As a result, sometimes if many peers got out of sync, they would re-download state from another out-of-sync peer for many times before they got it right.

This PR prioritizes downloading state from "latest" peers, depending on the optimizer.

Note: an intermediate version of this PR used a multiprocessing.Event to trigger re-declaring priority to the DHT. However, this turned out to **harm test performance for py39**. The slowdown was caused specifically by this one line: `await loop.run_in_executor(None, mp_event.wait, timeout_here)` in _declare_for_download_periodically. Other python versions are not affected. I have no idea why py39 reacts like this, but i ended up switching away from MP events just in case.

Co-authored-by: Aleksandr Borzunov <[email protected]>
  • Loading branch information
justheuristic and borzunov authored Nov 25, 2021
1 parent 50ab48e commit 00538db
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 11 deletions.
46 changes: 36 additions & 10 deletions hivemind/averaging/averager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import ctypes
import multiprocessing as mp
import os
import random
import threading
import weakref
from dataclasses import asdict
Expand Down Expand Up @@ -164,7 +165,6 @@ def __init__(

self._averaged_tensors = tuple(averaged_tensors)
self.lock_averaged_tensors = mp.Lock()
self.last_updated: DHTExpiration = -float("inf")
for tensor in self._averaged_tensors:
assert tensor.grad_fn is None, "averaged_tensors must be either parameters or leaf tensors"
tensor.share_memory_()
Expand Down Expand Up @@ -193,6 +193,8 @@ def __init__(
self._inner_pipe, self._outer_pipe = mp.Pipe(duplex=True) # a control pipe used to communicate with daemon

self._allow_state_sharing = mp.Value(ctypes.c_bool, 0)
self._state_sharing_priority = mp.Value(ctypes.c_double, 0)

if allow_state_sharing is None:
allow_state_sharing = not client_mode and not auxiliary
self.allow_state_sharing = allow_state_sharing
Expand Down Expand Up @@ -221,7 +223,27 @@ def allow_state_sharing(self, value: bool):
if value and self.client_mode:
raise ValueError("Cannot allow state sharing: averager in client mode cannot share its state.")
else:
self._allow_state_sharing.value = value
old_value, self._allow_state_sharing.value = self._allow_state_sharing.value, value
if value != old_value:
self._outer_pipe.send(("_trigger_declare_load_state", [], {}))

@property
def state_sharing_priority(self) -> float:
"""Others will preferentially downloading state from peers with highest priority."""
return float(self._state_sharing_priority.value)

@state_sharing_priority.setter
def state_sharing_priority(self, value: float):
if value and self.client_mode:
raise ValueError("State sharing priority is unused: averager in client mode cannot share its state.")
else:
old_value, self._state_sharing_priority.value = self._state_sharing_priority.value, value
if self.allow_state_sharing and value != old_value:
self._outer_pipe.send(("_trigger_declare_load_state", [], {}))

async def _trigger_declare_load_state(self):
# note: previously tried to set mp.Event instead of this. Awaiting it in executor caused degradation in py39
self._state_updated.set()

@property
def peer_id(self) -> PeerID:
Expand Down Expand Up @@ -490,7 +512,6 @@ async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kw
async for tensor, update in azip(as_aiter(*local_tensors), allreduce):
# all-reduce is performed asynchronously while iterating
tensor.add_(update, alpha=self._averaging_alpha)
self.last_updated = get_dht_time()
self._state_updated.set()

else:
Expand Down Expand Up @@ -550,24 +571,29 @@ async def rpc_aggregate_part(

async def _declare_for_download_periodically(self):
download_key = f"{self._matchmaking.group_key_manager.prefix}.all_averagers"
sharing_was_allowed = self.allow_state_sharing
while True:
if self.allow_state_sharing:
self._state_updated.clear()
expiration_time = get_dht_time() + self.declare_state_period
expiration_time = get_dht_time() + self.declare_state_period
if self.allow_state_sharing or sharing_was_allowed:
# notify either if sharing is allowed or if it was just switched off (to overwrite previous message)
asyncio.create_task(
asyncio.wait_for(
self.dht.store(
download_key,
subkey=self.peer_id.to_bytes(),
value=self.last_updated,
value=self.state_sharing_priority if self.allow_state_sharing else None,
expiration_time=expiration_time,
return_future=True,
),
timeout=expiration_time - self.request_timeout,
timeout=expiration_time - get_dht_time(),
)
)
sharing_was_allowed = self.allow_state_sharing

# report again either in state_declare_period or after the field was changed by the user
self._state_updated.clear()
try:
await asyncio.wait_for(self._state_updated.wait(), self.declare_state_period - self.request_timeout)
await asyncio.wait_for(self._state_updated.wait(), timeout=max(0.0, expiration_time - get_dht_time()))
except asyncio.TimeoutError:
pass

Expand Down Expand Up @@ -632,7 +658,7 @@ async def _load_state_from_peers(self, future: MPFuture, timeout: Optional[float
key_manager = self._matchmaking.group_key_manager
peer_priority, _ = self.dht.get(f"{key_manager.prefix}.all_averagers", latest=True) or ({}, None)
peer_priority = {
PeerID(peer_id): float(info.value)
PeerID(peer_id): (float(info.value), random.random()) # using randomness as a tie breaker
for peer_id, info in peer_priority.items()
if isinstance(info, ValueWithExpiration) and isinstance(info.value, (float, int))
}
Expand Down
4 changes: 4 additions & 0 deletions hivemind/optim/collaborative.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,10 +320,14 @@ def step(self, batch_size: Optional[int] = None, grad_scaler: Optional[HivemindG
self.averager.local_step = current_step + 1
self.collaboration_state_updated.set()
self.update_scheduler()

if grad_scaler is not None:
with grad_scaler.running_global_step():
assert grad_scaler.update()

if not self.averager.client_mode:
self.averager.state_sharing_priority = self.local_step

logger.log(self.status_loglevel, f"Optimizer step: done!")

return group_info
Expand Down
4 changes: 4 additions & 0 deletions hivemind/optim/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ def step(self, *args, **kwargs):
if self.local_step % self.averaging_step_period == 0:
self.update_event.set()
self.averager.pending_updates_done.wait()

if not self.averager.client_mode:
self.averager.state_sharing_priority = get_dht_time()

return loss
finally:
self.lock_parameters.acquire()
Expand Down
46 changes: 45 additions & 1 deletion tests/test_averaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,6 @@ def get_current_state(self):
target_group_size=2,
)

dht_instances[1].get("demo-run.all_averagers")
averager2 = TestAverager(
[torch.randn(3), torch.rand(5)],
dht=dht_instances[1],
Expand All @@ -381,6 +380,8 @@ def get_current_state(self):
target_group_size=2,
)

time.sleep(0.5)

assert num_calls == 0
got_metadata, got_tensors = averager2.load_state_from_peers()
assert num_calls == 1
Expand All @@ -399,7 +400,9 @@ def get_current_state(self):

averager1.allow_state_sharing = False
assert averager2.load_state_from_peers() is None

averager1.allow_state_sharing = True
time.sleep(0.5)
got_metadata, got_tensors = averager2.load_state_from_peers()
assert num_calls == 3
assert got_metadata == super_metadata
Expand All @@ -408,6 +411,47 @@ def get_current_state(self):
instance.shutdown()


@pytest.mark.forked
def test_load_state_priority():
dht_instances = launch_dht_instances(4)

averagers = []
for i in range(4):
averager = hivemind.DecentralizedAverager(
[torch.randn(3), torch.rand(5), torch.tensor([i], dtype=torch.float32)],
dht=dht_instances[i],
start=True,
prefix="demo-run",
target_group_size=2,
allow_state_sharing=i != 1,
)
averager.state_sharing_priority = 5 - abs(2 - i)
averagers.append(averager)

time.sleep(0.5)
metadata, tensors = averagers[0].load_state_from_peers(timeout=1)
assert tensors[-1].item() == 2

metadata, tensors = averagers[2].load_state_from_peers(timeout=1)
assert tensors[-1].item() == 3

averagers[0].state_sharing_priority = 10
time.sleep(0.2)

metadata, tensors = averagers[2].load_state_from_peers(timeout=1)
assert tensors[-1].item() == 0

averagers[1].allow_state_sharing = False
averagers[2].allow_state_sharing = False
metadata, tensors = averagers[0].load_state_from_peers(timeout=1)
assert tensors[-1].item() == 3

for averager in averagers:
averager.shutdown()
for dht in dht_instances:
dht.shutdown()


@pytest.mark.forked
def test_getset_bits():
dht = hivemind.DHT(start=True)
Expand Down

0 comments on commit 00538db

Please sign in to comment.