diff --git a/benchmarks/benchmark_averaging.py b/benchmarks/benchmark_averaging.py index a08dffcc0..ef5c758ac 100644 --- a/benchmarks/benchmark_averaging.py +++ b/benchmarks/benchmark_averaging.py @@ -6,7 +6,7 @@ import torch import hivemind -from hivemind.utils import LOCALHOST, increase_file_limit, get_logger +from hivemind.utils import LOCALHOST, get_logger, increase_file_limit from hivemind.proto import runtime_pb2 diff --git a/benchmarks/benchmark_dht.py b/benchmarks/benchmark_dht.py index 3a987558e..b3288c337 100644 --- a/benchmarks/benchmark_dht.py +++ b/benchmarks/benchmark_dht.py @@ -6,7 +6,7 @@ import hivemind import hivemind.server.expert_uid -from hivemind.utils.threading import increase_file_limit +from hivemind.utils.limits import increase_file_limit logger = hivemind.get_logger(__name__) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index f9b6f29bf..50754dd6f 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -9,7 +9,7 @@ import hivemind from hivemind import find_open_port from hivemind.server import layers -from hivemind.utils.threading import increase_file_limit +from hivemind.utils.limits import increase_file_limit from hivemind.utils.logging import get_logger diff --git a/examples/albert/README.md b/examples/albert/README.md index 21db2b031..37918b6a9 100644 --- a/examples/albert/README.md +++ b/examples/albert/README.md @@ -40,7 +40,7 @@ wandb: Run `wandb offline` to turn off syncing. - if necessary, specify paths: `--dataset_path ./path/to/unpacked/data --tokenizer ./path/to/tokenizer/config` (see [default paths](https://github.com/learning-at-home/hivemind/blob/collaborative_albert_example/examples/albert/run_trainer.py#L63-L69) for reference) - run: ```shell -HIVEMIND_THREADS=64 python run_trainer.py \ +python run_trainer.py \ --experiment_prefix SAME_AS_IN_RUN_FIRST_PEER --initial_peers ONE_OR_MORE_PEERS --seed 42 \ --logging_first_step --logging_steps 100 --output_dir ./outputs --overwrite_output_dir --logging_dir ./logs ``` @@ -88,7 +88,7 @@ Here's an example of a full trainer script for Google Colab: !pip install transformers datasets sentencepiece torch_optimizer==0.1.0 !git clone https://github.com/learning-at-home/hivemind && cd hivemind && pip install -e . !curl -L YOUR_HOSTED_DATA | tar xzf - # example: https://hivemind-data.s3.us-east-2.amazonaws.com/wikitext103.tar.gz -!ulimit -n 4096 && HIVEMIND_THREADS=256 python ./hivemind/examples/albert/run_trainer.py \ +!ulimit -n 4096 && python ./hivemind/examples/albert/run_trainer.py \ --client_mode --initial_peers ONE_OR_MORE_PEERS --averaging_expiration 10 \ --batch_size_lead 300 --per_device_train_batch_size 4 --gradient_accumulation_steps 1 \ --logging_first_step --logging_steps 100 --output_dir ./outputs --overwrite_output_dir --logging_dir ./logs \ diff --git a/hivemind/client/averaging/__init__.py b/hivemind/client/averaging/__init__.py index 56c5a3951..97b77b115 100644 --- a/hivemind/client/averaging/__init__.py +++ b/hivemind/client/averaging/__init__.py @@ -290,9 +290,9 @@ def step(self, gather: Optional[GatheredData] = None, weight: Optional[float] = weight = float(self.mode != AveragingMode.AUX) assert isinstance(weight, (int, float)) and weight >= 0, f"Expected a positive int/float, got {type(weight)}" - future, _future = MPFuture.make_pair() + future = MPFuture() gather_binary = self.serializer.dumps(gather) # serialize here to avoid loading modules in the averager process - self._outer_pipe.send(('_step', [], dict(future=_future, gather_binary=gather_binary, weight=weight, + self._outer_pipe.send(('_step', [], dict(future=future, gather_binary=gather_binary, weight=weight, allow_retries=allow_retries, timeout=timeout))) return future.result() if wait else future @@ -463,8 +463,8 @@ def get_current_state(self) -> Tuple[Any, Sequence[torch.Tensor]]: async def _get_current_state_from_host_process(self): """ Executed in the averager process inside rpc_download_state """ - future, _future = MPFuture.make_pair() - self._inner_pipe.send(('_TRIGGER_GET_CURRENT_STATE', _future)) + future = MPFuture() + self._inner_pipe.send(('_TRIGGER_GET_CURRENT_STATE', future)) return await future def load_state_from_peers(self, wait=True) -> Optional[Tuple[Any, Sequence[torch.Tensor]]]: @@ -477,8 +477,8 @@ def load_state_from_peers(self, wait=True) -> Optional[Tuple[Any, Sequence[torch The exact contents of both metadata and tensors are determined by get_current_state method """ - future, _future = MPFuture.make_pair() - self._outer_pipe.send(('_load_state_from_peers', [], dict(future=_future))) + future = MPFuture() + self._outer_pipe.send(('_load_state_from_peers', [], dict(future=future))) return future.result() if wait else future async def _load_state_from_peers(self, future: MPFuture): @@ -537,8 +537,8 @@ def get_group_bits(self, wait: bool = True): :param wait: if True, return bits immediately. Otherwise return awaitable MPFuture :returns: averager's current group key bits (without prefix) """ - future, _future = MPFuture.make_pair() - self._outer_pipe.send(('_get_group_bits', [], dict(future=_future))) + future = MPFuture() + self._outer_pipe.send(('_get_group_bits', [], dict(future=future))) return future.result() if wait else future async def _get_group_bits(self, future: MPFuture): @@ -549,9 +549,9 @@ def set_group_bits(self, group_bits: str, wait: bool = True): :param group_bits: group bits (string of '0' or '1') to be used in averager's group key :param wait: if True, wait until the update is confirmed by the averager. Otherwise return immediately """ - future, _future = MPFuture.make_pair() + future = MPFuture() assert all(bit in '01' for bit in group_bits) - self._outer_pipe.send(('_set_group_bits', [], dict(group_bits=group_bits, future=_future))) + self._outer_pipe.send(('_set_group_bits', [], dict(group_bits=group_bits, future=future))) return future.result() if wait else future async def _set_group_bits(self, group_bits: str, future: MPFuture): diff --git a/hivemind/client/averaging/training.py b/hivemind/client/averaging/training.py index d3916ff76..ba6e66d3d 100644 --- a/hivemind/client/averaging/training.py +++ b/hivemind/client/averaging/training.py @@ -1,4 +1,5 @@ """ An extension of averager that supports common optimization use cases. """ +from concurrent.futures import ThreadPoolExecutor from itertools import chain from threading import Lock from typing import Sequence, Dict, Iterator, Optional @@ -7,7 +8,7 @@ import torch from hivemind.client.averaging import DecentralizedAverager -from hivemind.utils import nested_flatten, nested_pack, get_logger, run_in_background +from hivemind.utils import nested_flatten, nested_pack, get_logger logger = get_logger(__name__) @@ -39,6 +40,7 @@ def __init__(self, opt: torch.optim.Optimizer, *, average_parameters: bool, aver self.opt, self.extra_tensors, self.local_step = opt, tuple(extra_tensors), 0 self.opt_statistics = tuple(average_opt_statistics) self.average_parameters, self.average_gradients = average_parameters, average_gradients + self.step_executor = ThreadPoolExecutor(max_workers=1) self.lock_averager_step = Lock() if initialize_optimizer: initialize_optimizer_state(opt) # note: this will run one optimizer step! @@ -47,15 +49,15 @@ def __init__(self, opt: torch.optim.Optimizer, *, average_parameters: bool, aver averaged_tensors = [tensor.detach().cpu().float().clone() for tensor in self.local_tensors()] super().__init__(averaged_tensors=averaged_tensors, **kwargs) - @torch.no_grad() def step(self, data_lock: Optional[Lock] = None, wait: bool = True, **kwargs): - """ Average optimizer weights and gradients with peers. + """ + Average optimizer weights and gradients with peers. + :param data_lock: averager locks it when model parameters are modified. Otherwise it's assumed that no model modifications occur during averaging step - :param wait: if True waits, otherwise returns Future """ if not wait: - return run_in_background(self.step, data_lock, wait=True, **kwargs) + return self.step_executor.submit(self.step, data_lock, wait=True, **kwargs) # if data_lock is supplied, tensors might change during averaging, so we need to copy them use_old_local_tensors = data_lock is not None @@ -63,7 +65,7 @@ def step(self, data_lock: Optional[Lock] = None, wait: bool = True, **kwargs): data_lock = nullcontext() local_tensors = list(self.local_tensors()) - with self.lock_averager_step: + with self.lock_averager_step, torch.no_grad(): # fill averager's tensors with current local tensors with data_lock, self.get_tensors() as averaged_tensors: if use_old_local_tensors: @@ -73,7 +75,7 @@ def step(self, data_lock: Optional[Lock] = None, wait: bool = True, **kwargs): for averaged_tensor, local_tensor in zip(averaged_tensors, local_tensors): averaged_tensor[...] = local_tensor.cpu().float() - # find a group and hopefully average tensors with peers, scaled by peer's weight + # find a group and hopefully average tensors with peers, use batch sizes as weights gathered = super().step(**kwargs) if gathered is not None: # load averaged tensors back into model diff --git a/hivemind/dht/__init__.py b/hivemind/dht/__init__.py index e921218d7..ad29697d6 100644 --- a/hivemind/dht/__init__.py +++ b/hivemind/dht/__init__.py @@ -127,8 +127,8 @@ def get(self, key: DHTKey, latest: bool = False, return_future: bool = False, ** :param kwargs: parameters forwarded to DHTNode.get_many_by_id :returns: (value, expiration time); if value was not found, returns None """ - future, _future = MPFuture.make_pair() - self._outer_pipe.send(('_get', [], dict(key=key, latest=latest, future=_future, **kwargs))) + future = MPFuture() + self._outer_pipe.send(('_get', [], dict(key=key, latest=latest, future=future, **kwargs))) return future if return_future else future.result() async def _get(self, key: DHTKey, latest: bool, future: MPFuture, **kwargs): @@ -153,9 +153,9 @@ def store(self, key: DHTKey, value: DHTValue, expiration_time: DHTExpiration, :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background. :returns: True if store succeeds, False if it fails (due to no response or newer value) """ - future, _future = MPFuture.make_pair() + future = MPFuture() self._outer_pipe.send(('_store', [], dict(key=key, value=value, expiration_time=expiration_time, subkey=subkey, - future=_future, **kwargs))) + future=future, **kwargs))) return future if return_future else future.result() async def _store(self, key: DHTKey, value: DHTValue, expiration_time: DHTExpiration, @@ -184,8 +184,8 @@ def run_coroutine(self, coro: Callable[[DHT, DHTNode], Awaitable[ReturnType]], or use asyncio.get_event_loop().run_in_executor(...) to prevent coroutine from blocking background DHT tasks :note: when run_coroutine is called with wait=False, MPFuture can be cancelled to interrupt the task. """ - future, _future = MPFuture.make_pair() - self._outer_pipe.send(('_run_coroutine', [], dict(coro=coro, future=_future))) + future = MPFuture() + self._outer_pipe.send(('_run_coroutine', [], dict(coro=coro, future=future))) return future if return_future else future.result() async def _run_coroutine(self, coro: Callable[[DHT, DHTNode], Awaitable[ReturnType]], @@ -226,8 +226,8 @@ def get_visible_address(self, num_peers: Optional[int] = None, peers: Sequence[E """ assert num_peers is None or peers == (), "please specify either a num_peers or the list of peers, not both" assert not isinstance(peers, str) and isinstance(peers, Sequence), "Please send a list / tuple of endpoints" - future, _future = MPFuture.make_pair() - self._outer_pipe.send(('_get_visible_address', [], dict(num_peers=num_peers, peers=peers, future=_future))) + future = MPFuture() + self._outer_pipe.send(('_get_visible_address', [], dict(num_peers=num_peers, peers=peers, future=future))) return future.result() async def _get_visible_address(self, num_peers: Optional[int], peers: Sequence[Endpoint], diff --git a/hivemind/hivemind_cli/run_server.py b/hivemind/hivemind_cli/run_server.py index 7a3c6ba6b..a9a952af1 100644 --- a/hivemind/hivemind_cli/run_server.py +++ b/hivemind/hivemind_cli/run_server.py @@ -6,7 +6,7 @@ from hivemind.proto.runtime_pb2 import CompressionType from hivemind.server import Server -from hivemind.utils.threading import increase_file_limit +from hivemind.utils.limits import increase_file_limit from hivemind.utils.logging import get_logger from hivemind.server.layers import schedule_name_to_scheduler diff --git a/hivemind/server/task_pool.py b/hivemind/server/task_pool.py index bb57ddd1d..ad4a819ab 100644 --- a/hivemind/server/task_pool.py +++ b/hivemind/server/task_pool.py @@ -14,7 +14,8 @@ import torch -from hivemind.utils import MPFuture, get_logger, FutureStateError +from hivemind.utils import get_logger +from hivemind.utils.mpfuture import MPFuture, InvalidStateError logger = get_logger(__name__) Task = namedtuple("Task", ("future", "args")) @@ -89,15 +90,14 @@ def __init__(self, process_func: callable, max_batch_size: int, name: str, min_b def submit_task(self, *args: torch.Tensor) -> Future: """ Add task to this pool's queue, return Future for its output """ - future1, future2 = MPFuture.make_pair() - task = Task(future1, args) + task = Task(MPFuture(), args) if self.get_task_size(task) > self.max_batch_size: exc = ValueError(f"Task size greater than max_batch_size ({self.max_batch_size}), it can't be processed") - future2.set_exception(exc) + task.future.set_exception(exc) else: self.tasks.put(task) self.undispatched_task_timestamps.put(time.time()) - return future2 + return task.future def iterate_minibatches(self, *args, **kwargs): """ Form minibatches by grouping one or more tasks together up to self.max_batch_size """ @@ -127,7 +127,7 @@ def iterate_minibatches(self, *args, **kwargs): if task.future.set_running_or_notify_cancel(): batch.append(task) total_size += task_size - except FutureStateError as e: + except InvalidStateError as e: logger.debug(f"Failed to add task to batch: {task.future} raised {e}") def run(self, *args, **kwargs): @@ -196,7 +196,7 @@ def _pool_output_loop(self, pending_batches: Dict[Any, List[Task]]): for task, task_outputs in zip(batch_tasks, outputs_per_task): try: task.future.set_result(tuple(task_outputs)) - except FutureStateError as e: + except InvalidStateError as e: logger.debug(f"Failed to send task result due to an exception: {e}") @property diff --git a/hivemind/utils/__init__.py b/hivemind/utils/__init__.py index 3287ca2e2..3d2d69868 100644 --- a/hivemind/utils/__init__.py +++ b/hivemind/utils/__init__.py @@ -1,11 +1,11 @@ from hivemind.utils.asyncio import * from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor from hivemind.utils.grpc import * +from hivemind.utils.limits import increase_file_limit from hivemind.utils.logging import get_logger from hivemind.utils.mpfuture import * from hivemind.utils.nested import * from hivemind.utils.networking import * from hivemind.utils.serializer import * from hivemind.utils.tensor_descr import * -from hivemind.utils.threading import * from hivemind.utils.timed_storage import * diff --git a/hivemind/utils/compression.py b/hivemind/utils/compression.py index a80dedba9..22583649f 100644 --- a/hivemind/utils/compression.py +++ b/hivemind/utils/compression.py @@ -1,3 +1,5 @@ +import os +from concurrent.futures import ThreadPoolExecutor from typing import Tuple, Sequence, Optional import numpy as np @@ -6,7 +8,7 @@ from hivemind.proto import runtime_pb2 from hivemind.proto.runtime_pb2 import CompressionType -from hivemind.utils.threading import run_in_background + FP32_EPS = 1e-06 NUM_BYTES_FLOAT32 = 4 @@ -17,6 +19,8 @@ FP16_MAX = 65_504 UINT8_RANGE = 256 +COMPRESSION_EXECUTOR = ThreadPoolExecutor(max_workers=int(os.environ.get("QUANTILE_COMPRESSION_THREADS", 128))) + warnings.filterwarnings("ignore", message="The given NumPy array is not writeable", category=UserWarning) @@ -48,8 +52,7 @@ def _quantile_qq_approximation(array: np.array, n_quantiles: int, min_chunk_size jobs = [] for i in range(num_chunks): chunk = slice(chunk_size * i, chunk_size * (i + 1)) - jobs.append(run_in_background( - np.quantile, array[chunk], quantiles, out=partition_quantiles[i])) + jobs.append(COMPRESSION_EXECUTOR.submit(np.quantile, array[chunk], quantiles, out=partition_quantiles[i])) for job in jobs: job.result() diff --git a/hivemind/utils/threading.py b/hivemind/utils/limits.py similarity index 58% rename from hivemind/utils/threading.py rename to hivemind/utils/limits.py index 584e44ffe..0521671bf 100644 --- a/hivemind/utils/threading.py +++ b/hivemind/utils/limits.py @@ -1,21 +1,7 @@ -import os -from concurrent.futures import Future, ThreadPoolExecutor - from hivemind.utils.logging import get_logger logger = get_logger(__name__) -EXECUTOR_PID, GLOBAL_EXECUTOR = None, None - - -def run_in_background(func: callable, *args, **kwargs) -> Future: - """ run func(*args, **kwargs) in background and return Future for its outputs """ - global EXECUTOR_PID, GLOBAL_EXECUTOR - if os.getpid() != EXECUTOR_PID: - GLOBAL_EXECUTOR = ThreadPoolExecutor(max_workers=int(os.environ.get("HIVEMIND_THREADS", 128))) - EXECUTOR_PID = os.getpid() - return GLOBAL_EXECUTOR.submit(func, *args, **kwargs) - def increase_file_limit(new_soft=2 ** 15, new_hard=2 ** 15): """ Increase the maximum number of open files. On Linux, this allows spawning more processes/threads. """ diff --git a/hivemind/utils/mpfuture.py b/hivemind/utils/mpfuture.py index cf47c399c..628434759 100644 --- a/hivemind/utils/mpfuture.py +++ b/hivemind/utils/mpfuture.py @@ -2,171 +2,262 @@ import asyncio import concurrent.futures._base as base +from contextlib import nullcontext import multiprocessing as mp import multiprocessing.connection -import time -from functools import lru_cache -from typing import Optional, Tuple, Generic, TypeVar +import os +import threading +import uuid +from enum import Enum, auto +from typing import Generic, TypeVar, Dict, Optional, Any, Callable -from hivemind.utils.threading import run_in_background +import torch # used for py3.7-compatible shared memory +from hivemind.utils.logging import get_logger + + +logger = get_logger(__name__) + +# flavour types ResultType = TypeVar('ResultType') +PID, UID, State, PipeEnd = int, int, str, mp.connection.Connection +ALL_STATES = base.PENDING, base.RUNNING, base.FINISHED, base.CANCELLED, base.CANCELLED_AND_NOTIFIED +TERMINAL_STATES = {base.FINISHED, base.CANCELLED, base.CANCELLED_AND_NOTIFIED} +try: + from concurrent.futures import InvalidStateError +except ImportError: + # Python 3.7 doesn't raise concurrent.futures.InvalidStateError for repeating set_result/set_exception calls and + # doesn't even define this error. In this module, we simulate the Python 3.8+ behavior, + # defining and raising this error if necessary. + class InvalidStateError(Exception): + """Raised when attempting to change state of a future in a terminal state (e.g. finished)""" -class FutureStateError(RuntimeError): - """Raised when attempting to change state of a future in a terminal state (e.g. finished)""" - pass + +class UpdateType(Enum): + RESULT = auto() + EXCEPTION = auto() + CANCEL = auto() class MPFuture(base.Future, Generic[ResultType]): - """ Multiprocessing version of concurrent.futures.Future. Can also be awaited like asyncio.Future """ + """ + A version of concurrent.futures.Future / asyncio.Future that can be fulfilled from a separate process. + Any process can access future status and set the result / exception and check for state. + However, only the original process (i.e. the process that created the future) can await the result or exception. + + :param use_lock: if True, operations with MPFuture use a global lock to prevent concurrent writes to the same pipe; + If set to False, writing to this future ignores global lock, slightly improving performance, but making user + responsible for avoiding concurrent set_result / set_exception calls to futures with the same process of origin. + :param loop: if specified, overrides default asyncio event loop for the purpose of awaiting MPFuture + + :note: This is an internal primitive that is not guaranteed to work outside of hivemind applications. + More specifically, there are two known limitations: + - MPFuture works between processes created through inheritance (e.g. fork), *not* for independent processes + - MPFuture is deterministic if only one process can call set_result/set_exception/set_running_or_notify_cancel + and only the origin process can call result/exception/cancel. + """ + _initialization_lock = mp.Lock() # global lock that prevents simultaneous initialization of two processes + _update_lock = mp.Lock() # global lock that prevents simultaneous writing to the same pipe + _global_sender_pipe: Optional[PipeEnd] = None # a pipe that is used to send results/exceptions to this process + _pipe_waiter_thread: Optional[threading.Thread] = None # process-specific thread that receives results/exceptions + _active_futures: Optional[Dict[UID, MPFuture]] = None # pending or running futures originated from current process + _active_pid: Optional[PID] = None # pid of currently active process; used to handle forks natively - TERMINAL_STATES = {base.FINISHED, base.CANCELLED, base.CANCELLED_AND_NOTIFIED} + def __init__(self, use_lock: bool = True, loop: Optional[asyncio.BaseEventLoop] = None): + self._origin_pid, self._uid = os.getpid(), uuid.uuid4().int + self._shared_state_code = torch.empty([], dtype=torch.uint8).share_memory_() + self._state_cache: Dict[State, State] = {} # mapping from global to cached local future used that makes updates immediately + # available on setter side; dictionary-based cache works because future can visit any state at most once - def __init__(self, connection: mp.connection.Connection): - """ manually create MPFuture. Please use MPFuture.make_pair instead """ + base.Future.__init__(self) # parent init is deferred because it uses self._shared_state_code self._state, self._result, self._exception = base.PENDING, None, None - self.connection = connection + self._use_lock = use_lock - @classmethod - def make_pair(cls) -> Tuple[MPFuture, MPFuture]: - """ Create a pair of linked futures to be used in two processes """ - connection1, connection2 = mp.Pipe() - return cls(connection1), cls(connection2) + if self._origin_pid != MPFuture._active_pid: + with MPFuture._initialization_lock: + if self._origin_pid != MPFuture._active_pid: + # note: the second if is intentional, see https://en.wikipedia.org/wiki/Double-checked_locking + self._initialize_mpfuture_backend() + assert self._uid not in MPFuture._active_futures + MPFuture._active_futures[self._uid] = self + self._sender_pipe = MPFuture._global_sender_pipe - def _send_updates(self): - """ Send updates to a paired MPFuture """ try: - self.connection.send((self._state, self._result, self._exception)) - if self._state in self.TERMINAL_STATES: - self._shutdown_trigger.set_result(True) - self.connection.close() - return True - except BrokenPipeError: - return False + self._loop = loop or asyncio.get_event_loop() + self._aio_event = asyncio.Event() + except RuntimeError: + self._loop, self._aio_event = None, None - def _recv_updates(self, timeout: Optional[float]): - """ Await updates from a paired MPFuture """ - try: - future = base.wait([run_in_background(self.connection.poll, timeout), self._shutdown_trigger], - return_when=base.FIRST_COMPLETED)[0].pop() - if future is self._shutdown_trigger: - raise BrokenPipeError() - if not future.result(): - raise TimeoutError() - self._state, result, exception = self.connection.recv() - self._result = result if result is not None else self._result - self._exception = exception if exception is not None else self._exception - if self._state in self.TERMINAL_STATES: - self.connection.close() - except TimeoutError as e: - raise e - except (BrokenPipeError, OSError, EOFError) as e: - if self._state in (base.PENDING, base.RUNNING): - self._state, self._exception = base.FINISHED, e - - def _await_terminal_state(self, timeout: Optional[float]): - """ Await updates until future is either finished, cancelled or got an exception """ - time_left = float('inf') if timeout is None else timeout - time_before = time.monotonic() - while self._state not in self.TERMINAL_STATES and time_left > 0: - self._recv_updates(time_left if timeout else None) - time_spent = time.monotonic() - time_before - time_left, time_before = time_left - time_spent, time_before + time_spent - - def _sync_updates(self): - """ Apply queued updates from a paired MPFuture without waiting for new ones """ + @property + def _state(self) -> State: + shared_state = ALL_STATES[self._shared_state_code.item()] + return self._state_cache.get(shared_state, shared_state) + + @_state.setter + def _state(self, new_state: State): + self._shared_state_code[...] = ALL_STATES.index(new_state) + if self._state in TERMINAL_STATES and self._loop is not None and not self._aio_event.is_set(): + self._set_event_threadsafe() + + def _set_event_threadsafe(self): try: - self._recv_updates(timeout=0) - except TimeoutError: - pass + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + + async def _event_setter(): + self._aio_event.set() + + if loop == self.get_loop(): + asyncio.create_task(_event_setter()) + else: + asyncio.run_coroutine_threadsafe(_event_setter(), self._loop) + + @classmethod + def _initialize_mpfuture_backend(cls): + pid = os.getpid() + logger.debug(f"Initializing MPFuture backend for pid {pid}") + assert pid != cls._active_pid, "already initialized" + + receiver_pipe, cls._global_sender_pipe = mp.Pipe(duplex=False) + cls._active_pid, cls._active_futures = pid, {} + cls._pipe_waiter_thread = threading.Thread(target=cls._process_updates_in_background, args=[receiver_pipe], + name=f'{__name__}.BACKEND', daemon=True) + cls._pipe_waiter_thread.start() + + @classmethod + def _process_updates_in_background(cls, receiver_pipe: mp.connection.Connection): + pid = os.getpid() + while True: + try: + uid, update_type, payload = receiver_pipe.recv() + if uid not in cls._active_futures: + logger.debug(f"Ignoring update to future with uid={uid}: the future is already done or destroyed") + elif update_type == UpdateType.RESULT: + cls._active_futures.pop(uid).set_result(payload) + elif update_type == UpdateType.EXCEPTION: + cls._active_futures.pop(uid).set_exception(payload) + elif update_type == UpdateType.CANCEL: + cls._active_futures.pop(uid).cancel() + else: + raise RuntimeError(f"Received unexpected update type {update_type}") + except (BrokenPipeError, EOFError): + logger.debug(f"Update pipe was was shut down unexpectedly (pid={pid})") + except Exception as e: + logger.exception(f"Could not retrieve update: caught {repr(e)} (pid={pid})") + + def _send_update(self, update_type: UpdateType, payload: Any = None): + """ This method sends result, exception or cancel to the MPFuture origin. """ + with MPFuture._update_lock if self._use_lock else nullcontext(): + self._sender_pipe.send((self._uid, update_type, payload)) def set_result(self, result: ResultType): - self._sync_updates() - if self._state in self.TERMINAL_STATES: - raise FutureStateError(f"Can't set_result to a future that is {self._state} ({self})") - self._state, self._result = base.FINISHED, result - return self._send_updates() - - def set_exception(self, exception: BaseException): - self._sync_updates() - if self._state in self.TERMINAL_STATES: - raise FutureStateError(f"Can't set_exception to a future that is {self._state} ({self})") - self._state, self._exception = base.FINISHED, exception - self._send_updates() + if os.getpid() == self._origin_pid: + super().set_result(result) + MPFuture._active_futures.pop(self._uid, None) + elif self._state in TERMINAL_STATES: + raise InvalidStateError(f"Can't set_result to a future that is {self._state} ({self._uid})") + else: + self._state_cache[self._state], self._result = base.FINISHED, result + self._send_update(UpdateType.RESULT, result) + + def set_exception(self, exception: Optional[BaseException]): + if os.getpid() == self._origin_pid: + super().set_exception(exception) + MPFuture._active_futures.pop(self._uid, None) + elif self._state in TERMINAL_STATES: + raise InvalidStateError(f"Can't set_exception to a future that is {self._state} ({self._uid})") + else: + self._state_cache[self._state], self._exception = base.FINISHED, exception + self._send_update(UpdateType.EXCEPTION, exception) + + def cancel(self) -> bool: + if os.getpid() == self._origin_pid: + MPFuture._active_futures.pop(self._uid, None) + return super().cancel() + elif self._state in [base.RUNNING, base.FINISHED]: + return False + else: + self._state_cache[self._state] = base.CANCELLED + self._send_update(UpdateType.CANCEL) + return True def set_running_or_notify_cancel(self): - self._sync_updates() if self._state == base.PENDING: self._state = base.RUNNING - return self._send_updates() + return True elif self._state == base.CANCELLED: return False else: - raise FutureStateError(f"Can't set_running_or_notify_cancel to a future that is in {self._state} ({self})") - - def cancel(self): - self._sync_updates() - if self._state in self.TERMINAL_STATES: - return False - self._state, self._exception = base.CANCELLED, base.CancelledError() - return self._send_updates() + raise InvalidStateError(f"Can't set_running_or_notify_cancel when future is in {self._state} ({self._uid})") def result(self, timeout: Optional[float] = None) -> ResultType: - self._await_terminal_state(timeout) - if self._exception is not None: + if self._state not in TERMINAL_STATES: + if os.getpid() != self._origin_pid: + raise RuntimeError("Only the process that created MPFuture can await result") + return super().result(timeout) + elif self._state == base.CANCELLED: + raise base.CancelledError() + elif self._exception: raise self._exception - return self._result + else: + return self._result - def exception(self, timeout=None) -> BaseException: - self._await_terminal_state(timeout) - if self._state == base.CANCELLED: + def exception(self, timeout: Optional[float] = None) -> Optional[BaseException]: + if self._state not in TERMINAL_STATES: + if os.getpid() != self._origin_pid: + raise RuntimeError("Only the process that created MPFuture can await exception") + return super().exception(timeout) + elif self._state == base.CANCELLED: raise base.CancelledError() return self._exception def done(self) -> bool: - self._sync_updates() - return self._state in self.TERMINAL_STATES + return self._state in TERMINAL_STATES def running(self): - self._sync_updates() return self._state == base.RUNNING def cancelled(self): - self._sync_updates() return self._state == base.CANCELLED - def add_done_callback(self, callback): - raise NotImplementedError(f"MPFuture doesn't support callbacks.") - - def remove_done_callback(self, callback): - raise NotImplementedError(f"MPFuture doesn't support callbacks.") + def add_done_callback(self, callback: Callable[[MPFuture], None]): + if os.getpid() != self._origin_pid: + raise RuntimeError("Only the process that created MPFuture can set callbacks") + return super().add_done_callback(callback) - def get_loop(self): - raise NotImplementedError(f"MPFuture doesn't support get_loop") - - @property - @lru_cache() - def _shutdown_trigger(self): - return base.Future() - - def __repr__(self): - self._sync_updates() - if self._state == base.FINISHED: - if self._exception: - return "".format(id(self), type(self._exception)) - else: - return "".format(id(self), type(self._result)) - else: - return "".format(id(self), self._state) + def get_loop(self) -> Optional[asyncio.BaseEventLoop]: + return self._loop def __await__(self): - yield from asyncio.get_running_loop().run_in_executor(None, self._await_terminal_state, None).__await__() - if self._exception: - raise self._exception - return self._result + if not self._aio_event: + raise RuntimeError("Can't await: MPFuture was created with no event loop") + yield from self._aio_event.wait().__await__() + try: + return super().result(timeout=0) + except base.CancelledError: + raise asyncio.CancelledError() def __del__(self): - self._shutdown_trigger.set_result(True) - if hasattr(self, 'connection'): - self.connection.close() + if getattr(self, '_origin_pid', None) == os.getpid(): + MPFuture._active_futures.pop(self._uid, None) + if getattr(self, '_aio_event', None): + self._aio_event.set() + + def __getstate__(self): + return dict(_sender_pipe=self._sender_pipe, _shared_state_code=self._shared_state_code, + _origin_pid=self._origin_pid, _uid=self._uid, _use_lock=self._use_lock, + _result=self._result, _exception=self._exception) + + def __setstate__(self, state): + self._sender_pipe = state['_sender_pipe'] + self._shared_state_code = state['_shared_state_code'] + self._origin_pid, self._uid = state['_origin_pid'], state['_uid'] + self._result, self._exception = state['_result'], state['_exception'] + self._use_lock = state['_use_lock'] + + self._waiters, self._done_callbacks = [], [] + self._condition = threading.Condition() + self._aio_event, self._loop = None, None + self._state_cache = {} diff --git a/tests/test_averaging.py b/tests/test_averaging.py index 80ffeb409..e6a3a58c7 100644 --- a/tests/test_averaging.py +++ b/tests/test_averaging.py @@ -423,3 +423,7 @@ def test_training_averager(n_steps: int = 10, n_dims: int = 16): assert torch.allclose(x2.grad, grad_avg) assert torch.allclose(opt1.state[x1]["exp_avg_sq"], stats_avg) assert torch.allclose(opt2.state[x2]["exp_avg_sq"], stats_avg) + + averager1.shutdown() + averager2.shutdown() + dht.shutdown() diff --git a/tests/test_util_modules.py b/tests/test_util_modules.py index 5d1f6059d..97d2e36ab 100644 --- a/tests/test_util_modules.py +++ b/tests/test_util_modules.py @@ -1,129 +1,310 @@ import asyncio -from concurrent.futures import CancelledError +import concurrent.futures +import multiprocessing as mp +import random +import time -import numpy as np import pytest import torch +import numpy as np +import hivemind from hivemind.proto.dht_pb2_grpc import DHTStub from hivemind.proto.runtime_pb2 import CompressionType from hivemind.proto.runtime_pb2_grpc import ConnectionHandlerStub -import hivemind from hivemind.utils import MSGPackSerializer from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor from hivemind.utils.asyncio import amap_in_executor, aiter, aenumerate, achain, anext, azip -from hivemind.utils.mpfuture import FutureStateError +from hivemind.utils.mpfuture import InvalidStateError +@pytest.mark.forked def test_mpfuture_result(): - f1, f2 = hivemind.MPFuture.make_pair() - f1.set_result(321) - assert f2.result() == 321 - assert f1.result() == 321 + future = hivemind.MPFuture() - for future in [f1, f2]: - with pytest.raises(FutureStateError): - future.set_result(123) - with pytest.raises(FutureStateError): - future.set_exception(ValueError()) - assert future.cancel() is False - assert future.done() and not future.running() and not future.cancelled() + def _proc(future): + with pytest.raises(RuntimeError): + future.result() # only creator process can await result + + future.set_result(321) + + p = mp.Process(target=_proc, args=(future,)) + p.start() + p.join() - f1, f2 = hivemind.MPFuture.make_pair() - with pytest.raises(TimeoutError): - f1.result(timeout=1e-3) + assert future.result() == 321 + assert future.exception() is None + assert future.cancel() is False + assert future.done() and not future.running() and not future.cancelled() - f2.set_result(['abacaba', 123]) - assert f1.result() == ['abacaba', 123] + future = hivemind.MPFuture() + with pytest.raises(concurrent.futures.TimeoutError): + future.result(timeout=1e-3) + future.set_result(['abacaba', 123]) + assert future.result() == ['abacaba', 123] + +@pytest.mark.forked def test_mpfuture_exception(): - f1, f2 = hivemind.MPFuture.make_pair() - with pytest.raises(TimeoutError): - f1.exception(timeout=1e-3) + future = hivemind.MPFuture() + with pytest.raises(concurrent.futures.TimeoutError): + future.exception(timeout=1e-3) - f2.set_exception(NotImplementedError()) + def _proc(future): + future.set_exception(NotImplementedError()) - for future in [f1, f2]: - assert isinstance(future.exception(), NotImplementedError) - with pytest.raises(NotImplementedError): - future.result() - assert future.cancel() is False - assert future.done() and not future.running() and not future.cancelled() + p = mp.Process(target=_proc, args=(future,)) + p.start() + p.join() + + assert isinstance(future.exception(), NotImplementedError) + with pytest.raises(NotImplementedError): + future.result() + assert future.cancel() is False + assert future.done() and not future.running() and not future.cancelled() +@pytest.mark.forked def test_mpfuture_cancel(): - f1, f2 = hivemind.MPFuture.make_pair() - assert not f2.cancelled() - f1.cancel() - for future in [f1, f2]: - with pytest.raises(CancelledError): + future = hivemind.MPFuture() + assert not future.cancelled() + future.cancel() + evt = mp.Event() + + def _proc(): + with pytest.raises(concurrent.futures.CancelledError): future.result() - with pytest.raises(CancelledError): + with pytest.raises(concurrent.futures.CancelledError): future.exception() - with pytest.raises(FutureStateError): + with pytest.raises(InvalidStateError): future.set_result(123) - with pytest.raises(FutureStateError): + with pytest.raises(InvalidStateError): future.set_exception(NotImplementedError()) assert future.cancelled() and future.done() and not future.running() + evt.set() + p = mp.Process(target=_proc) + p.start() + p.join() + assert evt.is_set() + +@pytest.mark.forked def test_mpfuture_status(): - f1, f2 = hivemind.MPFuture.make_pair() - assert f1.set_running_or_notify_cancel() is True - for future in [f1, f2]: - assert future.running() and not future.done() and not future.cancelled() - with pytest.raises(RuntimeError): - future.set_running_or_notify_cancel() - f2.cancel() - for future in [f1, f2]: + evt = mp.Event() + future = hivemind.MPFuture() + + def _proc1(future): + assert future.set_running_or_notify_cancel() is True + evt.set() + + p = mp.Process(target=_proc1, args=(future,)) + p.start() + p.join() + assert evt.is_set() + evt.clear() + + assert future.running() and not future.done() and not future.cancelled() + with pytest.raises(InvalidStateError): + future.set_running_or_notify_cancel() + + future = hivemind.MPFuture() + assert future.cancel() + + def _proc2(future): assert not future.running() and future.done() and future.cancelled() assert future.set_running_or_notify_cancel() is False + evt.set() - f1, f2 = hivemind.MPFuture.make_pair() - f1.cancel() - for future in [f1, f2]: - assert future.set_running_or_notify_cancel() is False + p = mp.Process(target=_proc2, args=(future,)) + p.start() + p.join() + evt.set() + + future2 = hivemind.MPFuture() + future2.cancel() + assert future2.set_running_or_notify_cancel() is False @pytest.mark.asyncio async def test_await_mpfuture(): - # await result - f1, f2 = hivemind.MPFuture.make_pair() + # await result from the same process, but a different coroutine + f1, f2 = hivemind.MPFuture(), hivemind.MPFuture() - async def wait_and_assign(): + async def wait_and_assign_async(): assert f2.set_running_or_notify_cancel() is True await asyncio.sleep(0.1) - f2.set_result((123, 'ololo')) + f1.set_result((123, 'ololo')) + f2.set_result((456, 'pyshpysh')) + + asyncio.create_task(wait_and_assign_async()) - asyncio.create_task(wait_and_assign()) - for future in [f1, f2]: - res = await future - assert res == (123, 'ololo') + assert (await asyncio.gather(f1, f2)) == [(123, 'ololo'), (456, 'pyshpysh')] + + # await result from separate processes + f1, f2 = hivemind.MPFuture(), hivemind.MPFuture() + + def wait_and_assign(future, value): + time.sleep(0.1 * random.random()) + future.set_result(value) + + p1 = mp.Process(target=wait_and_assign, args=(f1, 'abc')) + p2 = mp.Process(target=wait_and_assign, args=(f2, 'def')) + for p in p1, p2: + p.start() + + assert (await asyncio.gather(f1, f2)) == ['abc', 'def'] + for p in p1, p2: + p.join() # await cancel - f1, f2 = hivemind.MPFuture.make_pair() + f1, f2 = hivemind.MPFuture(), hivemind.MPFuture() - async def wait_and_cancel(): - await asyncio.sleep(0.1) + def wait_and_cancel(): + time.sleep(0.01) + f2.set_result(123456) + time.sleep(0.1) f1.cancel() - asyncio.create_task(wait_and_cancel()) - for future in [f1, f2]: - with pytest.raises(CancelledError): - await future + p = mp.Process(target=wait_and_cancel) + p.start() + + with pytest.raises(asyncio.CancelledError): + # note: it is intended that MPFuture raises Cancel + await asyncio.gather(f1, f2) + + p.join() # await exception - f1, f2 = hivemind.MPFuture.make_pair() + f1, f2 = hivemind.MPFuture(), hivemind.MPFuture() - async def wait_and_raise(): - await asyncio.sleep(0.1) - f1.set_exception(SystemError()) + def wait_and_raise(): + time.sleep(0.01) + f2.set_result(123456) + time.sleep(0.1) + f1.set_exception(ValueError('we messed up')) + + p = mp.Process(target=wait_and_raise) + p.start() + + with pytest.raises(ValueError): + # note: it is intended that MPFuture raises Cancel + await asyncio.gather(f1, f2) + + p.join() + + +@pytest.mark.forked +def test_mpfuture_bidirectional(): + evt = mp.Event() + future_from_main = hivemind.MPFuture() + + def _future_creator(): + future_from_fork = hivemind.MPFuture() + future_from_main.set_result(('abc', future_from_fork)) + + if future_from_fork.result() == ['we', 'need', 'to', 'go', 'deeper']: + evt.set() + + p = mp.Process(target=_future_creator) + p.start() + + out = future_from_main.result() + assert isinstance(out[1], hivemind.MPFuture) + out[1].set_result(['we', 'need', 'to', 'go', 'deeper']) + + p.join() + assert evt.is_set() + + +@pytest.mark.forked +def test_mpfuture_done_callback(): + receiver, sender = mp.Pipe(duplex=False) + events = [mp.Event() for _ in range(5)] + + def _future_creator(): + future1, future2, future3 = hivemind.MPFuture(), hivemind.MPFuture(), hivemind.MPFuture() + + def _check_result_and_set(future): + assert future.done() + assert future.result() == 123 + events[0].set() + + future1.add_done_callback(_check_result_and_set) + future1.add_done_callback(lambda future: events[1].set()) + future2.add_done_callback(lambda future: events[2].set()) + future3.add_done_callback(lambda future: events[3].set()) + + sender.send((future1, future2)) + future2.cancel() # trigger future2 callback from the same process + + events[0].wait() + future1.add_done_callback(lambda future: events[4].set()) # schedule callback after future1 is already finished + + p = mp.Process(target=_future_creator) + p.start() + + future1, future2 = receiver.recv() + future1.set_result(123) + + with pytest.raises(RuntimeError): + future1.add_done_callback(lambda future: (1, 2, 3)) + + p.join() + events[0].wait(1) + events[1].wait(1) + assert future1.done() and not future1.cancelled() + assert future2.done() and future2.cancelled() + assert events[0].is_set() and events[1].is_set() and events[2].is_set() and events[4].is_set() + assert not events[3].is_set() + + +@pytest.mark.forked +def test_many_futures(): + evt = mp.Event() + receiver, sender = mp.Pipe() + main_futures = [hivemind.MPFuture() for _ in range(1000)] + assert len(hivemind.MPFuture._active_futures) == 1000 + + def _run_peer(): + fork_futures = [hivemind.MPFuture() for _ in range(500)] + assert len(hivemind.MPFuture._active_futures) == 500 + + for i, future in enumerate(random.sample(main_futures, 300)): + if random.random() < 0.5: + future.set_result(i) + else: + future.set_exception(ValueError(f"{i}")) + + sender.send(fork_futures[:-100]) + for future in fork_futures[-100:]: + future.cancel() + + evt.wait() + + assert len(hivemind.MPFuture._active_futures) == 200 + for future in fork_futures: + future.cancel() + assert len(hivemind.MPFuture._active_futures) == 0 + + p = mp.Process(target=_run_peer) + p.start() + + some_fork_futures = receiver.recv() + assert len(hivemind.MPFuture._active_futures) == 700 + + for future in some_fork_futures: + future.set_running_or_notify_cancel() + for future in random.sample(some_fork_futures, 200): + future.set_result(321) - asyncio.create_task(wait_and_raise()) - for future in [f1, f2]: - with pytest.raises(SystemError): - await future + time.sleep(0.5) + evt.set() + for future in main_futures: + future.cancel() + assert len(hivemind.MPFuture._active_futures) == 0 + p.join() def test_tensor_compression(size=(128, 128, 64), alpha=5e-08, beta=0.0008): @@ -139,7 +320,7 @@ def test_tensor_compression(size=(128, 128, 64), alpha=5e-08, beta=0.0008): error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.UNIFORM_8BIT)) - X assert error.square().mean() < beta - zeros = torch.zeros(5,5) + zeros = torch.zeros(5, 5) for compression_type in CompressionType.values(): assert deserialize_torch_tensor(serialize_torch_tensor(zeros, compression_type)).isfinite().all()