Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor MPFuture to use a single pipe/thread per process #298

Merged
merged 80 commits into from
Jul 3, 2021
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
80 commits
Select commit Hold shift + click to select a range
2c57843
wip
justheuristic Jun 28, 2021
6918249
Merge branch 'master' into a-better-future
justheuristic Jun 30, 2021
1c04d87
working MPFuture prototype
justheuristic Jun 30, 2021
6e73f92
pep8
justheuristic Jun 30, 2021
3a01b30
review
borzunov Jun 30, 2021
63c9652
review
justheuristic Jun 30, 2021
38e0815
go-deeper test, pytorch version for now
justheuristic Jun 30, 2021
dc83f61
partially transfer to new MPFuture
justheuristic Jun 30, 2021
70193f0
partially transfer to new MPFuture
justheuristic Jun 30, 2021
b4d901f
partially transfer to new MPFuture
justheuristic Jun 30, 2021
e64fbb6
py37 compatibility
justheuristic Jun 30, 2021
f408a7a
edge cases
justheuristic Jun 30, 2021
1998174
edge cases
justheuristic Jul 1, 2021
f2fc224
WIP
justheuristic Jul 1, 2021
31e83bc
refactor global variables as class variables, make results immediatel…
justheuristic Jul 1, 2021
0d825bd
enum-based message type
justheuristic Jul 1, 2021
40cef93
sync set event
justheuristic Jul 1, 2021
d5c2005
set event threadsafe
justheuristic Jul 1, 2021
4751c2b
clarify docstr, rm global lock
justheuristic Jul 1, 2021
061e1b8
Merge branch 'master' into a-better-future
justheuristic Jul 1, 2021
8de6f8f
review
justheuristic Jul 1, 2021
5823cf3
WIP
justheuristic Jul 2, 2021
ee77e0b
test done callback
justheuristic Jul 2, 2021
c781b5b
test callback
justheuristic Jul 2, 2021
b2f6eb5
relock
justheuristic Jul 2, 2021
66a55cc
move file limit to test_utils
justheuristic Jul 2, 2021
28176ea
fix tests
justheuristic Jul 2, 2021
945b5c2
better callback
justheuristic Jul 2, 2021
8057b9d
Merge branch 'master' into a-better-future
justheuristic Jul 2, 2021
c176db3
Update hivemind/utils/mpfuture.py
justheuristic Jul 2, 2021
7f02f94
better callback
justheuristic Jul 2, 2021
58626fe
Merge remote-tracking branch 'origin/a-better-future' into a-better-f…
justheuristic Jul 2, 2021
e3dce48
TEST WIP
justheuristic Jul 2, 2021
bb73cbe
rollback
justheuristic Jul 2, 2021
802feb2
rollback
borzunov Jul 2, 2021
8a2ad22
rollback
borzunov Jul 2, 2021
d24586b
\n
justheuristic Jul 2, 2021
2749bdc
pattern
borzunov Jul 2, 2021
4628199
review
borzunov Jul 2, 2021
f4683c0
erase all mentions of HIVEMIND_THREADS from history
justheuristic Jul 2, 2021
5965225
Update tests/test_util_modules.py
justheuristic Jul 2, 2021
5abf3ff
rollback
borzunov Jul 2, 2021
14f551e
Merge remote-tracking branch 'origin/a-better-future' into a-better-f…
justheuristic Jul 2, 2021
3205df9
isort
justheuristic Jul 2, 2021
e159717
test_many_futures
justheuristic Jul 2, 2021
9748159
review
borzunov Jul 2, 2021
bdec22c
review
borzunov Jul 2, 2021
59d7ce7
Update hivemind/utils/mpfuture.py
justheuristic Jul 2, 2021
f26e0a8
Apply suggestions from code review
justheuristic Jul 2, 2021
20a480a
[internal changes to MPFuture] (#305)
justheuristic Jul 2, 2021
690bbb1
review
borzunov Jul 2, 2021
5f4f828
fallback
justheuristic Jul 2, 2021
6190b1b
fallback
justheuristic Jul 2, 2021
26d02a8
hard-kill damaged managers
justheuristic Jul 2, 2021
f487354
hard-kill damaged managers
justheuristic Jul 2, 2021
f4b331a
down with the SyncManager
justheuristic Jul 2, 2021
461a7bf
down with the SyncManager
justheuristic Jul 3, 2021
ff7b20d
Update hivemind/utils/mpfuture.py
justheuristic Jul 3, 2021
c48d7c1
Update tests/conftest.py
justheuristic Jul 3, 2021
4ed5a5b
shutdown gracefully
justheuristic Jul 3, 2021
8e7fc0f
Merge remote-tracking branch 'origin/a-better-future' into a-better-f…
justheuristic Jul 3, 2021
72b7444
shutdown gracefully
justheuristic Jul 3, 2021
123bd06
review
mryab Jul 3, 2021
831b3f7
Update hivemind/server/task_pool.py
justheuristic Jul 3, 2021
09464a2
Update hivemind/utils/__init__.py
justheuristic Jul 3, 2021
f492bed
Update hivemind/client/averaging/training.py
justheuristic Jul 3, 2021
0808601
review
justheuristic Jul 3, 2021
3f37eea
Merge remote-tracking branch 'origin/a-better-future' into a-better-f…
justheuristic Jul 3, 2021
c8add97
review
mryab Jul 3, 2021
c9dee07
review
mryab Jul 3, 2021
8de6c70
review
justheuristic Jul 3, 2021
6206233
review
mryab Jul 3, 2021
c8b518b
review
mryab Jul 3, 2021
18d8abf
Update hivemind/utils/mpfuture.py
justheuristic Jul 3, 2021
6e0db1d
review
yhn112 Jul 3, 2021
8401f98
Merge remote-tracking branch 'origin/a-better-future' into a-better-f…
justheuristic Jul 3, 2021
976d689
Update hivemind/utils/mpfuture.py
justheuristic Jul 3, 2021
1cd7fe4
Update hivemind/utils/mpfuture.py
justheuristic Jul 3, 2021
9568979
switch to RuntimeError
justheuristic Jul 3, 2021
cc113c9
Merge remote-tracking branch 'origin/a-better-future' into a-better-f…
justheuristic Jul 3, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions hivemind/server/task_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@

import torch

from hivemind.utils import MPFuture, get_logger, FutureStateError
from hivemind.utils import MPFuture, get_logger
from concurrent.futures import InvalidStateError

logger = get_logger(__name__)
Task = namedtuple("Task", ("future", "args"))
Expand Down Expand Up @@ -127,7 +128,7 @@ def iterate_minibatches(self, *args, **kwargs):
if task.future.set_running_or_notify_cancel():
batch.append(task)
total_size += task_size
except FutureStateError as e:
except InvalidStateError as e:
logger.debug(f"Failed to add task to batch: {task.future} raised {e}")

def run(self, *args, **kwargs):
Expand Down Expand Up @@ -196,7 +197,7 @@ def _pool_output_loop(self, pending_batches: Dict[Any, List[Task]]):
for task, task_outputs in zip(batch_tasks, outputs_per_task):
try:
task.future.set_result(tuple(task_outputs))
except FutureStateError as e:
except InvalidStateError as e:
logger.debug(f"Failed to send task result due to an exception: {e}")

@property
Expand Down
276 changes: 158 additions & 118 deletions hivemind/utils/mpfuture.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,169 +4,209 @@
import concurrent.futures._base as base
import multiprocessing as mp
import multiprocessing.connection
import time
from functools import lru_cache
from typing import Optional, Tuple, Generic, TypeVar
import os
import threading
from typing import Tuple, Generic, TypeVar, Dict, Optional, Any, Callable

from hivemind.utils.threading import run_in_background
import torch
Copy link
Member Author

@justheuristic justheuristic Jun 30, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note on torch: it is indeed weird, but so far we're still not sure how else to implement shared value for py3.7

Options considered:

  • current version (with torch.empty)
  • mp.Value or mp.Event - cannot send to other processes (cannot serialize)
  • using multiprocessing.shared_memory - incompatible with py3.7 (and thus colab & kaggle kernels)
  • using _posixshmem (extra dependency to requirements.txt)
  • using mp.Pipe - back to where we started, will need an extra pipe per each future; too many open files

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess that's alright for now, but please import the two methods/attributes explicitly (probably with a short comment that explains its necessity, like "needed for python 3.7-compatible shared memory")


from hivemind.utils.logging import get_logger

ResultType = TypeVar('ResultType')

logger = get_logger(__name__)

# flavour types
ResultType = TypeVar('ResultType')
PID, UID, State, PipeEnd = int, int, Any, mp.connection.Connection
ALL_STATES = base.PENDING, base.RUNNING, base.FINISHED, base.CANCELLED, base.CANCELLED_AND_NOTIFIED
TERMINAL_STATES = {base.FINISHED, base.CANCELLED, base.CANCELLED_AND_NOTIFIED}

INITIALIZER_LOCK = mp.Lock()
PIPE_WAITER: Optional[threading.Thread] = None
MPFUTURE_PIPES: Dict[PID, Tuple[PipeEnd, PipeEnd]] = mp.Manager().dict()
ACTIVE_FUTURES: Optional[Dict[PID, MPFuture]] = None
ACTIVE_PID: Optional[PID] = None


def _initialize_mpfuture_backend():
global ACTIVE_PID, ACTIVE_FUTURES, PIPE_WAITER
pid = os.getpid()
logger.debug(f"Initializing MPFuture backend for pid {pid}")
assert pid != ACTIVE_PID and pid not in MPFUTURE_PIPES, "already initialized"

with INITIALIZER_LOCK:
ACTIVE_PID, ACTIVE_FUTURES, MPFUTURE_PIPES[pid] = pid, {}, mp.Pipe(duplex=False)
PIPE_WAITER = threading.Thread(target=_process_updates_in_background, name=f'{__name__}.BACKEND', daemon=True)
PIPE_WAITER.start()


def _send_update(pid: PID, uid: UID, message_type: State, payload: Any = None):
pipes = MPFUTURE_PIPES.get(pid)
if pipes:
receiver_pipe, sender_pipe = pipes
sender_pipe.send((uid, message_type, payload))
else:
logger.warning(f"Could not update MPFuture(pid={pid}, uid={uid}): unknown pid.")


def _process_updates_in_background():
pid = os.getpid()
receiver_pipe, sender_pipe = MPFUTURE_PIPES[pid]
while True:
try:
uid, message_type, payload = receiver_pipe.recv()
if uid not in ACTIVE_FUTURES:
logger.debug(f"Ignoring update to future with uid={uid}: the future is no longer active.")
elif message_type == Exception:
base.Future.set_exception(ACTIVE_FUTURES[uid], payload)
elif message_type == base.FINISHED:
base.Future.set_result(ACTIVE_FUTURES[uid], payload)
elif message_type == base.CANCELLED:
base.Future.cancel(ACTIVE_FUTURES[uid])
else:
raise ValueError(f"Unexpected message type {message_type}")

class FutureStateError(RuntimeError):
"""Raised when attempting to change state of a future in a terminal state (e.g. finished)"""
pass
except BrokenPipeError:
logger.debug(f"MPFuture backend was shut down (pid={pid}).")
except Exception as e:
logger.warning(f"Internal error (type={e}, pid={pid}): could not retrieve update for MPFuture.")
logger.exception(e)


class MPFuture(base.Future, Generic[ResultType]):
""" Multiprocessing version of concurrent.futures.Future. Can also be awaited like asyncio.Future """
"""
Multiprocessing-aware version of concurrent.futures.Future / asyncio.Future.
Any process can access future status and set the result / exception. However, only the
original process (i.e. the process that created the future) can retrieve the result or exception.

TERMINAL_STATES = {base.FINISHED, base.CANCELLED, base.CANCELLED_AND_NOTIFIED}
This primitive works between processes created through inheritance (e.g. fork), *not* for arbitrary processes.
For independently spawned processes, please instead use mp.Pipe / mp.connection.Connection.
"""

def __init__(self, connection: mp.connection.Connection):
""" manually create MPFuture. Please use MPFuture.make_pair instead """
def __init__(self, loop: Optional[asyncio.BaseEventLoop] = None):
self._shared_state_code = torch.empty([], dtype=torch.uint8).share_memory_()
self._state, self._result, self._exception = base.PENDING, None, None
self.connection = connection

@classmethod
def make_pair(cls) -> Tuple[MPFuture, MPFuture]:
""" Create a pair of linked futures to be used in two processes """
connection1, connection2 = mp.Pipe()
return cls(connection1), cls(connection2)
self._origin_pid, self._uid = os.getpid(), id(self)
# note: self._uid is only unique inside process that spawned it
super().__init__()
if ACTIVE_PID != self._origin_pid:
_initialize_mpfuture_backend()
ACTIVE_FUTURES[self._uid] = self

def _send_updates(self):
""" Send updates to a paired MPFuture """
try:
self.connection.send((self._state, self._result, self._exception))
if self._state in self.TERMINAL_STATES:
self._shutdown_trigger.set_result(True)
self.connection.close()
return True
except BrokenPipeError:
return False
self._loop = loop or asyncio.get_event_loop()
self._aio_event = asyncio.Event()
except RuntimeError:
self._loop = self._aio_event = None

def _recv_updates(self, timeout: Optional[float]):
""" Await updates from a paired MPFuture """
try:
future = base.wait([run_in_background(self.connection.poll, timeout), self._shutdown_trigger],
return_when=base.FIRST_COMPLETED)[0].pop()
if future is self._shutdown_trigger:
raise BrokenPipeError()
if not future.result():
raise TimeoutError()
self._state, result, exception = self.connection.recv()
self._result = result if result is not None else self._result
self._exception = exception if exception is not None else self._exception
if self._state in self.TERMINAL_STATES:
self.connection.close()
except TimeoutError as e:
raise e
except (BrokenPipeError, OSError, EOFError) as e:
if self._state in (base.PENDING, base.RUNNING):
self._state, self._exception = base.FINISHED, e

def _await_terminal_state(self, timeout: Optional[float]):
""" Await updates until future is either finished, cancelled or got an exception """
time_left = float('inf') if timeout is None else timeout
time_before = time.monotonic()
while self._state not in self.TERMINAL_STATES and time_left > 0:
self._recv_updates(time_left if timeout else None)
time_spent = time.monotonic() - time_before
time_left, time_before = time_left - time_spent, time_before + time_spent

def _sync_updates(self):
""" Apply queued updates from a paired MPFuture without waiting for new ones """
try:
self._recv_updates(timeout=0)
except TimeoutError:
pass
@property
def _state(self) -> State:
return ALL_STATES[self._shared_state_code.item()]

@_state.setter
def _state(self, new_state):
self._shared_state_code[...] = ALL_STATES.index(new_state)
if self._state in TERMINAL_STATES and self._aio_event is not None:
asyncio.run_coroutine_threadsafe(self._set_event(), self.get_loop())

async def _set_event(self):
self._aio_event.set()

def set_result(self, result: ResultType):
self._sync_updates()
if self._state in self.TERMINAL_STATES:
raise FutureStateError(f"Can't set_result to a future that is {self._state} ({self})")
self._state, self._result = base.FINISHED, result
return self._send_updates()
if os.getpid() == self._origin_pid:
super().set_result(result)
elif self._state in TERMINAL_STATES:
raise base.InvalidStateError(f"Can't set_result to a future that is {self._state} ({self})")
else:
_send_update(self._origin_pid, self._uid, base.FINISHED, result)

def set_exception(self, exception: BaseException):
self._sync_updates()
if self._state in self.TERMINAL_STATES:
raise FutureStateError(f"Can't set_exception to a future that is {self._state} ({self})")
self._state, self._exception = base.FINISHED, exception
self._send_updates()
if os.getpid() == self._origin_pid:
super().set_exception(exception)
elif self._state in TERMINAL_STATES:
raise base.InvalidStateError(f"Can't set_exception to a future that is {self._state} ({self})")
else:
_send_update(self._origin_pid, self._uid, Exception, exception)

def set_running_or_notify_cancel(self):
self._sync_updates()
if self._state == base.PENDING:
self._state = base.RUNNING
return self._send_updates()
return True
elif self._state == base.CANCELLED:
return False
else:
raise FutureStateError(f"Can't set_running_or_notify_cancel to a future that is in {self._state} ({self})")
raise base.InvalidStateError(f"Can't set_running_or_notify_cancel when future is in {self._state} ({self})")

def cancel(self):
self._sync_updates()
if self._state in self.TERMINAL_STATES:
if os.getpid() == self._origin_pid:
return super().cancel()
elif self._state in TERMINAL_STATES:
return False
self._state, self._exception = base.CANCELLED, base.CancelledError()
return self._send_updates()
else:
_send_update(self._origin_pid, self._uid, base.CANCELLED)
return True

def result(self, timeout: Optional[float] = None) -> ResultType:
self._await_terminal_state(timeout)
if self._exception is not None:
if self._state not in TERMINAL_STATES:
assert os.getpid() == self._origin_pid, "only the process that created MPFuture can await result."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I'd suggest to remove dots in the end of the message and capitalize the first letter, so the message style is consistent across the library :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please move to #275

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with @borzunov: since this is an addition to the codebase, there are no reasons why this shouldn't be consistent with it even now without library-wide changes

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed, fixed

return super().result(timeout)
elif self._state == base.CANCELLED:
raise base.CancelledError()
elif self._exception:
raise self._exception
return self._result
else:
return self._result

def exception(self, timeout=None) -> BaseException:
self._await_terminal_state(timeout)
if self._state == base.CANCELLED:
def exception(self, timeout: Optional[float] = None) -> BaseException:
if self._state not in TERMINAL_STATES:
assert os.getpid() == self._origin_pid, "only the process that created MPFuture can await exception."
return super().exception(timeout)
elif self._state == base.CANCELLED:
raise base.CancelledError()
return self._exception

def done(self) -> bool:
self._sync_updates()
return self._state in self.TERMINAL_STATES
return self._state in TERMINAL_STATES

def running(self):
self._sync_updates()
return self._state == base.RUNNING

def cancelled(self):
self._sync_updates()
return self._state == base.CANCELLED

def add_done_callback(self, callback):
raise NotImplementedError(f"MPFuture doesn't support callbacks.")
def add_done_callback(self, callback: Callable):
assert os.getpid() == self._origin_pid, "only the process that created MPFuture can set callbacks."
return super().add_done_callback(callback)

def remove_done_callback(self, callback):
raise NotImplementedError(f"MPFuture doesn't support callbacks.")
def remove_done_callback(self, callback: Callable):
assert os.getpid() == self._origin_pid, "only the process that created MPFuture can set callbacks."
return super().add_done_callback(callback)

def get_loop(self):
raise NotImplementedError(f"MPFuture doesn't support get_loop")

@property
@lru_cache()
def _shutdown_trigger(self):
return base.Future()

def __repr__(self):
self._sync_updates()
if self._state == base.FINISHED:
if self._exception:
return "<MPFuture at 0x{:x} state=finished raised {}>".format(id(self), type(self._exception))
else:
return "<MPFuture at 0x{:x} state=finished returned {}>".format(id(self), type(self._result))
else:
return "<MPFuture at 0x{:x} state={}>".format(id(self), self._state)
def get_loop(self) -> Optional[asyncio.BaseEventLoop]:
return self._loop

def __await__(self):
yield from asyncio.get_running_loop().run_in_executor(None, self._await_terminal_state, None).__await__()
if self._exception:
raise self._exception
return self._result
if not self._aio_event:
raise RuntimeError("Can't await: MPFuture was created with no event loop.")
yield from self._aio_event.wait().__await__()
try:
return super().result(timeout=0)
except base.CancelledError:
raise asyncio.CancelledError()

def __del__(self):
self._shutdown_trigger.set_result(True)
if hasattr(self, 'connection'):
self.connection.close()
if self._aio_event:
self._aio_event.set()
del ACTIVE_FUTURES[self._uid]

def __getstate__(self):
return dict(_shared_state_code=self._shared_state_code,
_origin_pid=self._origin_pid, _uid=self._uid,
_result=self._result, _exception=self._exception)

def __setstate__(self, state):
self._shared_state_code = state['_shared_state_code']
self._origin_pid, self._uid = state['_origin_pid'], state['_uid']
self._result, self._exception = state['_result'], state['_exception']
self._waiters, self._done_callbacks = [], []
self._condition = threading.Condition()
Loading