From ef2185483f5b537722fb3fa1d441a80cf48e77f7 Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Wed, 5 May 2021 18:40:35 +0300 Subject: [PATCH 1/9] Set correct device for scores --- hivemind/client/switch_moe.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/hivemind/client/switch_moe.py b/hivemind/client/switch_moe.py index 40dd0d7f1..84b9e3013 100644 --- a/hivemind/client/switch_moe.py +++ b/hivemind/client/switch_moe.py @@ -156,8 +156,11 @@ def compute_expert_scores( batch_size = len(batch_experts) max_num_experts = max(expert_counts) total_num_experts = sum(expert_counts) - expert_index_in_batch = torch.arange(total_num_experts, device=grid_probs[0].device) - expert_strides = torch.cumsum(torch.as_tensor([0] + expert_counts, device=grid_probs[0].device), dim=-1)[:-1] + + device = grid_probs[0].device + + expert_index_in_batch = torch.arange(total_num_experts, device=device) + expert_strides = torch.cumsum(torch.as_tensor([0] + expert_counts, device=device), dim=-1)[:-1] flat_batch_indices = (expert_index_in_batch >= expert_strides[:, None]).to(torch.int32).sum(0) - 1 flat_local_indices = expert_index_in_batch - expert_strides[flat_batch_indices] flat_experts = [expert for row in batch_experts for expert in row] @@ -169,10 +172,10 @@ def compute_expert_scores( grid_indices[i] = torch.as_tensor(expert_indices, dtype=grid_indices.dtype) scores_per_dim = [ - dim_scores[flat_batch_indices, dim_indices] if len(flat_batch_indices) else torch.zeros(0) + dim_scores[flat_batch_indices, dim_indices] if len(flat_batch_indices) else torch.zeros(0, device=device) for dim_scores, dim_indices in zip(grid_probs, grid_indices.T)] flat_scores = torch.prod(torch.stack(scores_per_dim, dim=0), dim=0) - scores = torch.full((batch_size, max_num_experts), fill_value=-float('inf'), device=grid_probs[0].device) + scores = torch.full((batch_size, max_num_experts), fill_value=-float('inf'), device=device) scores[flat_batch_indices, flat_local_indices] = flat_scores # backprop-able w.r.t. flat_scores return scores From f79f181e872a16174559c76b5fd93b74dee5ca5e Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Fri, 7 May 2021 01:29:18 +0300 Subject: [PATCH 2/9] Put pipe_awaiter in a context manager --- hivemind/dht/__init__.py | 39 ++++++++++++++++++++------------------- 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/hivemind/dht/__init__.py b/hivemind/dht/__init__.py index 7a6f6ec30..d4eb38284 100644 --- a/hivemind/dht/__init__.py +++ b/hivemind/dht/__init__.py @@ -69,25 +69,26 @@ def __init__(self, listen_on: Endpoint = "0.0.0.0:*", initial_peers: Sequence[En def run(self) -> None: """ Serve DHT forever. This function will not return until DHT node is shut down """ loop = switch_to_uvloop() - pipe_awaiter = ThreadPoolExecutor(max_workers=1) - async def _run(): - node = await DHTNode.create( - initial_peers=list(self.initial_peers), listen_on=self.listen_on, parallel_rpc=self.parallel_rpc, - num_workers=self.max_workers or 1, record_validator=self._record_validator, - **self.kwargs) - if node.port is not None: - self._port.value = node.port - self.ready.set() - - while True: - method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._pipe.recv) - asyncio.create_task(getattr(self, method)(node, *args, **kwargs)) - - try: - loop.run_until_complete(_run()) - except KeyboardInterrupt: - logger.debug("Caught KeyboardInterrupt, shutting down") + with ThreadPoolExecutor(max_workers=1) as pipe_awaiter: + async def _run(): + node = await DHTNode.create( + initial_peers=list(self.initial_peers), listen_on=self.listen_on, parallel_rpc=self.parallel_rpc, + num_workers=self.max_workers or 1, record_validator=self._record_validator, + **self.kwargs) + if node.port is not None: + self._port.value = node.port + self.ready.set() + + while True: + method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._pipe.recv) + asyncio.create_task(getattr(self, method)(node, *args, **kwargs)) + + coro = _run() + try: + loop.run_until_complete(coro) + except KeyboardInterrupt: + logger.debug("Caught KeyboardInterrupt, shutting down") def run_in_background(self, await_ready=True, timeout=None): """ @@ -96,7 +97,7 @@ def run_in_background(self, await_ready=True, timeout=None): """ self.start() if await_ready and not self.ready.wait(timeout=timeout): - raise TimeoutError(f"Server didn't notify .ready in {timeout} seconds") + raise TimeoutError(f"DHT didn't notify .ready in {timeout} seconds") def shutdown(self) -> None: """ Shut down a running dht process """ From 357f9b0efa5da46a6e7853ff287ab87bce1b11a1 Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Fri, 7 May 2021 01:29:58 +0300 Subject: [PATCH 3/9] Pass min_batch_size to ExpertBackend in Server.create --- hivemind/hivemind_cli/run_server.py | 4 +++- hivemind/server/__init__.py | 13 ++++++------- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/hivemind/hivemind_cli/run_server.py b/hivemind/hivemind_cli/run_server.py index fe29eb4fc..7a3c6ba6b 100644 --- a/hivemind/hivemind_cli/run_server.py +++ b/hivemind/hivemind_cli/run_server.py @@ -32,7 +32,9 @@ def main(): parser.add_argument('--num_handlers', type=int, default=None, required=False, help='server will use this many processes to handle incoming requests') - parser.add_argument('--max_batch_size', type=int, default=16384, required=False, + parser.add_argument('--min_batch_size', type=int, default=1, + help='Minimum required batch size for all expert operations') + parser.add_argument('--max_batch_size', type=int, default=16384, help='The total number of examples in the same batch will not exceed this value') parser.add_argument('--device', type=str, default=None, required=False, help='all experts will use this device in torch notation; default: cuda if available else cpu') diff --git a/hivemind/server/__init__.py b/hivemind/server/__init__.py index 4aa7e9404..db101c549 100644 --- a/hivemind/server/__init__.py +++ b/hivemind/server/__init__.py @@ -71,10 +71,10 @@ def __init__( @classmethod def create(cls, listen_on='0.0.0.0:*', num_experts: int = None, expert_uids: str = None, expert_pattern: str = None, expert_cls='ffn', hidden_dim=1024, optim_cls=torch.optim.Adam, scheduler: str = 'none', - num_warmup_steps=None, num_total_steps=None, clip_grad_norm=None, num_handlers=None, max_batch_size=4096, - device=None, no_dht=False, initial_peers=(), dht_port=None, checkpoint_dir: Optional[Path] = None, - compression=CompressionType.NONE, stats_report_interval: Optional[int] = None, custom_module_path=None, - *, start: bool, **kwargs) -> Server: + num_warmup_steps=None, num_total_steps=None, clip_grad_norm=None, num_handlers=None, min_batch_size=1, + max_batch_size=4096, device=None, no_dht=False, initial_peers=(), dht_port=None, + checkpoint_dir: Optional[Path] = None, compression=CompressionType.NONE, + stats_report_interval: Optional[int] = None, custom_module_path=None, *, start: bool) -> Server: """ Instantiate a server with several identical experts. See argparse comments below for details :param listen_on: network interface with address and (optional) port, e.g. "127.0.0.1:1337" or "[::]:80" @@ -85,6 +85,7 @@ def create(cls, listen_on='0.0.0.0:*', num_experts: int = None, expert_uids: str :param expert_cls: expert type from hivemind.server.layers, e.g. 'ffn', 'transformer', 'det_dropout' or 'nop'; :param hidden_dim: main dimension for expert_cls :param num_handlers: server will use this many parallel processes to handle incoming requests + :param min_batch_size: total num examples in the same batch will be greater than this value :param max_batch_size: total num examples in the same batch will not exceed this value :param device: all experts will use this device in torch notation; default: cuda if available else cpu @@ -112,9 +113,6 @@ def create(cls, listen_on='0.0.0.0:*', num_experts: int = None, expert_uids: str """ if custom_module_path is not None: add_custom_models_from_file(custom_module_path) - - if len(kwargs) != 0: - logger.info("Ignored kwargs:", kwargs) assert expert_cls in name_to_block if no_dht: @@ -172,6 +170,7 @@ def create(cls, listen_on='0.0.0.0:*', num_experts: int = None, expert_uids: str num_warmup_steps=num_warmup_steps, num_total_steps=num_total_steps, clip_grad_norm=clip_grad_norm, + min_batch_size=min_batch_size, max_batch_size=max_batch_size) if checkpoint_dir is not None: From 02a4737b9276ba68faa3ea1777e0d6bc8222fc19 Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Fri, 7 May 2021 01:30:52 +0300 Subject: [PATCH 4/9] Remove unneeded variable for exception in generate_uids_from_pattern --- hivemind/server/expert_uid.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/hivemind/server/expert_uid.py b/hivemind/server/expert_uid.py index b8b1ec03a..4053c1c9e 100644 --- a/hivemind/server/expert_uid.py +++ b/hivemind/server/expert_uid.py @@ -62,8 +62,8 @@ def _generate_uid(): uid.append(str(random.randint(slice_start, slice_end - 1))) else: raise ValueError("Block must be either fixed or a range [from:to]") - except KeyboardInterrupt as e: - raise e + except KeyboardInterrupt: + raise except Exception as e: raise ValueError(f"Expert pattern {expert_pattern} has invalid block {block}, {e}") return UID_DELIMITER.join(uid) From e25348c610adeb5bac8bbb3b9414cec2c48d023e Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Fri, 7 May 2021 01:32:56 +0300 Subject: [PATCH 5/9] Overhaul server architecture --- hivemind/server/__init__.py | 39 +++++--- hivemind/server/connection_handler.py | 5 +- hivemind/server/expert_backend.py | 8 +- hivemind/server/runtime.py | 35 ++++--- hivemind/server/task_pool.py | 129 ++++++++++++-------------- 5 files changed, 109 insertions(+), 107 deletions(-) diff --git a/hivemind/server/__init__.py b/hivemind/server/__init__.py index db101c549..b34e6d4fc 100644 --- a/hivemind/server/__init__.py +++ b/hivemind/server/__init__.py @@ -65,6 +65,10 @@ def __init__( self.checkpoint_saver = None self.runtime = Runtime(self.experts, **kwargs) + if self.dht and self.experts: + self.dht_handler_thread = DHTHandlerThread(experts=self.experts, dht=self.dht, endpoint=self.listen_on, + update_period=self.update_period) + if start: self.run_in_background(await_ready=True) @@ -195,9 +199,7 @@ def run(self): self.dht.run_in_background(await_ready=True) if self.experts: - dht_handler_thread = DHTHandlerThread( - experts=self.experts, dht=self.dht, endpoint=self.listen_on, update_period=self.update_period) - dht_handler_thread.start() + self.dht_handler_thread.start() if self.checkpoint_saver is not None: self.checkpoint_saver.start() @@ -206,16 +208,10 @@ def run(self): process.start() process.ready.wait() - self.runtime.run() - - for process in self.conn_handlers: - process.join() - if self.dht and self.experts: - dht_handler_thread.stop.set() - dht_handler_thread.join() - if self.checkpoint_saver is not None: - self.checkpoint_saver.stop.set() - self.checkpoint_saver.join() + try: + self.runtime.run() + finally: + self.shutdown() def run_in_background(self, await_ready=True, timeout=None): """ @@ -241,19 +237,32 @@ def ready(self) -> mp.synchronize.Event: def shutdown(self): """ - Gracefully terminate a hivemind server, process-safe. + Gracefully terminate the server, process-safe. Please note that terminating server otherwise (e.g. by killing processes) may result in zombie processes. If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL). """ self.ready.clear() + for process in self.conn_handlers: process.terminate() + process.join() + logger.debug("Connection handlers terminated") + + if self.dht and self.experts: + self.dht_handler_thread.stop.set() + self.dht_handler_thread.join() + + if self.checkpoint_saver is not None: + self.checkpoint_saver.stop.set() + self.checkpoint_saver.join() if self.dht is not None: self.dht.shutdown() self.dht.join() - self.runtime.shutdown() + logger.debug(f"Shutting down runtime") + self.runtime.stop.set() + logger.info("Server shutdown succesfully") @contextmanager diff --git a/hivemind/server/connection_handler.py b/hivemind/server/connection_handler.py index f19042cdc..7f5307af2 100644 --- a/hivemind/server/connection_handler.py +++ b/hivemind/server/connection_handler.py @@ -52,7 +52,10 @@ async def _run(): await server.wait_for_termination() logger.debug(f"ConnectionHandler terminated: (pid={os.getpid()})") - loop.run_until_complete(_run()) + try: + loop.run_until_complete(_run()) + except KeyboardInterrupt: + logger.debug('Caught KeyboardInterrupt, shutting down') async def info(self, request: runtime_pb2.ExpertUID, context: grpc.ServicerContext): return runtime_pb2.ExpertInfo(serialized_info=pickle.dumps(self.experts[request.uid].get_info())) diff --git a/hivemind/server/expert_backend.py b/hivemind/server/expert_backend.py index cb0006a5a..e6d867c42 100644 --- a/hivemind/server/expert_backend.py +++ b/hivemind/server/expert_backend.py @@ -74,12 +74,13 @@ def __init__(self, name: str, expert: nn.Module, optimizer: torch.optim.Optimize self.backward_schema = (self.forward_schema, self.outputs_schema) # inputs to backward self.grad_inputs_schema = self.forward_schema # outputs from backward - self.forward_pool = TaskPool(self.forward, uid=f'{self.name}_forward', **kwargs) - self.backward_pool = TaskPool(self.backward, uid=f'{self.name}_backward', **kwargs) + self.forward_pool = TaskPool(self.forward, name=f'{self.name}_forward', **kwargs) + self.backward_pool = TaskPool(self.backward, name=f'{self.name}_backward', **kwargs) self.update_count = 0 self.examples_processed = 0 + @torch.no_grad() def forward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]: """ Apply forward pass to an aggregated batch of requests. Used by Runtime, do not call this manually; @@ -99,8 +100,7 @@ def forward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]: if args[0].shape[0] == 0: raise RuntimeError("Batch should contain more than 0 samples") - with torch.no_grad(): - outputs = self.expert(*args, **kwargs) + outputs = self.expert(*args, **kwargs) # Note: TaskPool requires function to accept and return a flat tuple of values, we pack/unpack it on client side return tuple(nested_flatten(outputs)) diff --git a/hivemind/server/runtime.py b/hivemind/server/runtime.py index f8df3c088..3fa99f999 100644 --- a/hivemind/server/runtime.py +++ b/hivemind/server/runtime.py @@ -48,8 +48,8 @@ def __init__(self, expert_backends: Dict[str, ExpertBackend], prefetch_batches=6 self.expert_backends = expert_backends self.pools = tuple(chain(*(expert.get_pools() for expert in expert_backends.values()))) self.device, self.prefetch_batches, self.sender_threads = device, prefetch_batches, sender_threads - self.shutdown_recv, self.shutdown_send = mp.Pipe(duplex=False) self.ready = mp.Event() # event is set iff server is currently running and ready to accept batches + self.stop = threading.Event() self.stats_report_interval = stats_report_interval if self.stats_report_interval is not None: @@ -72,49 +72,48 @@ def run(self): for pool, batch_index, batch in BackgroundGenerator( self.iterate_minibatches_from_pools(), self.prefetch_batches): - logger.debug(f"Processing batch {batch_index} from pool {pool.uid}") + logger.debug(f"Processing batch {batch_index} from pool {pool.name}") start = time() outputs = pool.process_func(*batch) batch_processing_time = time() - start batch_size = outputs[0].size(0) - logger.debug(f"Pool {pool.uid}: batch {batch_index} processed, size {batch_size}") + logger.debug(f"Pool {pool.name}: batch {batch_index} processed, size {batch_size}") if self.stats_report_interval is not None: - self.stats_reporter.report_stats(pool.uid, batch_size, batch_processing_time) + self.stats_reporter.report_stats(pool.name, batch_size, batch_processing_time) output_sender_pool.apply_async(pool.send_outputs_from_runtime, args=[batch_index, outputs]) finally: - logger.info("Shutting down") - - if self.stats_report_interval is not None: - self.stats_reporter.stop.set() - self.stats_reporter.join() - self.shutdown() - SHUTDOWN_TRIGGER = "RUNTIME SHUTDOWN TRIGGERED" - def shutdown(self): """ Gracefully terminate a running runtime. """ - self.ready.clear() - self.shutdown_send.send(self.SHUTDOWN_TRIGGER) # trigger background thread to shutdown + logger.info("Shutting down") + + if self.stats_report_interval is not None: + self.stats_reporter.stop.set() + self.stats_reporter.join() + + self.stop.set() # trigger background thread to shutdown + + logger.debug("Terminating pools") for pool in self.pools: if pool.is_alive(): pool.terminate() pool.join() + logger.debug("Pools terminated") def iterate_minibatches_from_pools(self, timeout=None): """ Chooses pool according to priority, then copies exposed batch and frees the buffer """ with DefaultSelector() as selector: - selector.register(self.shutdown_recv, EVENT_READ, self.SHUTDOWN_TRIGGER) for pool in self.pools: selector.register(pool.batch_receiver, EVENT_READ, pool) - while True: + while not self.stop.is_set(): # wait until at least one batch_receiver becomes available logger.debug("Waiting for inputs from task pools") ready_fds = selector.select() @@ -125,9 +124,9 @@ def iterate_minibatches_from_pools(self, timeout=None): logger.debug("Choosing the pool with highest priority") pool = max(ready_objects, key=lambda pool: pool.priority) - logger.debug(f"Loading batch from {pool.uid}") + logger.debug(f"Loading batch from {pool.name}") batch_index, batch_tensors = pool.load_batch_to_runtime(timeout, self.device) - logger.debug(f"Loaded batch from {pool.uid}") + logger.debug(f"Loaded batch from {pool.name}") yield pool, batch_index, batch_tensors diff --git a/hivemind/server/task_pool.py b/hivemind/server/task_pool.py index 0dac5a47a..8f5a2aec9 100644 --- a/hivemind/server/task_pool.py +++ b/hivemind/server/task_pool.py @@ -6,7 +6,6 @@ import os import threading import time -import uuid from abc import ABCMeta, abstractmethod from collections import namedtuple from concurrent.futures import Future @@ -24,8 +23,8 @@ class TaskPoolBase(mp.context.ForkProcess, metaclass=ABCMeta): """ A pool that accepts tasks and forms batches for parallel processing, interacts with Runtime """ - def __init__(self, process_func: callable, daemon=True): - super().__init__(daemon=daemon) + def __init__(self, process_func: callable, **kwargs): + super().__init__(**kwargs) self.process_func = process_func self._priority = mp.Value(ctypes.c_double, 1.0) # higher priority = the more urgent to process this pool @@ -63,6 +62,7 @@ class TaskPool(TaskPoolBase): :param process_func: function to be applied to every formed batch; called by Runtime Note that process_func should accept only positional args (Tensors) and return a flat tuple of Tensors :param max_batch_size: process at most this many inputs in a batch (task contains have one or several inputs) + :param name: pool name :param min_batch_size: process at least this many inputs in a batch, otherwise wait for more :param timeout: wait for a subsequent task for at most this many seconds :param pool_size: store at most this many unprocessed tasks in a queue @@ -71,11 +71,10 @@ class TaskPool(TaskPoolBase): :param start: if True, start automatically at the end of __init__ """ - def __init__(self, process_func: callable, max_batch_size: int, min_batch_size=1, - timeout=None, pool_size=None, prefetch_batches=1, uid=None, daemon=True, start=False): - super().__init__(process_func, daemon=daemon) + def __init__(self, process_func: callable, max_batch_size: int, name: str, min_batch_size=1, + timeout=None, pool_size=None, prefetch_batches=1, daemon=True, start=False): + super().__init__(process_func, daemon=daemon, name=name) self.min_batch_size, self.max_batch_size, self.timeout = min_batch_size, max_batch_size, timeout - self.uid = uid or uuid.uuid4() self.prefetch_batches = prefetch_batches # interaction with ConnectionHandlers @@ -112,7 +111,7 @@ def iterate_minibatches(self, *args, **kwargs): batch = [] total_size = 0 try: - logger.debug(f"{self.uid} getting next task") + logger.debug(f"{self.name} getting next task") task = self.tasks.get(timeout=self.timeout) except Empty: logger.warning(f"Timeout reached but batch doesn't contain >={self.min_batch_size} elements yet") @@ -134,80 +133,72 @@ def iterate_minibatches(self, *args, **kwargs): def run(self, *args, **kwargs): torch.set_num_threads(1) - logger.info(f'{self.uid} starting, pid={os.getpid()}') + logger.info(f'{self.name} starting, pid={os.getpid()}') pending_batches = {} # Dict[batch uuid, List[MPFuture]] for each batch currently in runtime + output_thread = threading.Thread(target=self._pool_output_loop, args=[pending_batches], - name=f'{self.uid}_output') + name=f'{self.name}_output') + try: output_thread.start() self._pool_input_loop(pending_batches, *args, **kwargs) - except BaseException as e: - # terminate output loop - self.outputs_sender.send(e) + except KeyboardInterrupt: + logger.debug('Caught KeyboardInterrupt, shutting down') + finally: output_thread.join() - raise e def _pool_input_loop(self, pending_batches: Dict[Any, List[Task]], *args, **kwargs): """ Infinite loop: aggregate tasks into batches and send them to runtime """ - try: - prev_num_tasks = 0 # number of tasks currently in shared buffer - batch_index = max(pending_batches.keys(), default=0) - batch_iterator = self.iterate_minibatches(*args, **kwargs) - - while True: - # SIDE-EFFECT - compute pool priority from timestamp of earliest undispatched task - # assumes that tasks are processed in the same order as they are created - for skip_i in range(prev_num_tasks): - finished_task_timestamp = self.undispatched_task_timestamps.get() # earlier timestamp = higher priority - if skip_i == prev_num_tasks - 1: - self.priority = finished_task_timestamp - - logger.debug(f"{self.uid} getting next batch") - batch_tasks = next(batch_iterator) - # save batch futures, _output_loop will deliver on them later - pending_batches[batch_index] = batch_tasks - - logger.debug(f"{self.uid}, batch {batch_index}: aggregating inputs") - # find or create shared arrays for current batch size - batch_inputs = [torch.cat([task.args[i] for task in batch_tasks]) for i in - range(len(batch_tasks[0].args))] - batch_inputs = [inp.detach().requires_grad_(inp.requires_grad).share_memory_() for inp in batch_inputs] - - logger.debug(f"{self.uid}, batch {batch_index}: sending to runtime") - self.batch_sender.send((batch_index, batch_inputs)) - logger.debug(f"{self.uid}, batch {batch_index}: sent to runtime") - prev_num_tasks = len(batch_tasks) - batch_index += 1 - except KeyboardInterrupt: - logger.debug('Caught KeyboardInterrupt, shutting down') + + prev_num_tasks = 0 # number of tasks currently in shared buffer + batch_index = max(pending_batches.keys(), default=0) + batch_iterator = self.iterate_minibatches(*args, **kwargs) + + while True: + # SIDE-EFFECT - compute pool priority from timestamp of earliest undispatched task + # assumes that tasks are processed in the same order as they are created + for skip_i in range(prev_num_tasks): + finished_task_timestamp = self.undispatched_task_timestamps.get() # earlier timestamp = higher priority + if skip_i == prev_num_tasks - 1: + self.priority = finished_task_timestamp + + logger.debug(f"{self.name} getting next batch") + batch_tasks = next(batch_iterator) + # save batch futures, _output_loop will deliver on them later + pending_batches[batch_index] = batch_tasks + + logger.debug(f"{self.name}, batch {batch_index}: aggregating inputs") + # find or create shared arrays for current batch size + batch_inputs = [torch.cat([task.args[i] for task in batch_tasks]) for i in + range(len(batch_tasks[0].args))] + batch_inputs = [inp.detach().requires_grad_(inp.requires_grad).share_memory_() for inp in batch_inputs] + + logger.debug(f"{self.name}, batch {batch_index}: sending to runtime") + self.batch_sender.send((batch_index, batch_inputs)) + logger.debug(f"{self.name}, batch {batch_index}: sent to runtime") + prev_num_tasks = len(batch_tasks) + batch_index += 1 def _pool_output_loop(self, pending_batches: Dict[Any, List[Task]]): """ Infinite loop: receive results from runtime and dispatch them to task Futures """ - try: - while True: - logger.debug(f"{self.uid} waiting for results from runtime") - payload = self.outputs_receiver.recv() - if isinstance(payload, BaseException): - raise payload - else: - batch_index, batch_outputs = payload - logger.debug(f"{self.uid}, batch {batch_index}: got results") - - # split batch into partitions for individual tasks - batch_tasks = pending_batches.pop(batch_index) - task_sizes = [self.get_task_size(task) for task in batch_tasks] - outputs_per_task = zip(*(torch.split_with_sizes(tensor, task_sizes, dim=0) for tensor in batch_outputs)) - logger.debug(f"{self.uid}, batch {batch_index}: sending outputs to handlers") - - # dispatch results to futures - for task, task_outputs in zip(batch_tasks, outputs_per_task): - try: - task.future.set_result(tuple(task_outputs)) - except FutureStateError as e: - logger.debug(f"Failed to send task result due to an exception: {e}") - except KeyboardInterrupt: - logger.debug(f"Caught KeyboardInterrupt, shutting down") + while True: + logger.debug(f"{self.name} waiting for results from runtime") + batch_index, batch_outputs = self.outputs_receiver.recv() + logger.debug(f"{self.name}, batch {batch_index}: got results") + + # split batch into partitions for individual tasks + batch_tasks = pending_batches.pop(batch_index) + task_sizes = [self.get_task_size(task) for task in batch_tasks] + outputs_per_task = zip(*(torch.split_with_sizes(tensor, task_sizes, dim=0) for tensor in batch_outputs)) + logger.debug(f"{self.name}, batch {batch_index}: sending outputs to handlers") + + # dispatch results to futures + for task, task_outputs in zip(batch_tasks, outputs_per_task): + try: + task.future.set_result(tuple(task_outputs)) + except FutureStateError as e: + logger.debug(f"Failed to send task result due to an exception: {e}") @property def empty(self): From 6fb6252d0bf1d7e0aa9f5e81d249195ef58cbf77 Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Fri, 7 May 2021 01:33:05 +0300 Subject: [PATCH 6/9] Overhaul server architecture --- hivemind/server/runtime.py | 2 -- hivemind/server/task_pool.py | 1 - 2 files changed, 3 deletions(-) diff --git a/hivemind/server/runtime.py b/hivemind/server/runtime.py index 3fa99f999..95ae8c99d 100644 --- a/hivemind/server/runtime.py +++ b/hivemind/server/runtime.py @@ -118,8 +118,6 @@ def iterate_minibatches_from_pools(self, timeout=None): logger.debug("Waiting for inputs from task pools") ready_fds = selector.select() ready_objects = {key.data for (key, events) in ready_fds} - if self.SHUTDOWN_TRIGGER in ready_objects: - break # someone asked us to shutdown, break from the loop logger.debug("Choosing the pool with highest priority") pool = max(ready_objects, key=lambda pool: pool.priority) diff --git a/hivemind/server/task_pool.py b/hivemind/server/task_pool.py index 8f5a2aec9..858b632ad 100644 --- a/hivemind/server/task_pool.py +++ b/hivemind/server/task_pool.py @@ -67,7 +67,6 @@ class TaskPool(TaskPoolBase): :param timeout: wait for a subsequent task for at most this many seconds :param pool_size: store at most this many unprocessed tasks in a queue :param prefetch_batches: prepare up to this many *batches* in background for faster off-loading to runtime - :param uid: pool identifier used for shared array allocation :param start: if True, start automatically at the end of __init__ """ From e39c44e8222d30b55ac799abf269a9241ddb37a7 Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Fri, 7 May 2021 01:50:10 +0300 Subject: [PATCH 7/9] Address review fixes --- hivemind/client/averaging/__init__.py | 54 +++++++++++++-------------- hivemind/client/moe.py | 11 ++++-- hivemind/dht/__init__.py | 5 +-- 3 files changed, 35 insertions(+), 35 deletions(-) diff --git a/hivemind/client/averaging/__init__.py b/hivemind/client/averaging/__init__.py index 0181eac2e..c009cb41d 100644 --- a/hivemind/client/averaging/__init__.py +++ b/hivemind/client/averaging/__init__.py @@ -171,35 +171,34 @@ def _run_internal(self): """ Serve DecentralizedAverager forever. This function will not return until the averager is shut down """ loop = switch_to_uvloop() # initialize asyncio synchronization primitives in this event loop - pipe_awaiter = ThreadPoolExecutor(max_workers=1) - - async def _run(): - grpc.aio.init_grpc_aio() - - if self.listen: - server = grpc.aio.server(**self.kwargs, options=GRPC_KEEPALIVE_OPTIONS) - averaging_pb2_grpc.add_DecentralizedAveragingServicer_to_server(self, server) - found_port = server.add_insecure_port(self.listen_on) - assert found_port != 0, f"Failed to listen to {self.listen_on}" - self._port.value = found_port - await server.start() - else: - logger.info(f"The averager running in an experimental client mode, please report any bugs.") + with ThreadPoolExecutor(max_workers=1) as pipe_awaiter: + async def _run(): + grpc.aio.init_grpc_aio() + + if self.listen: + server = grpc.aio.server(**self.kwargs, options=GRPC_KEEPALIVE_OPTIONS) + averaging_pb2_grpc.add_DecentralizedAveragingServicer_to_server(self, server) + found_port = server.add_insecure_port(self.listen_on) + assert found_port != 0, f"Failed to listen to {self.listen_on}" + self._port.value = found_port + await server.start() + else: + logger.info(f"The averager running in an experimental client mode, please report any bugs.") - self._matchmaking = Matchmaking(self.endpoint, self.schema_hash, self.dht, **self.matchmaking_kwargs, - client_mode=not self.listen) - if self.listen: - asyncio.create_task(self._declare_for_download_periodically()) + self._matchmaking = Matchmaking(self.endpoint, self.schema_hash, self.dht, **self.matchmaking_kwargs, + client_mode=not self.listen) + if self.listen: + asyncio.create_task(self._declare_for_download_periodically()) - self._pending_group_assembled = asyncio.Event() - self._pending_group_assembled.set() - self.ready.set() + self._pending_group_assembled = asyncio.Event() + self._pending_group_assembled.set() + self.ready.set() - while True: - method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._pipe.recv) - asyncio.create_task(getattr(self, method)(*args, **kwargs)) + while True: + method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._pipe.recv) + asyncio.create_task(getattr(self, method)(*args, **kwargs)) - loop.run_until_complete(_run()) + loop.run_until_complete(_run()) def run_in_background(self, await_ready=True, timeout=None): """ @@ -255,7 +254,8 @@ async def _step(self, *, future: MPFuture, gather_binary: bytes, weight: float, try: self._pending_group_assembled.clear() data_for_gather = self.serializer.dumps([weight, self._throughput, self.listen, gather_binary]) - group_info = await self._matchmaking.look_for_group(timeout=timeout, data_for_gather=data_for_gather) + group_info = await self._matchmaking.look_for_group(timeout=timeout, + data_for_gather=data_for_gather) if group_info is None: raise AllreduceException("Averaging step failed: could not find a group.") group_id = group_info.group_id @@ -294,7 +294,7 @@ async def _make_allreduce_runner(self, group_info: GroupInfo, min_vector_size: i """ Use a group description found by Matchmaking to form AllreduceRunner """ try: weights, throughputs, modes, user_gathered = zip(*map(self.serializer.loads, group_info.gathered)) - user_gathered = dict(zip(group_info.endpoints, map(self.serializer.loads, user_gathered))) + user_gathered = dict(zip(group_info.endpoints, map(self.serializer.loads, user_gathered))) # compute optimal part sizes from peer throughputs incoming_throughputs = [thr if listen else 0.0 for thr, listen in zip(throughputs, modes)] diff --git a/hivemind/client/moe.py b/hivemind/client/moe.py index 1ff015d80..9cadf81b3 100644 --- a/hivemind/client/moe.py +++ b/hivemind/client/moe.py @@ -120,8 +120,11 @@ def compute_expert_scores( batch_size = len(batch_experts) max_num_experts = max(expert_counts) total_num_experts = sum(expert_counts) - expert_index_in_batch = torch.arange(total_num_experts, device=grid_scores[0].device) - expert_strides = torch.cumsum(torch.as_tensor([0] + expert_counts, device=grid_scores[0].device), dim=-1)[:-1] + + device = grid_scores[0].device + + expert_index_in_batch = torch.arange(total_num_experts, device=device) + expert_strides = torch.cumsum(torch.as_tensor([0] + expert_counts, device=device), dim=-1)[:-1] flat_batch_indices = (expert_index_in_batch >= expert_strides[:, None]).to(torch.int32).sum(0) - 1 flat_local_indices = expert_index_in_batch - expert_strides[flat_batch_indices] flat_experts = [expert for row in batch_experts for expert in row] @@ -133,11 +136,11 @@ def compute_expert_scores( grid_indices[i] = torch.as_tensor(expert_indices, dtype=grid_indices.dtype) scores_per_dim = [ - dim_scores[flat_batch_indices, dim_indices] if len(flat_batch_indices) else torch.zeros(0) + dim_scores[flat_batch_indices, dim_indices] if len(flat_batch_indices) else torch.zeros(0, device=device) for dim_scores, dim_indices in zip(grid_scores, grid_indices.T)] flat_scores = torch.sum(torch.stack(scores_per_dim, dim=0), dim=0) - scores = torch.full((batch_size, max_num_experts), fill_value=-float('inf'), device=grid_scores[0].device) + scores = torch.full((batch_size, max_num_experts), fill_value=-float('inf'), device=device) scores[flat_batch_indices, flat_local_indices] = flat_scores # backprop-able w.r.t. flat_scores return scores diff --git a/hivemind/dht/__init__.py b/hivemind/dht/__init__.py index d4eb38284..fc037ebdf 100644 --- a/hivemind/dht/__init__.py +++ b/hivemind/dht/__init__.py @@ -85,10 +85,7 @@ async def _run(): asyncio.create_task(getattr(self, method)(node, *args, **kwargs)) coro = _run() - try: - loop.run_until_complete(coro) - except KeyboardInterrupt: - logger.debug("Caught KeyboardInterrupt, shutting down") + loop.run_until_complete(coro) def run_in_background(self, await_ready=True, timeout=None): """ From df87874cf5abb8348f6b88aef8e23b13a14336b5 Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Fri, 7 May 2021 01:52:02 +0300 Subject: [PATCH 8/9] daemon=True by default --- hivemind/server/task_pool.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/hivemind/server/task_pool.py b/hivemind/server/task_pool.py index 858b632ad..b0d03540d 100644 --- a/hivemind/server/task_pool.py +++ b/hivemind/server/task_pool.py @@ -23,8 +23,8 @@ class TaskPoolBase(mp.context.ForkProcess, metaclass=ABCMeta): """ A pool that accepts tasks and forms batches for parallel processing, interacts with Runtime """ - def __init__(self, process_func: callable, **kwargs): - super().__init__(**kwargs) + def __init__(self, process_func: callable, daemon=True, **kwargs): + super().__init__(daemon=daemon, **kwargs) self.process_func = process_func self._priority = mp.Value(ctypes.c_double, 1.0) # higher priority = the more urgent to process this pool From af245738d569106c4b3b649136804866cb25a545 Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Fri, 7 May 2021 02:01:27 +0300 Subject: [PATCH 9/9] Return torch.no_grad inside forward --- hivemind/server/expert_backend.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/hivemind/server/expert_backend.py b/hivemind/server/expert_backend.py index e6d867c42..a67dba652 100644 --- a/hivemind/server/expert_backend.py +++ b/hivemind/server/expert_backend.py @@ -80,7 +80,6 @@ def __init__(self, name: str, expert: nn.Module, optimizer: torch.optim.Optimize self.update_count = 0 self.examples_processed = 0 - @torch.no_grad() def forward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]: """ Apply forward pass to an aggregated batch of requests. Used by Runtime, do not call this manually; @@ -100,7 +99,8 @@ def forward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]: if args[0].shape[0] == 0: raise RuntimeError("Batch should contain more than 0 samples") - outputs = self.expert(*args, **kwargs) + with torch.no_grad(): + outputs = self.expert(*args, **kwargs) # Note: TaskPool requires function to accept and return a flat tuple of values, we pack/unpack it on client side return tuple(nested_flatten(outputs))