Skip to content

Commit

Permalink
Fix StepControl.cancel() in DecentralizedAverager (#411)
Browse files Browse the repository at this point in the history
Previously, cancelling StepControl would result in infinite recursion. This PR fixes this and implements additional features to make sure that matchmaking will be cancelled at any point if user calls step_control.cancel().

Co-authored-by: Aleksandr Borzunov <[email protected]>
  • Loading branch information
justheuristic and borzunov authored Nov 18, 2021
1 parent 09e34f8 commit 7a79aea
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 17 deletions.
34 changes: 23 additions & 11 deletions hivemind/averaging/averager.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,25 +377,35 @@ def step(
data_for_gather=data_for_gather,
)

future_for_trigger = MPFuture()
self._outer_pipe.send(("_step", [], dict(step=step, future_for_trigger=future_for_trigger)))
step.attach_trigger(future_for_trigger.result())
future_for_init = MPFuture()
self._outer_pipe.send(("_step", [], dict(step=step, future_for_init=future_for_init)))
step.attach(*future_for_init.result())

if not require_trigger:
step.allow_allreduce()
return step.result() if wait else step

async def _step(self, *, step: StepControl, future_for_trigger: MPFuture):
async def _step(self, *, step: StepControl, future_for_init: MPFuture):
try:
trigger = MPFuture()
step.attach_trigger(trigger)
future_for_trigger.set_result(trigger)
trigger, cancel = MPFuture(), MPFuture()
step.attach(trigger, cancel)
future_for_init.set_result((trigger, cancel))

while not step.done():
try:
self._pending_group_assembled.clear()
step.stage = AveragingStage.LOOKING_FOR_GROUP
group_info = await self._matchmaking.look_for_group(step)
matchmaking_task = asyncio.create_task(self._matchmaking.look_for_group(step))
check_cancel_task = asyncio.create_task(step.wait_for_cancel())

await asyncio.wait({matchmaking_task, check_cancel_task}, return_when=asyncio.FIRST_COMPLETED)
if step.cancelled():
matchmaking_task.cancel()
raise asyncio.CancelledError()
else:
check_cancel_task.cancel()

group_info = await matchmaking_task
if group_info is None:
raise AllreduceException("Averaging step failed: could not find a group.")

Expand Down Expand Up @@ -424,9 +434,11 @@ async def _step(self, *, step: StepControl, future_for_trigger: MPFuture):
asyncio.InvalidStateError,
P2PHandlerError,
) as e:
if not step.allow_retries or get_dht_time() >= step.deadline:
logger.exception(e)
step.set_exception(e)
if step.done() or not step.allow_retries or get_dht_time() >= step.deadline:
if not step.cancelled():
logger.exception(e)
if not step.done():
step.set_exception(e)
else:
logger.warning(f"{self.__class__.__name__} caught {repr(e)}, retrying")

Expand Down
20 changes: 14 additions & 6 deletions hivemind/averaging/control.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(
super().__init__()
self._data_for_gather, self._deadline, self._allow_retries = data_for_gather, deadline, allow_retries
self._trigger: Optional[MPFuture] = None
self._cancel: Optional[MPFuture] = None

# Buffer contents:
# scheduled_time (double) | weight (double) | stage (AveragingStage, 1 byte) | began_allreduce: (bool, 1 byte)
Expand All @@ -52,12 +53,12 @@ def __init__(
self.weight = weight
self.began_allreduce = False

def attach_trigger(self, trigger: MPFuture):
assert self._trigger is None, "Trigger is already attached"
self._trigger = trigger
def attach(self, trigger: MPFuture, cancel: MPFuture):
assert self._trigger is None and self._cancel is None, "Futures are already attached"
self._trigger, self._cancel = trigger, cancel

def allow_allreduce(self):
"""Allow averager to begin allreduce when it finds a group. Meant to be triggered by user."""
"""Allow averager to begin all-reduce when it finds a group. Meant to be triggered by user."""
assert self._trigger is not None, "StepControl does not have an attached trigger"
if self._trigger.done():
logger.warning("Trigger is already set")
Expand Down Expand Up @@ -133,16 +134,23 @@ def __getstate__(self):
return dict(
super().__getstate__(),
_trigger=self._trigger,
_cancel=self._cancel,
_shared_buffer=self._shared_buffer,
immutable_params=(self._data_for_gather, self._deadline, self._allow_retries),
)

def __setstate__(self, state):
super().__setstate__(state)
self._trigger, self._shared_buffer = state["_trigger"], state["_shared_buffer"]
self._trigger, self._cancel, self._shared_buffer = state["_trigger"], state["_cancel"], state["_shared_buffer"]
self._data_for_gather, self._deadline, self._allow_retries = state["immutable_params"]

def cancel(self) -> bool:
if self._trigger is not None:
self._trigger.cancel()
return self.cancel()
if self._cancel is not None:
self._cancel.set_result(None)
return super().cancel()

async def wait_for_cancel(self):
"""Await for step to be cancelled by the user. Should be called from insider the averager."""
await self._cancel
31 changes: 31 additions & 0 deletions tests/test_averaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,37 @@ def test_averaging_trigger():
c0.allow_allreduce()


@pytest.mark.forked
def test_averaging_cancel():
averagers = tuple(
hivemind.averaging.DecentralizedAverager(
averaged_tensors=[torch.randn(3)],
dht=dht,
min_matchmaking_time=0.5,
request_timeout=0.3,
client_mode=(i % 2 == 0),
prefix="mygroup",
start=True,
)
for i, dht in enumerate(launch_dht_instances(4))
)

step_controls = [averager.step(wait=False, scheduled_time=hivemind.get_dht_time() + 1) for averager in averagers]

time.sleep(0.2)
step_controls[0].cancel()
step_controls[1].cancel()

for i, control in enumerate(step_controls):
if i in (0, 1):
assert control.cancelled()
else:
assert control.result() is not None and len(control.result()) == 2

for averager in averagers:
averager.shutdown()


@pytest.mark.forked
def test_training_averager(n_steps: int = 10, n_dims: int = 16):
torch.manual_seed(42)
Expand Down

0 comments on commit 7a79aea

Please sign in to comment.