-
Notifications
You must be signed in to change notification settings - Fork 176
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor MPFuture to use a single pipe/thread per process #298
Changes from 7 commits
2c57843
6918249
1c04d87
6e73f92
3a01b30
63c9652
38e0815
dc83f61
70193f0
b4d901f
e64fbb6
f408a7a
1998174
f2fc224
31e83bc
0d825bd
40cef93
d5c2005
4751c2b
061e1b8
8de6f8f
5823cf3
ee77e0b
c781b5b
b2f6eb5
66a55cc
28176ea
945b5c2
8057b9d
c176db3
7f02f94
58626fe
e3dce48
bb73cbe
802feb2
8a2ad22
d24586b
2749bdc
4628199
f4683c0
5965225
5abf3ff
14f551e
3205df9
e159717
9748159
bdec22c
59d7ce7
f26e0a8
20a480a
690bbb1
5f4f828
6190b1b
26d02a8
f487354
f4b331a
461a7bf
ff7b20d
c48d7c1
4ed5a5b
8e7fc0f
72b7444
123bd06
831b3f7
09464a2
f492bed
0808601
3f37eea
c8add97
c9dee07
8de6c70
6206233
c8b518b
18d8abf
6e0db1d
8401f98
976d689
1cd7fe4
9568979
cc113c9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,169 +4,209 @@ | |
import concurrent.futures._base as base | ||
import multiprocessing as mp | ||
import multiprocessing.connection | ||
import time | ||
from functools import lru_cache | ||
from typing import Optional, Tuple, Generic, TypeVar | ||
import os | ||
import threading | ||
from typing import Tuple, Generic, TypeVar, Dict, Optional, Any, Callable | ||
|
||
from hivemind.utils.threading import run_in_background | ||
import torch | ||
|
||
from hivemind.utils.logging import get_logger | ||
|
||
ResultType = TypeVar('ResultType') | ||
|
||
logger = get_logger(__name__) | ||
|
||
# flavour types | ||
ResultType = TypeVar('ResultType') | ||
PID, UID, State, PipeEnd = int, int, Any, mp.connection.Connection | ||
justheuristic marked this conversation as resolved.
Show resolved
Hide resolved
|
||
ALL_STATES = base.PENDING, base.RUNNING, base.FINISHED, base.CANCELLED, base.CANCELLED_AND_NOTIFIED | ||
mryab marked this conversation as resolved.
Show resolved
Hide resolved
|
||
TERMINAL_STATES = {base.FINISHED, base.CANCELLED, base.CANCELLED_AND_NOTIFIED} | ||
|
||
INITIALIZER_LOCK = mp.Lock() | ||
PIPE_WAITER: Optional[threading.Thread] = None | ||
MPFUTURE_PIPES: Dict[PID, Tuple[PipeEnd, PipeEnd]] = mp.Manager().dict() | ||
ACTIVE_FUTURES: Optional[Dict[PID, MPFuture]] = None | ||
ACTIVE_PID: Optional[PID] = None | ||
borzunov marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
def _initialize_mpfuture_backend(): | ||
global ACTIVE_PID, ACTIVE_FUTURES, PIPE_WAITER | ||
pid = os.getpid() | ||
logger.debug(f"Initializing MPFuture backend for pid {pid}") | ||
assert pid != ACTIVE_PID and pid not in MPFUTURE_PIPES, "already initialized" | ||
|
||
with INITIALIZER_LOCK: | ||
ACTIVE_PID, ACTIVE_FUTURES, MPFUTURE_PIPES[pid] = pid, {}, mp.Pipe(duplex=False) | ||
PIPE_WAITER = threading.Thread(target=_process_updates_in_background, name=f'{__name__}.BACKEND', daemon=True) | ||
PIPE_WAITER.start() | ||
|
||
|
||
def _send_update(pid: PID, uid: UID, message_type: State, payload: Any = None): | ||
pipes = MPFUTURE_PIPES.get(pid) | ||
if pipes: | ||
receiver_pipe, sender_pipe = pipes | ||
sender_pipe.send((uid, message_type, payload)) | ||
else: | ||
logger.warning(f"Could not update MPFuture(pid={pid}, uid={uid}): unknown pid.") | ||
|
||
|
||
def _process_updates_in_background(): | ||
pid = os.getpid() | ||
receiver_pipe, sender_pipe = MPFUTURE_PIPES[pid] | ||
while True: | ||
try: | ||
uid, message_type, payload = receiver_pipe.recv() | ||
if uid not in ACTIVE_FUTURES: | ||
logger.debug(f"Ignoring update to future with uid={uid}: the future is no longer active.") | ||
elif message_type == Exception: | ||
base.Future.set_exception(ACTIVE_FUTURES[uid], payload) | ||
elif message_type == base.FINISHED: | ||
base.Future.set_result(ACTIVE_FUTURES[uid], payload) | ||
elif message_type == base.CANCELLED: | ||
base.Future.cancel(ACTIVE_FUTURES[uid]) | ||
else: | ||
raise ValueError(f"Unexpected message type {message_type}") | ||
|
||
class FutureStateError(RuntimeError): | ||
"""Raised when attempting to change state of a future in a terminal state (e.g. finished)""" | ||
pass | ||
except BrokenPipeError: | ||
logger.debug(f"MPFuture backend was shut down (pid={pid}).") | ||
except Exception as e: | ||
logger.warning(f"Internal error (type={e}, pid={pid}): could not retrieve update for MPFuture.") | ||
logger.exception(e) | ||
|
||
|
||
class MPFuture(base.Future, Generic[ResultType]): | ||
""" Multiprocessing version of concurrent.futures.Future. Can also be awaited like asyncio.Future """ | ||
""" | ||
Multiprocessing-aware version of concurrent.futures.Future / asyncio.Future. | ||
Any process can access future status and set the result / exception. However, only the | ||
original process (i.e. the process that created the future) can retrieve the result or exception. | ||
|
||
TERMINAL_STATES = {base.FINISHED, base.CANCELLED, base.CANCELLED_AND_NOTIFIED} | ||
This primitive works between processes created through inheritance (e.g. fork), *not* for arbitrary processes. | ||
For independently spawned processes, please instead use mp.Pipe / mp.connection.Connection. | ||
""" | ||
|
||
def __init__(self, connection: mp.connection.Connection): | ||
""" manually create MPFuture. Please use MPFuture.make_pair instead """ | ||
def __init__(self, loop: Optional[asyncio.BaseEventLoop] = None): | ||
self._shared_state_code = torch.empty([], dtype=torch.uint8).share_memory_() | ||
self._state, self._result, self._exception = base.PENDING, None, None | ||
self.connection = connection | ||
|
||
@classmethod | ||
def make_pair(cls) -> Tuple[MPFuture, MPFuture]: | ||
""" Create a pair of linked futures to be used in two processes """ | ||
connection1, connection2 = mp.Pipe() | ||
return cls(connection1), cls(connection2) | ||
self._origin_pid, self._uid = os.getpid(), id(self) | ||
# note: self._uid is only unique inside process that spawned it | ||
super().__init__() | ||
if ACTIVE_PID != self._origin_pid: | ||
_initialize_mpfuture_backend() | ||
ACTIVE_FUTURES[self._uid] = self | ||
|
||
def _send_updates(self): | ||
""" Send updates to a paired MPFuture """ | ||
try: | ||
self.connection.send((self._state, self._result, self._exception)) | ||
if self._state in self.TERMINAL_STATES: | ||
self._shutdown_trigger.set_result(True) | ||
self.connection.close() | ||
return True | ||
except BrokenPipeError: | ||
return False | ||
self._loop = loop or asyncio.get_event_loop() | ||
self._aio_event = asyncio.Event() | ||
except RuntimeError: | ||
self._loop = self._aio_event = None | ||
justheuristic marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def _recv_updates(self, timeout: Optional[float]): | ||
""" Await updates from a paired MPFuture """ | ||
try: | ||
future = base.wait([run_in_background(self.connection.poll, timeout), self._shutdown_trigger], | ||
return_when=base.FIRST_COMPLETED)[0].pop() | ||
if future is self._shutdown_trigger: | ||
raise BrokenPipeError() | ||
if not future.result(): | ||
raise TimeoutError() | ||
self._state, result, exception = self.connection.recv() | ||
self._result = result if result is not None else self._result | ||
self._exception = exception if exception is not None else self._exception | ||
if self._state in self.TERMINAL_STATES: | ||
self.connection.close() | ||
except TimeoutError as e: | ||
raise e | ||
except (BrokenPipeError, OSError, EOFError) as e: | ||
if self._state in (base.PENDING, base.RUNNING): | ||
self._state, self._exception = base.FINISHED, e | ||
|
||
def _await_terminal_state(self, timeout: Optional[float]): | ||
""" Await updates until future is either finished, cancelled or got an exception """ | ||
time_left = float('inf') if timeout is None else timeout | ||
time_before = time.monotonic() | ||
while self._state not in self.TERMINAL_STATES and time_left > 0: | ||
self._recv_updates(time_left if timeout else None) | ||
time_spent = time.monotonic() - time_before | ||
time_left, time_before = time_left - time_spent, time_before + time_spent | ||
|
||
def _sync_updates(self): | ||
""" Apply queued updates from a paired MPFuture without waiting for new ones """ | ||
try: | ||
self._recv_updates(timeout=0) | ||
except TimeoutError: | ||
pass | ||
@property | ||
def _state(self) -> State: | ||
return ALL_STATES[self._shared_state_code.item()] | ||
|
||
@_state.setter | ||
def _state(self, new_state): | ||
self._shared_state_code[...] = ALL_STATES.index(new_state) | ||
if self._state in TERMINAL_STATES and self._aio_event is not None: | ||
asyncio.run_coroutine_threadsafe(self._set_event(), self.get_loop()) | ||
|
||
async def _set_event(self): | ||
self._aio_event.set() | ||
|
||
def set_result(self, result: ResultType): | ||
self._sync_updates() | ||
if self._state in self.TERMINAL_STATES: | ||
raise FutureStateError(f"Can't set_result to a future that is {self._state} ({self})") | ||
self._state, self._result = base.FINISHED, result | ||
return self._send_updates() | ||
if os.getpid() == self._origin_pid: | ||
super().set_result(result) | ||
elif self._state in TERMINAL_STATES: | ||
raise base.InvalidStateError(f"Can't set_result to a future that is {self._state} ({self})") | ||
else: | ||
_send_update(self._origin_pid, self._uid, base.FINISHED, result) | ||
|
||
def set_exception(self, exception: BaseException): | ||
justheuristic marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self._sync_updates() | ||
if self._state in self.TERMINAL_STATES: | ||
raise FutureStateError(f"Can't set_exception to a future that is {self._state} ({self})") | ||
self._state, self._exception = base.FINISHED, exception | ||
self._send_updates() | ||
if os.getpid() == self._origin_pid: | ||
super().set_exception(exception) | ||
elif self._state in TERMINAL_STATES: | ||
raise base.InvalidStateError(f"Can't set_exception to a future that is {self._state} ({self})") | ||
else: | ||
_send_update(self._origin_pid, self._uid, Exception, exception) | ||
|
||
def set_running_or_notify_cancel(self): | ||
self._sync_updates() | ||
if self._state == base.PENDING: | ||
self._state = base.RUNNING | ||
return self._send_updates() | ||
return True | ||
elif self._state == base.CANCELLED: | ||
return False | ||
else: | ||
raise FutureStateError(f"Can't set_running_or_notify_cancel to a future that is in {self._state} ({self})") | ||
raise base.InvalidStateError(f"Can't set_running_or_notify_cancel when future is in {self._state} ({self})") | ||
|
||
def cancel(self): | ||
self._sync_updates() | ||
if self._state in self.TERMINAL_STATES: | ||
if os.getpid() == self._origin_pid: | ||
return super().cancel() | ||
elif self._state in TERMINAL_STATES: | ||
return False | ||
self._state, self._exception = base.CANCELLED, base.CancelledError() | ||
return self._send_updates() | ||
else: | ||
_send_update(self._origin_pid, self._uid, base.CANCELLED) | ||
return True | ||
|
||
def result(self, timeout: Optional[float] = None) -> ResultType: | ||
self._await_terminal_state(timeout) | ||
if self._exception is not None: | ||
if self._state not in TERMINAL_STATES: | ||
assert os.getpid() == self._origin_pid, "only the process that created MPFuture can await result." | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: I'd suggest to remove dots in the end of the message and capitalize the first letter, so the message style is consistent across the library :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please move to #275 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agree with @borzunov: since this is an addition to the codebase, there are no reasons why this shouldn't be consistent with it even now without library-wide changes There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. agreed, fixed |
||
return super().result(timeout) | ||
elif self._state == base.CANCELLED: | ||
raise base.CancelledError() | ||
elif self._exception: | ||
raise self._exception | ||
return self._result | ||
else: | ||
return self._result | ||
|
||
def exception(self, timeout=None) -> BaseException: | ||
self._await_terminal_state(timeout) | ||
if self._state == base.CANCELLED: | ||
def exception(self, timeout: Optional[float] = None) -> BaseException: | ||
justheuristic marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if self._state not in TERMINAL_STATES: | ||
assert os.getpid() == self._origin_pid, "only the process that created MPFuture can await exception." | ||
justheuristic marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return super().exception(timeout) | ||
elif self._state == base.CANCELLED: | ||
raise base.CancelledError() | ||
return self._exception | ||
|
||
def done(self) -> bool: | ||
self._sync_updates() | ||
return self._state in self.TERMINAL_STATES | ||
return self._state in TERMINAL_STATES | ||
|
||
def running(self): | ||
self._sync_updates() | ||
return self._state == base.RUNNING | ||
|
||
def cancelled(self): | ||
self._sync_updates() | ||
return self._state == base.CANCELLED | ||
|
||
def add_done_callback(self, callback): | ||
raise NotImplementedError(f"MPFuture doesn't support callbacks.") | ||
def add_done_callback(self, callback: Callable): | ||
assert os.getpid() == self._origin_pid, "only the process that created MPFuture can set callbacks." | ||
return super().add_done_callback(callback) | ||
|
||
def remove_done_callback(self, callback): | ||
raise NotImplementedError(f"MPFuture doesn't support callbacks.") | ||
def remove_done_callback(self, callback: Callable): | ||
assert os.getpid() == self._origin_pid, "only the process that created MPFuture can set callbacks." | ||
return super().add_done_callback(callback) | ||
|
||
def get_loop(self): | ||
raise NotImplementedError(f"MPFuture doesn't support get_loop") | ||
|
||
@property | ||
@lru_cache() | ||
def _shutdown_trigger(self): | ||
return base.Future() | ||
|
||
def __repr__(self): | ||
self._sync_updates() | ||
if self._state == base.FINISHED: | ||
if self._exception: | ||
return "<MPFuture at 0x{:x} state=finished raised {}>".format(id(self), type(self._exception)) | ||
else: | ||
return "<MPFuture at 0x{:x} state=finished returned {}>".format(id(self), type(self._result)) | ||
else: | ||
return "<MPFuture at 0x{:x} state={}>".format(id(self), self._state) | ||
def get_loop(self) -> Optional[asyncio.BaseEventLoop]: | ||
return self._loop | ||
|
||
def __await__(self): | ||
yield from asyncio.get_running_loop().run_in_executor(None, self._await_terminal_state, None).__await__() | ||
if self._exception: | ||
raise self._exception | ||
return self._result | ||
if not self._aio_event: | ||
raise RuntimeError("Can't await: MPFuture was created with no event loop.") | ||
yield from self._aio_event.wait().__await__() | ||
try: | ||
return super().result(timeout=0) | ||
except base.CancelledError: | ||
raise asyncio.CancelledError() | ||
|
||
def __del__(self): | ||
self._shutdown_trigger.set_result(True) | ||
if hasattr(self, 'connection'): | ||
self.connection.close() | ||
if self._aio_event: | ||
self._aio_event.set() | ||
justheuristic marked this conversation as resolved.
Show resolved
Hide resolved
|
||
del ACTIVE_FUTURES[self._uid] | ||
|
||
def __getstate__(self): | ||
return dict(_shared_state_code=self._shared_state_code, | ||
_origin_pid=self._origin_pid, _uid=self._uid, | ||
_result=self._result, _exception=self._exception) | ||
|
||
def __setstate__(self, state): | ||
self._shared_state_code = state['_shared_state_code'] | ||
self._origin_pid, self._uid = state['_origin_pid'], state['_uid'] | ||
self._result, self._exception = state['_result'], state['_exception'] | ||
self._waiters, self._done_callbacks = [], [] | ||
self._condition = threading.Condition() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note on torch: it is indeed weird, but so far we're still not sure how else to implement shared value for py3.7
Options considered:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess that's alright for now, but please import the two methods/attributes explicitly (probably with a short comment that explains its necessity, like "needed for python 3.7-compatible shared memory")