Skip to content

Commit

Permalink
Fix deadlocks in DecentralizedAverager and MPFuture (#331)
Browse files Browse the repository at this point in the history
This PR does the following:

1. Fix a possible deadlock in DecentralizedAverager.rpc_join_group().
2. Fix a possible deadlock related to corrupted MPFuture state after killing child processes.
3. Add -v flag to pytest in CI.

Co-authored-by: justheuristic <[email protected]>
  • Loading branch information
borzunov and justheuristic authored Jul 22, 2021
1 parent 407c201 commit 0d67284
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 5 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/run-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:
- name: Test
run: |
cd tests
pytest --durations=0 --durations-min=1.0
pytest --durations=0 --durations-min=1.0 -v
build_and_test_p2pd:
runs-on: ubuntu-latest
Expand All @@ -60,7 +60,7 @@ jobs:
- name: Test
run: |
cd tests
pytest -k "p2p"
pytest -k "p2p" -v
codecov_in_develop_mode:

Expand All @@ -87,6 +87,6 @@ jobs:
pip install -e .
- name: Test
run: |
pytest --cov=hivemind tests
pytest --cov=hivemind -v tests
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v1
4 changes: 4 additions & 0 deletions hivemind/averaging/matchmaking.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ async def request_join_group(self, leader: Endpoint, expiration_time: DHTExpirat
expiration=expiration_time,
client_mode=self.client_mode,
gather=self.data_for_gather,
group_key=self.group_key_manager.current_key,
)
)
message = await asyncio.wait_for(call.read(), timeout=self.request_timeout)
Expand Down Expand Up @@ -315,11 +316,14 @@ def _check_reasons_to_reject(
or not isinstance(request.endpoint, Endpoint)
or len(request.endpoint) == 0
or self.client_mode
or not isinstance(request.group_key, GroupKey)
):
return averaging_pb2.MessageFromLeader(code=averaging_pb2.PROTOCOL_VIOLATION)

elif request.schema_hash != self.schema_hash:
return averaging_pb2.MessageFromLeader(code=averaging_pb2.BAD_SCHEMA_HASH)
elif request.group_key != self.group_key_manager.current_key:
return averaging_pb2.MessageFromLeader(code=averaging_pb2.BAD_GROUP_KEY)
elif self.potential_leaders.declared_group_key is None:
return averaging_pb2.MessageFromLeader(code=averaging_pb2.NOT_DECLARED)
elif self.potential_leaders.declared_expiration_time > (request.expiration or float("inf")):
Expand Down
2 changes: 2 additions & 0 deletions hivemind/proto/averaging.proto
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ enum MessageCode {
INTERNAL_ERROR = 15; // "I messed up, we will have to stop allreduce because of that."
CANCELLED = 16; // "[from peer during allreduce] I no longer want to participate in AllReduce."
GROUP_DISBANDED = 17; // "[from leader] The group is closed. Go find another group."
BAD_GROUP_KEY = 18; // "I will not accept you. My current group key differs (maybe you used my older key)."
}

message JoinRequest {
Expand All @@ -36,6 +37,7 @@ message JoinRequest {
double expiration = 3; // Follower would like to **begin** all_reduce by this point in time
bytes gather = 4; // optional metadata that is gathered from all peers (e.g. batch size or current loss)
bool client_mode = 5; // if True, the incoming averager is a client with no capacity for averaging
string group_key = 6; // group key identifying an All-Reduce bucket, e.g my_averager.0b011011101
}

message MessageFromLeader {
Expand Down
17 changes: 16 additions & 1 deletion hivemind/utils/mpfuture.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import concurrent.futures._base as base
from contextlib import nullcontext, suppress
import multiprocessing as mp
import multiprocessing.connection
import os
import threading
import uuid
Expand Down Expand Up @@ -127,11 +126,27 @@ def _initialize_backend_if_necessary(cls):
)
cls._pipe_waiter_thread.start()

@classmethod
def reset_backend(cls):
"""
Reset the MPFuture backend. This is useful when the state may have been corrupted
(e.g. killing child processes may leave the locks acquired and the background thread blocked).
This method is neither thread-safe nor process-safe.
"""

cls._initialization_lock = mp.Lock()
cls._update_lock = mp.Lock()
cls._active_pid = None

@classmethod
def _process_updates_in_background(cls, receiver_pipe: mp.connection.Connection):
pid = os.getpid()
while True:
try:
if cls._pipe_waiter_thread is not threading.current_thread():
break # Backend was reset, a new background thread has started

uid, msg_type, payload = receiver_pipe.recv()
future = None
future_ref = cls._active_futures.get(uid)
Expand Down
6 changes: 5 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import psutil
import pytest

from hivemind.utils import get_logger
from hivemind.utils.logging import get_logger
from hivemind.utils.mpfuture import MPFuture


logger = get_logger(__name__)
Expand All @@ -26,3 +27,6 @@ def cleanup_children():
for child in children:
with suppress(psutil.NoSuchProcess):
child.kill()

# Broken code or killing of child processes may leave the MPFuture backend corrupted
MPFuture.reset_backend()

0 comments on commit 0d67284

Please sign in to comment.