Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove shared memory from MPFuture, fix minor bugs #317

Merged
merged 26 commits into from
Jul 13, 2021
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
4248a91
fix "coroutine was never awaited"
justheuristic Jul 11, 2021
c2ce35c
client-server-style MPFuture
justheuristic Jul 13, 2021
a310f26
Merge branch 'master' into make-mpfuture-great-again
justheuristic Jul 13, 2021
5321f7f
re-black
justheuristic Jul 13, 2021
904c808
re-black
justheuristic Jul 13, 2021
34e2d1d
py39 compat
justheuristic Jul 13, 2021
0e38f5d
py39 compat
justheuristic Jul 13, 2021
1849fe2
Merge branch 'master' into make-mpfuture-great-again
justheuristic Jul 13, 2021
668698d
review
borzunov Jul 13, 2021
ba2d0cd
Merge branch 'master' into make-mpfuture-great-again
justheuristic Jul 13, 2021
5d87238
make moe test easier
justheuristic Jul 13, 2021
7652a99
Update hivemind/utils/mpfuture.py
justheuristic Jul 13, 2021
e5fc179
Update hivemind/utils/mpfuture.py
justheuristic Jul 13, 2021
f5600c6
Update hivemind/utils/mpfuture.py
justheuristic Jul 13, 2021
9b563bb
Update hivemind/utils/mpfuture.py
justheuristic Jul 13, 2021
2dcd26d
Merge remote-tracking branch 'origin/make-mpfuture-great-again' into …
justheuristic Jul 13, 2021
bd852ef
make moe test easier
justheuristic Jul 13, 2021
d6a4048
review
borzunov Jul 13, 2021
f95a00a
review
borzunov Jul 13, 2021
d564efc
review
borzunov Jul 13, 2021
1667264
review
borzunov Jul 13, 2021
8a01bf1
review
justheuristic Jul 13, 2021
1a4d7a8
review
borzunov Jul 13, 2021
ad34fdf
Update hivemind/utils/mpfuture.py
justheuristic Jul 13, 2021
3ee5732
review
justheuristic Jul 13, 2021
f147d18
Merge remote-tracking branch 'origin/make-mpfuture-great-again' into …
justheuristic Jul 13, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion hivemind/moe/server/task_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def __init__(

def submit_task(self, *args: torch.Tensor) -> Future:
"""Add task to this pool's queue, return Future for its output"""
task = Task(MPFuture(), args)
task = Task(MPFuture(synchronize=False), args)
if self.get_task_size(task) > self.max_batch_size:
exc = ValueError(f"Task size greater than max_batch_size ({self.max_batch_size}), it can't be processed")
task.future.set_exception(exc)
Expand Down
253 changes: 162 additions & 91 deletions hivemind/utils/mpfuture.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@
import os
import threading
import uuid
from weakref import ref
from enum import Enum, auto
from typing import Generic, TypeVar, Dict, Optional, Any, Callable

import torch # used for py3.7-compatible shared memory
from typing import Generic, TypeVar, Dict, Optional, Any, Callable, Type, Tuple

from hivemind.utils.logging import get_logger

Expand All @@ -34,10 +33,13 @@ class InvalidStateError(Exception):
"""Raised when attempting to change state of a future in a terminal state (e.g. finished)"""


class UpdateType(Enum):
class MessageType(Enum):
RESULT = auto()
EXCEPTION = auto()
RUNNING = auto()
CANCEL = auto()
STATE_REQUEST = auto()
STATE_RESPONSE = auto()


class MPFuture(base.Future, Generic[ResultType]):
Expand All @@ -46,6 +48,7 @@ class MPFuture(base.Future, Generic[ResultType]):
Any process can access future status and set the result / exception and check for state.
However, only the original process (i.e. the process that created the future) can await the result or exception.

:param synchronize: if True (default), future will request state from origin, otherwise it will only use local state
:param use_lock: if True, operations with MPFuture use a global lock to prevent concurrent writes to the same pipe;
If set to False, writing to this future ignores global lock, slightly improving performance, but making user
responsible for avoiding concurrent set_result / set_exception calls to futures with the same process of origin.
Expand All @@ -60,49 +63,36 @@ class MPFuture(base.Future, Generic[ResultType]):

_initialization_lock = mp.Lock() # global lock that prevents simultaneous initialization of two processes
_update_lock = mp.Lock() # global lock that prevents simultaneous writing to the same pipe
_global_sender_pipe: Optional[PipeEnd] = None # a pipe that is used to send results/exceptions to this process
_process_wide_pipe: Optional[PipeEnd] = None # a pipe that is used to send results/exceptions to this process
_pipe_waiter_thread: Optional[threading.Thread] = None # process-specific thread that receives results/exceptions
_active_futures: Optional[Dict[UID, MPFuture]] = None # pending or running futures originated from current process
_active_futures: Optional[Dict[UID, Type[ref][MPFuture]]] = None # non-done futures originated from this process
_status_requests: Optional[Dict[UID, Tuple[MPFuture, threading.Event]]] = None # futures to be updated by origin
_active_pid: Optional[PID] = None # pid of currently active process; used to handle forks natively

def __init__(self, use_lock: bool = True, loop: Optional[asyncio.BaseEventLoop] = None):
self._origin_pid, self._uid = os.getpid(), uuid.uuid4().int
self._shared_state_code = torch.empty([], dtype=torch.uint8).share_memory_()
self._state_cache: Dict[State, State] = {}
# mapping from global to cached local future used that makes updates immediately
# available on setter side; dictionary-based cache works because future can visit any state at most once
SOFT_UPDATE_TIMEOUT = 0.1 # seconds spent awaiting status update before warning is printed
HARD_UPDATE_TIMEOUT = 10.0 # seconds spent awaiting status update before future is automatically cancelled

base.Future.__init__(self) # parent init is deferred because it uses self._shared_state_code
def __init__(self, synchronize: bool = True, use_lock: bool = True, loop: Optional[asyncio.BaseEventLoop] = None):
base.Future.__init__(self)
self.synchronize = synchronize
self._origin_pid, self._uid = os.getpid(), uuid.uuid4().int
self._state, self._result, self._exception = base.PENDING, None, None
self._use_lock = use_lock

if self._origin_pid != MPFuture._active_pid:
with MPFuture._initialization_lock:
if self._origin_pid != MPFuture._active_pid:
# note: the second if is intentional, see https://en.wikipedia.org/wiki/Double-checked_locking
self._initialize_mpfuture_backend()
self._initialize_backend_if_necessary()
assert self._uid not in MPFuture._active_futures
MPFuture._active_futures[self._uid] = self
self._sender_pipe = MPFuture._global_sender_pipe
MPFuture._active_futures[self._uid] = ref(self)
self._sender_pipe = MPFuture._process_wide_pipe

try:
self._loop = loop or asyncio.get_event_loop()
self._aio_event = asyncio.Event()
except RuntimeError:
self._loop, self._aio_event = None, None

@property
def _state(self) -> State:
shared_state = ALL_STATES[self._shared_state_code.item()]
return self._state_cache.get(shared_state, shared_state)

@_state.setter
def _state(self, new_state: State):
self._shared_state_code[...] = ALL_STATES.index(new_state)
if self._state in TERMINAL_STATES and self._loop is not None and not self._aio_event.is_set():
self._set_event_threadsafe()

def _set_event_threadsafe(self):
def _set_event_if_necessary(self):
if self._aio_event is None or self._aio_event.is_set():
Copy link
Member

@borzunov borzunov Jul 13, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: self._aio_event.is_set() is not guaranteed to be thread-safe. This check can probably be removed.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

decided to keep is_set for performance reasons

Copy link
Member

@borzunov borzunov Jul 13, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[It's thread-safe in the current implementation, and we've agreed that it's hard to imagine that it will stop being thread-safe.]

return
try:
loop = asyncio.get_running_loop()
except RuntimeError:
Expand All @@ -111,91 +101,169 @@ def _set_event_threadsafe(self):
async def _event_setter():
self._aio_event.set()

if loop == self.get_loop():
if self._loop.is_running() and loop == self.get_loop():
asyncio.create_task(_event_setter())
else:
elif self._loop.is_running() and loop != self.get_loop():
asyncio.run_coroutine_threadsafe(_event_setter(), self._loop)
else:
self._loop.run_until_complete(_event_setter())

@classmethod
def _initialize_mpfuture_backend(cls):
def _initialize_backend_if_necessary(cls):
pid = os.getpid()
logger.debug(f"Initializing MPFuture backend for pid {pid}")
assert pid != cls._active_pid, "already initialized"

receiver_pipe, cls._global_sender_pipe = mp.Pipe(duplex=False)
cls._active_pid, cls._active_futures = pid, {}
cls._pipe_waiter_thread = threading.Thread(
target=cls._process_updates_in_background, args=[receiver_pipe], name=f"{__name__}.BACKEND", daemon=True
)
cls._pipe_waiter_thread.start()
if MPFuture._active_pid != pid:
with MPFuture._initialization_lock:
if MPFuture._active_pid != pid:
# note: the second if is intentional, see https://en.wikipedia.org/wiki/Double-checked_locking
logger.debug(f"Initializing MPFuture backend for pid {pid}")
receiver_pipe, cls._process_wide_pipe = mp.Pipe(duplex=False)
cls._active_pid, cls._active_futures, cls._status_requests = pid, {}, {}
cls._pipe_waiter_thread = threading.Thread(
target=cls._process_updates_in_background,
args=[receiver_pipe],
name=f"{__name__}.BACKEND",
daemon=True,
)
cls._pipe_waiter_thread.start()

@classmethod
def _process_updates_in_background(cls, receiver_pipe: mp.connection.Connection):
pid = os.getpid()
while True:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: This loop never stops gracefully. Ideally, it would be cool to stop it via a threading.Event in the __del__ finalizer of the last active future in the current process.

However, graceful shutdown may be out of scope of this PR.

try:
uid, update_type, payload = receiver_pipe.recv()
if uid not in cls._active_futures:
logger.debug(f"Ignoring update to future with uid={uid}: the future is already done or destroyed")
elif update_type == UpdateType.RESULT:
cls._active_futures.pop(uid).set_result(payload)
elif update_type == UpdateType.EXCEPTION:
cls._active_futures.pop(uid).set_exception(payload)
elif update_type == UpdateType.CANCEL:
cls._active_futures.pop(uid).cancel()
uid, msg_type, payload = receiver_pipe.recv()
future = None
future_ref = cls._active_futures.get(uid)
if future_ref is not None:
future = future_ref()

if msg_type == MessageType.STATE_REQUEST:
future_state = None if future is None else future.__getstate__()
payload.send((uid, MessageType.STATE_RESPONSE, future_state))

elif msg_type == MessageType.STATE_RESPONSE:
future, state_updated_event = cls._status_requests.get(uid) or (None, None)
if future is None:
logger.debug("Received a state update for a future that does not await status update.")
else:
if payload is not None:
future.__setstate__(payload)
else:
base.Future.cancel(future)
state_updated_event.set()

elif future is None:
logger.debug(
f"Received {msg_type} for MPFuture uid={uid}, but future is already done or destroyed"
)
elif msg_type == MessageType.RESULT:
future.set_result(payload)
elif msg_type == MessageType.EXCEPTION:
future.set_exception(payload)
elif msg_type == MessageType.RUNNING:
try:
future.set_running_or_notify_cancel()
except (InvalidStateError, RuntimeError) as e:
logger.debug(f"could set MPFuture (uid={uid}) to running due to {e}")
elif msg_type == MessageType.CANCEL:
future.cancel()
else:
raise RuntimeError(f"Received unexpected update type {update_type}")
except (BrokenPipeError, EOFError):
raise RuntimeError(f"Received unexpected update type {msg_type}")

if future is None or future.done():
cls._active_futures.pop(uid, None)

except (BrokenPipeError, EOFError, ConnectionError):
logger.debug(f"Update pipe was was shut down unexpectedly (pid={pid})")
except Exception as e:
logger.exception(f"Could not retrieve update: caught {repr(e)} (pid={pid})")

def _send_update(self, update_type: UpdateType, payload: Any = None):
def _send_update(self, update_type: MessageType, payload: Any = None):
"""This method sends result, exception or cancel to the MPFuture origin."""
with MPFuture._update_lock if self._use_lock else nullcontext():
self._sender_pipe.send((self._uid, update_type, payload))
try:
with MPFuture._update_lock if self._use_lock else nullcontext():
self._sender_pipe.send((self._uid, update_type, payload))
except (ConnectionError, BrokenPipeError, EOFError) as e:
logger.debug(f"No updates were sent: pipe to origin process is no longer operational ({e}).")

def _synchronize_if_necessary(self):
if not self.synchronize or os.getpid() == self._origin_pid or self._state in TERMINAL_STATES:
return

self._initialize_backend_if_necessary()

maybe_existing_request = self._status_requests.get(self._uid)
if maybe_existing_request is not None:
_, status_updated = maybe_existing_request
status_updated.wait(MPFuture.HARD_UPDATE_TIMEOUT)
return

# otherwise create a new request for synchronization

try:
status_updated = threading.Event()
self._status_requests[self._uid] = (self, status_updated)
with MPFuture._update_lock if self._use_lock else nullcontext():
self._sender_pipe.send((self._uid, MessageType.STATE_REQUEST, self._process_wide_pipe))
status_updated.wait(MPFuture.SOFT_UPDATE_TIMEOUT)
if not status_updated.is_set():
logger.warning(f"Status update took over {MPFuture.SOFT_UPDATE_TIMEOUT}, expect performance issues")
status_updated.wait(MPFuture.HARD_UPDATE_TIMEOUT - MPFuture.SOFT_UPDATE_TIMEOUT)
if not status_updated.is_set():
self.set_exception(
TimeoutError(
f"Status update took over {MPFuture.HARD_UPDATE_TIMEOUT} seconds, "
f"mpfuture is cancelled"
)
)
status_updated.set() # this triggers any concurrent _synchronize_if_necessary calls to finish
except (ConnectionError, BrokenPipeError, EOFError) as e:
logger.error(f"MPFuture was cancelled because sender pipe is broken. Origin process is probably down.")
if not self.cancel():
self.set_exception(e)
finally:
self._status_requests.pop(self._uid, None)

def set_result(self, result: ResultType):
if os.getpid() == self._origin_pid:
super().set_result(result)
MPFuture._active_futures.pop(self._uid, None)
elif self._state in TERMINAL_STATES:
if self._state in TERMINAL_STATES:
raise InvalidStateError(f"Can't set_result to a future that is {self._state} ({self._uid})")
elif os.getpid() == self._origin_pid:
MPFuture._active_futures.pop(self._uid, None)
self._set_event_if_necessary()
else:
self._state_cache[self._state], self._result = base.FINISHED, result
self._send_update(UpdateType.RESULT, result)
self._send_update(MessageType.RESULT, result)
super().set_result(result)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: In Python 3.7, this may raise RuntimeError (instead of InvalidStateError) if set_result is called concurrently inside one process.

This is a minor issue and we've agreed it won't be fixed.


def set_exception(self, exception: Optional[BaseException]):
if os.getpid() == self._origin_pid:
super().set_exception(exception)
MPFuture._active_futures.pop(self._uid, None)
elif self._state in TERMINAL_STATES:
if self._state in TERMINAL_STATES:
raise InvalidStateError(f"Can't set_exception to a future that is {self._state} ({self._uid})")
elif os.getpid() == self._origin_pid:
MPFuture._active_futures.pop(self._uid, None)
self._set_event_if_necessary()
else:
self._state_cache[self._state], self._exception = base.FINISHED, exception
self._send_update(UpdateType.EXCEPTION, exception)
self._send_update(MessageType.EXCEPTION, exception)
super().set_exception(exception)

def cancel(self) -> bool:
if os.getpid() == self._origin_pid:
MPFuture._active_futures.pop(self._uid, None)
return super().cancel()
elif self._state in [base.RUNNING, base.FINISHED]:
if self._state in [base.RUNNING, base.FINISHED]:
return False
elif os.getpid() == self._origin_pid:
MPFuture._active_futures.pop(self._uid, None)
self._set_event_if_necessary()
else:
self._state_cache[self._state] = base.CANCELLED
self._send_update(UpdateType.CANCEL)
return True
self._send_update(MessageType.CANCEL)
return super().cancel()

def set_running_or_notify_cancel(self):
if self._state == base.PENDING:
self._state = base.RUNNING
return True
elif self._state == base.CANCELLED:
return False
else:
raise InvalidStateError(
f"Can't set_running_or_notify_cancel when future is in {self._state} ({self._uid})"
)
"""if synchronize is set to False, this future will ignore any state changes from origin"""
self._synchronize_if_necessary()
try:
is_running = super().set_running_or_notify_cancel()
if is_running and os.getpid() != self._origin_pid:
self._send_update(MessageType.RUNNING)
return is_running
except RuntimeError as e:
raise InvalidStateError(str(e))

def result(self, timeout: Optional[float] = None) -> ResultType:
if self._state not in TERMINAL_STATES:
Expand All @@ -219,12 +287,15 @@ def exception(self, timeout: Optional[float] = None) -> Optional[BaseException]:
return self._exception

def done(self) -> bool:
self._synchronize_if_necessary()
return self._state in TERMINAL_STATES

def running(self):
self._synchronize_if_necessary()
return self._state == base.RUNNING

def cancelled(self):
self._synchronize_if_necessary()
return self._state == base.CANCELLED

def add_done_callback(self, callback: Callable[[MPFuture], None]):
Expand All @@ -240,7 +311,7 @@ def __await__(self):
raise RuntimeError("Can't await: MPFuture was created with no event loop")
yield from self._aio_event.wait().__await__()
try:
return super().result(timeout=0)
return super().result()
except base.CancelledError:
raise asyncio.CancelledError()

Expand All @@ -252,8 +323,9 @@ def __del__(self):

def __getstate__(self):
return dict(
synchronize=self.synchronize,
_sender_pipe=self._sender_pipe,
_shared_state_code=self._shared_state_code,
_state=self._state,
_origin_pid=self._origin_pid,
_uid=self._uid,
_use_lock=self._use_lock,
Expand All @@ -262,13 +334,12 @@ def __getstate__(self):
)

def __setstate__(self, state):
self.synchronize = state["synchronize"]
self._sender_pipe = state["_sender_pipe"]
self._shared_state_code = state["_shared_state_code"]
self._origin_pid, self._uid = state["_origin_pid"], state["_uid"]
self._state, self._origin_pid, self._uid = state["_state"], state["_origin_pid"], state["_uid"]
self._result, self._exception = state["_result"], state["_exception"]
self._use_lock = state["_use_lock"]

self._waiters, self._done_callbacks = [], []
self._condition = threading.Condition()
self._aio_event, self._loop = None, None
self._state_cache = {}
Loading