diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 6e1389a4b..c9af194c4 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -10,7 +10,7 @@ from hivemind.moe.client.expert import RemoteExpert from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker from hivemind.moe.expert_uid import ExpertInfo -from hivemind.moe.server import ExpertBackend, Server +from hivemind.moe.server import ModuleBackend, Server from hivemind.moe.server.layers import name_to_block from hivemind.p2p import P2P from hivemind.utils.limits import increase_file_limit @@ -118,12 +118,12 @@ def benchmark_throughput( timestamps["launched_clients"] = timestamps["began_launching_server"] = time.perf_counter() device = device or ("cuda" if torch.cuda.is_available() else "cpu") - experts = {} + module_backends = {} for i in range(num_experts): expert = torch.jit.script(name_to_block[expert_cls](hid_dim)) - experts[f"expert.{i}"] = ExpertBackend( + module_backends[f"expert.{i}"] = ModuleBackend( name=f"expert.{i}", - expert=expert, + module=expert, optimizer=torch.optim.Adam(expert.parameters()), args_schema=(BatchTensorDescriptor(hid_dim),), outputs_schema=BatchTensorDescriptor(hid_dim), @@ -133,7 +133,7 @@ def benchmark_throughput( server = Server( dht=server_dht, - expert_backends=experts, + module_backends=module_backends, num_connection_handlers=num_handlers, device=device, ) diff --git a/docs/modules/server.rst b/docs/modules/server.rst index 4e8c61456..a958ec057 100644 --- a/docs/modules/server.rst +++ b/docs/modules/server.rst @@ -9,9 +9,9 @@ or as a part of **hivemind.moe.client.RemoteMixtureOfExperts** that finds the mo The hivemind.moe.server module is organized as follows: - Server_ is the main class that publishes experts, accepts incoming requests, and passes them to Runtime_ for compute. -- ExpertBackend_ is a wrapper for `torch.nn.Module `_ \ +- ModuleBackend_ is a wrapper for `torch.nn.Module `_ \ that can be accessed by remote clients. It has two TaskPool_ s for forward and backward requests. -- Runtime_ balances the device (GPU) usage between several ExpertBackend_ instances that each service one expert. +- Runtime_ balances the device (GPU) usage between several ModuleBackend_ instances that each service one expert. - TaskPool_ stores incoming requests for a batch-parallel computation (e.g. forward pass), groups them into batches \ and offers those batches to Runtime_ for processing. @@ -25,9 +25,9 @@ The hivemind.moe.server module is organized as follows: :members: :member-order: bysource -.. _ExpertBackend: -.. autoclass:: ExpertBackend - :members: forward, backward, apply_gradients, get_info, get_pools +.. _ModuleBackend: +.. autoclass:: ModuleBackend + :members: forward, backward, on_backward, get_info, get_pools :member-order: bysource .. currentmodule:: hivemind.moe.server.runtime diff --git a/hivemind/__init__.py b/hivemind/__init__.py index 32443f7f7..f74a640a7 100644 --- a/hivemind/__init__.py +++ b/hivemind/__init__.py @@ -2,7 +2,7 @@ from hivemind.compression import * from hivemind.dht import DHT from hivemind.moe import ( - ExpertBackend, + ModuleBackend, RemoteExpert, RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts, diff --git a/hivemind/hivemind_cli/run_server.py b/hivemind/hivemind_cli/run_server.py index 0418ac8df..078702c8e 100644 --- a/hivemind/hivemind_cli/run_server.py +++ b/hivemind/hivemind_cli/run_server.py @@ -54,7 +54,8 @@ def main(): help='Server will report experts to DHT once in this many seconds') parser.add_argument('--expiration', type=float, required=False, default=None, help='DHT entries will expire after this many seconds') - parser.add_argument('--num_total_steps', type=int, required=False, help='The total number of steps for LR schedule') + parser.add_argument('--num_training_steps', type=int, required=False, help='The total number of steps for LR schedule') + parser.add_argument('--clip_grad_norm', type=float, required=False, help='Maximum gradient norm used for clipping') parser.add_argument('--initial_peers', type=str, nargs='*', required=False, default=[], diff --git a/hivemind/moe/__init__.py b/hivemind/moe/__init__.py index 00905507d..1436ab35d 100644 --- a/hivemind/moe/__init__.py +++ b/hivemind/moe/__init__.py @@ -1,6 +1,6 @@ from hivemind.moe.client import RemoteExpert, RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts from hivemind.moe.server import ( - ExpertBackend, + ModuleBackend, Server, background_server, declare_experts, diff --git a/hivemind/moe/server/__init__.py b/hivemind/moe/server/__init__.py index 1ac24db2d..b370ffbff 100644 --- a/hivemind/moe/server/__init__.py +++ b/hivemind/moe/server/__init__.py @@ -1,4 +1,4 @@ from hivemind.moe.server.dht_handler import declare_experts, get_experts -from hivemind.moe.server.expert_backend import ExpertBackend from hivemind.moe.server.layers import register_expert_class +from hivemind.moe.server.module_backend import ModuleBackend from hivemind.moe.server.server import Server, background_server diff --git a/hivemind/moe/server/checkpoints.py b/hivemind/moe/server/checkpoints.py index 23a4a4a2e..6003a1c39 100644 --- a/hivemind/moe/server/checkpoints.py +++ b/hivemind/moe/server/checkpoints.py @@ -8,7 +8,7 @@ import torch -from hivemind.moe.server.expert_backend import ExpertBackend +from hivemind.moe.server.module_backend import ModuleBackend from hivemind.utils.logging import get_logger logger = get_logger(__name__) @@ -34,23 +34,23 @@ def copy_tree(src: str, dst: str): class CheckpointSaver(threading.Thread): - def __init__(self, expert_backends: Dict[str, ExpertBackend], checkpoint_dir: Path, update_period: float): + def __init__(self, module_backends: Dict[str, ModuleBackend], checkpoint_dir: Path, update_period: float): super().__init__() assert is_directory(checkpoint_dir) - self.expert_backends = expert_backends + self.module_backends = module_backends self.update_period = update_period self.checkpoint_dir = checkpoint_dir self.stop = threading.Event() # create expert directories to ensure that the directory is writable and checkpoints can be loaded - store_experts(self.expert_backends, self.checkpoint_dir) + store_experts(self.module_backends, self.checkpoint_dir) def run(self) -> None: while not self.stop.wait(self.update_period): - store_experts(self.expert_backends, self.checkpoint_dir) + store_experts(self.module_backends, self.checkpoint_dir) -def store_experts(experts: Dict[str, ExpertBackend], checkpoint_dir: Path): +def store_experts(experts: Dict[str, ModuleBackend], checkpoint_dir: Path): logger.debug(f"Storing experts at {checkpoint_dir.absolute()}") assert is_directory(checkpoint_dir) timestamp = datetime.now().isoformat(sep="_") @@ -59,17 +59,17 @@ def store_experts(experts: Dict[str, ExpertBackend], checkpoint_dir: Path): expert_dir = Path(tmpdirname) / expert_name expert_dir.mkdir() checkpoint_name = expert_dir / f"checkpoint_{timestamp}.pt" - torch.save(expert_backend.get_full_state(), checkpoint_name) + torch.save(expert_backend.state_dict(), checkpoint_name) os.symlink(checkpoint_name, expert_dir / "checkpoint_last.pt") copy_tree(tmpdirname, str(checkpoint_dir)) -def load_experts(experts: Dict[str, ExpertBackend], checkpoint_dir: Path): +def load_experts(experts: Dict[str, ModuleBackend], checkpoint_dir: Path): assert is_directory(checkpoint_dir) for expert_name, expert in experts.items(): checkpoints_folder = checkpoint_dir / expert_name latest_checkpoint = checkpoints_folder / "checkpoint_last.pt" if latest_checkpoint.exists(): - expert.load_full_state(torch.load(latest_checkpoint)) + expert.load_state_dict(torch.load(latest_checkpoint)) else: logger.warning(f"Failed to load checkpoint for expert {expert_name}") diff --git a/hivemind/moe/server/connection_handler.py b/hivemind/moe/server/connection_handler.py index ff610a7bc..435a0b7d8 100644 --- a/hivemind/moe/server/connection_handler.py +++ b/hivemind/moe/server/connection_handler.py @@ -6,7 +6,7 @@ from hivemind.compression import deserialize_tensor_stream, deserialize_torch_tensor, serialize_torch_tensor from hivemind.dht import DHT -from hivemind.moe.server.expert_backend import ExpertBackend +from hivemind.moe.server.module_backend import ModuleBackend from hivemind.moe.server.task_pool import TaskPool from hivemind.p2p import P2PContext, ServicerBase from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE, P2P @@ -25,10 +25,10 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase): :note: ConnectionHandler is designed so as to allow using multiple handler processes for the same port :param dht: a running hivemind.dht.DHT, used to let other peers connect to this one - :param experts: a dict [UID -> ExpertBackend] with all active experts + :param experts: a dict [UID -> ModuleBackend] with all active experts """ - def __init__(self, dht: DHT, experts: Dict[str, ExpertBackend]): + def __init__(self, dht: DHT, experts: Dict[str, ModuleBackend]): super().__init__() self.dht, self.experts = dht, experts self._p2p: Optional[P2P] = None diff --git a/hivemind/moe/server/dht_handler.py b/hivemind/moe/server/dht_handler.py index ea1a4f658..e5cbb1935 100644 --- a/hivemind/moe/server/dht_handler.py +++ b/hivemind/moe/server/dht_handler.py @@ -20,20 +20,22 @@ class DHTHandlerThread(threading.Thread): - def __init__(self, experts, dht: DHT, update_period: float = 30, expiration: Optional[int] = None, **kwargs): + def __init__( + self, module_backends, dht: DHT, update_period: float = 30, expiration: Optional[int] = None, **kwargs + ): super().__init__(**kwargs) if expiration is None: expiration = max(2 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS) - self.experts = experts + self.module_backends = module_backends self.dht = dht self.update_period = update_period self.expiration = expiration self.stop = threading.Event() def run(self) -> None: - declare_experts(self.dht, self.experts.keys(), expiration_time=get_dht_time() + self.expiration) + declare_experts(self.dht, self.module_backends.keys(), expiration_time=get_dht_time() + self.expiration) while not self.stop.wait(self.update_period): - declare_experts(self.dht, self.experts.keys(), expiration_time=get_dht_time() + self.expiration) + declare_experts(self.dht, self.module_backends.keys(), expiration_time=get_dht_time() + self.expiration) def declare_experts( diff --git a/hivemind/moe/server/layers/dropout.py b/hivemind/moe/server/layers/dropout.py index 8efad903e..526787c7c 100644 --- a/hivemind/moe/server/layers/dropout.py +++ b/hivemind/moe/server/layers/dropout.py @@ -19,7 +19,7 @@ def backward(ctx, grad_output): class DeterministicDropout(nn.Module): """ Custom dropout layer which accepts dropout mask as an input (drop_prob is only used for scaling input activations). - Can be used with RemoteExpert/ExpertBackend to ensure that dropout mask is the same at forward and backward steps + Can be used with RemoteExpert/ModuleBackend to ensure that dropout mask is the same at forward and backward steps """ def __init__(self, drop_prob): diff --git a/hivemind/moe/server/layers/optim.py b/hivemind/moe/server/layers/optim.py new file mode 100644 index 000000000..f280ba427 --- /dev/null +++ b/hivemind/moe/server/layers/optim.py @@ -0,0 +1,58 @@ +import torch + + +class OptimizerWrapper(torch.optim.Optimizer): + """A wrapper for pytorch.optim.Optimizer that forwards all methods to the wrapped optimizer""" + + def __init__(self, optim: torch.optim.Optimizer): + super().__init__(optim.param_groups, optim.defaults) + self.optim = optim + + @property + def defaults(self): + return self.optim.defaults + + @property + def state(self): + return self.optim.state + + def __getstate__(self): + return self.optim.__getstate__() + + def __setstate__(self, state): + self.optim.__setstate__(state) + + def __repr__(self): + return f"{self.__class__.__name__}({repr(self.optim)})" + + def state_dict(self): + return self.optim.state_dict() + + def load_state_dict(self, state_dict: dict) -> None: + return self.optim.load_state_dict(state_dict) + + def step(self, *args, **kwargs): + return self.optim.step(*args, **kwargs) + + def zero_grad(self, *args, **kwargs): + return self.optim.zero_grad(*args, **kwargs) + + @property + def param_groups(self): + return self.optim.param_groups + + def add_param_group(self, param_group: dict) -> None: + return self.optim.add_param_group(param_group) + + +class ClippingWrapper(OptimizerWrapper): + """A wrapper of torch.Optimizer that clips gradients by global norm before each step""" + + def __init__(self, optim: torch.optim.Optimizer, clip_grad_norm: float): + super().__init__(optim) + self.clip_grad_norm = clip_grad_norm + + def step(self, *args, **kwargs): + parameters = tuple(param for group in self.param_groups for param in group["params"]) + torch.nn.utils.clip_grad_norm_(parameters, self.clip_grad_norm) + return super().step(*args, **kwargs) diff --git a/hivemind/moe/server/expert_backend.py b/hivemind/moe/server/module_backend.py similarity index 65% rename from hivemind/moe/server/expert_backend.py rename to hivemind/moe/server/module_backend.py index b35238158..f6260371a 100644 --- a/hivemind/moe/server/expert_backend.py +++ b/hivemind/moe/server/module_backend.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, Sequence, Tuple, Union +from typing import Any, Dict, Optional, Sequence, Tuple, Union import torch from torch import nn @@ -8,19 +8,20 @@ from hivemind.utils.nested import nested_compare, nested_flatten, nested_map, nested_pack from hivemind.utils.tensor_descr import DUMMY_BATCH_SIZE, BatchTensorDescriptor +LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None) logger = get_logger(__name__) -class ExpertBackend: +class ModuleBackend: """ - ExpertBackend is a wrapper around torch module that allows it to run tasks asynchronously with Runtime - By default, ExpertBackend handles three types of requests: + ModuleBackend is a wrapper around torch module that allows it to run tasks asynchronously with Runtime + By default, ModuleBackend handles three types of requests: - forward - receive inputs and compute outputs. Concurrent requests will be batched for better GPU utilization. - backward - receive gradients w.r.t. outputs, compute gradients w.r.t. inputs and **update expert**. Also batched. - get_info - return expert metadata. Not batched. - :param expert: nn.Module to be wrapped into a backend. Arbitrary pytorch module with a few limitations: + :param module: nn.Module to be wrapped into a backend. Arbitrary pytorch module with a few limitations: - Experts must always receive the same set of args and kwargs and produce output tensors of same type - All args, kwargs and outputs must be **tensors** where 0-th dimension represents to batch size @@ -34,49 +35,37 @@ class ExpertBackend: :param args_schema: description of positional arguments to expert.forward, list of BatchTensorProto :param kwargs_schema: description of keyword arguments to expert.forward, dict of BatchTensorProto :param outputs_schema: description of outputs from expert.forward, nested structure of BatchTensorProto - :param num_warmup_steps: the number of warmup steps for LR schedule - :param num_total_steps: the total number of steps for LR schedule - :param clip_grad_norm: maximum gradient norm used for clipping :param kwargs: extra parameters to be forwarded into TaskPool.__init__ """ def __init__( self, name: str, - expert: nn.Module, - optimizer: torch.optim.Optimizer, + module: nn.Module, *, - scheduler: Callable = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerBase] = None, args_schema: Tuple[BatchTensorDescriptor, ...] = None, kwargs_schema: Dict[str, BatchTensorDescriptor] = None, outputs_schema: Union[BatchTensorDescriptor, Tuple[BatchTensorDescriptor, ...]] = None, - num_warmup_steps: int = None, - num_total_steps: int = None, - clip_grad_norm: float = None, **kwargs, ): super().__init__() - self.expert, self.optimizer, self.name = expert, optimizer, name - - if scheduler is None: - self.scheduler = None - else: - assert optimizer is not None and num_warmup_steps is not None and num_total_steps is not None - self.scheduler = scheduler(self.optimizer, num_warmup_steps, num_total_steps) - self.clip_grad_norm = clip_grad_norm + self.name, self.module, self.optimizer, self.scheduler = name, module, optimizer, scheduler self.args_schema = args_schema = tuple(args_schema or ()) self.kwargs_schema = kwargs_schema = dict(kwargs_schema or {}) assert args_schema or kwargs_schema, ( - "expert must receive at least one positional or keyword input." + f"Module must take at least one positional or keyword input." " Did you forget to provide args_schema/kwargs_schema?" ) + assert optimizer is not None or scheduler is None, "scheduler should only be used if optimizer is not None" if outputs_schema is None: # run expert once to get outputs schema dummy_args = tuple(sample.make_zeros(DUMMY_BATCH_SIZE) for sample in args_schema) dummy_kwargs = {key: sample.make_zeros(DUMMY_BATCH_SIZE) for key, sample in kwargs_schema.items()} - dummy_outputs = self.expert(*dummy_args, **dummy_kwargs) + dummy_outputs = self.module(*dummy_args, **dummy_kwargs) outputs_schema = nested_map(BatchTensorDescriptor.from_tensor, dummy_outputs) self.forward_schema = (self.args_schema, self.kwargs_schema) # inputs for forward @@ -87,22 +76,17 @@ def __init__( self.forward_pool = TaskPool(self.forward, name=f"{self.name}_forward", **kwargs) self.backward_pool = TaskPool(self.backward, name=f"{self.name}_backward", **kwargs) - self.update_count = 0 - self.examples_processed = 0 - def forward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]: """ Apply forward pass to an aggregated batch of requests. Used by Runtime, do not call this manually; - To submit a request for asynchronous processing, please use ``ExpertBackend.forward_pool.submit_task``. + To submit a request for asynchronous processing, please use ``ModuleBackend.forward_pool.submit_task``. + + .. warning: if the underlying module performs non-gradient updates (e.g. batchnorm), it will be updated twice: + once during forward pass, and again during backward. This behavior is similar to gradient checkpointing. Subclassing: This method receives a sequence of torch tensors following ``nested_flatten(self.forward_schema)``; - It should return gradients w.r.t. inputs that follow ``nested_flatten(self.outputs_schema)``; - - .. todo we handle layer states (e.g. batchnorm stats) incorrectly, updating them twice. - .. For now, either register all buffers as outputs or avoid stateful experts - """ args, kwargs = nested_pack(inputs, structure=self.forward_schema) @@ -110,7 +94,7 @@ def forward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]: raise RuntimeError("Batch should contain more than 0 samples") with torch.no_grad(): - outputs = self.expert(*args, **kwargs) + outputs = self.module(*args, **kwargs) # Note: TaskPool requires function to accept and return a flat tuple of values, we pack/unpack it on client side return tuple(nested_flatten(outputs)) @@ -118,7 +102,7 @@ def forward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]: def backward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]: """ Apply backward pass to an aggregated batch of requests. Used by Runtime, do not call this manually - To submit a request for asynchronous processing, please use ``ExpertBackend.backward_pool.submit_task``. + To submit a request for asynchronous processing, please use ``ModuleBackend.backward_pool.submit_task``. Subclassing: This method receives a sequence of torch tensors following ``nested_flatten(self.backward_schema)``; @@ -128,9 +112,7 @@ def backward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]: Runtime doesn't guarantee that backward will be performed in the same order and for the same data as forward, so we recommend stateless backward pass that re-runs expert forward pass inside backward. - .. todo correct state handling (see forward) - - Please make sure to call ``ExpertBackend.apply_gradients`` here, otherwise the expert will not train + Please make sure to call ``ModuleBackend.on_backward`` after each call to backward """ (args, kwargs), grad_outputs = nested_pack(inputs, structure=self.backward_schema) @@ -148,7 +130,7 @@ def backward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]: batch_size = args[0].size(0) - outputs = self.expert(*args, **kwargs) + outputs = self.module(*args, **kwargs) assert nested_compare(outputs, grad_outputs), "outputs and grad_outputs must have the same structure" outputs_flat = tuple(nested_flatten(outputs)) @@ -163,65 +145,45 @@ def backward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]: torch.autograd.backward( outputs_flat, grad_tensors=grad_outputs_flat, create_graph=False, retain_graph=False ) - self.apply_gradients(batch_size) + self.on_backward(batch_size) return tuple( x.grad if isinstance(x.grad, torch.Tensor) else torch.zeros_like(x) for x in nested_flatten((args, kwargs)) ) - def apply_gradients(self, batch_size) -> None: + def on_backward(self, batch_size: int) -> None: """ - Train the expert for one step. This method is called by ``ExpertBackend.backward`` after computing gradients. + Train the expert for one step. This method is called by ``ModuleBackend.backward`` after computing gradients. """ - if self.clip_grad_norm is not None: - torch.nn.utils.clip_grad_norm_(self.expert.parameters(), self.clip_grad_norm) - - self.optimizer.step() - self.optimizer.zero_grad() + if self.optimizer is not None: + self.optimizer.step() + self.optimizer.zero_grad() if self.scheduler is not None: self.scheduler.step() - self.update_count += 1 - self.examples_processed += batch_size - - def get_stats(self) -> Dict: - """ - Return current expert training statistics (number of updates, number of processed examples after - last optimizer step) - """ - return {"updates": self.update_count, "examples_processed": self.examples_processed} - - def get_full_state(self) -> Dict: - """ - Return the current state of the expert (including batch processing statistics) - """ - full_state = { - "stats": self.get_stats(), - "model": self.expert.state_dict(), - "optimizer": self.optimizer.state_dict(), - "scheduler": {} if self.scheduler is None else self.scheduler.state_dict(), - } + def state_dict(self) -> Dict: + """Return the current state of the module, optimizer, and scheduler""" + full_state = dict(module=self.module.state_dict()) + if self.optimizer is not None: + full_state["optimizer"] = self.optimizer.state_dict() + if self.scheduler is not None: + full_state["scheduler"] = self.scheduler.state_dict() return full_state - def load_full_state(self, state_dict: Dict): - if "stats" in state_dict: - self.update_count = state_dict["stats"]["updates"] - self.examples_processed = state_dict["stats"]["examples_processed"] - else: - logger.warning(f"Batch processing stats missing for expert {self.name}") - - self.expert.load_state_dict(state_dict["model"]) + def load_state_dict(self, state_dict: Dict): + self.module.load_state_dict(state_dict["module"]) + if self.optimizer is not None: + if "optimizer" in state_dict: + self.optimizer.load_state_dict(state_dict["optimizer"]) + else: + logger.warning(f"Optimizer state missing for {self.name}") - if "optimizer" in state_dict: - self.optimizer.load_state_dict(state_dict["optimizer"]) - else: - logger.warning(f"Optimizer state missing for expert {self.name}") - - if self.scheduler is not None and "scheduler" in state_dict: - self.scheduler.load_state_dict(state_dict["scheduler"]) - else: - logger.warning(f"Learning rate scheduler state missing for expert {self.name}") + if self.scheduler is not None: + if "scheduler" in state_dict: + self.scheduler.load_state_dict(state_dict["scheduler"]) + else: + logger.warning(f"Learning rate scheduler state missing for {self.name}") def get_info(self) -> Dict[str, Any]: """Get expert parameters and stats. Used by RemoteExpert to check shapes and for DMoE orchestration.""" diff --git a/hivemind/moe/server/runtime.py b/hivemind/moe/server/runtime.py index b79410fc5..1e750812f 100644 --- a/hivemind/moe/server/runtime.py +++ b/hivemind/moe/server/runtime.py @@ -12,7 +12,7 @@ import torch from prefetch_generator import BackgroundGenerator -from hivemind.moe.server.expert_backend import ExpertBackend +from hivemind.moe.server.module_backend import ModuleBackend from hivemind.utils import get_logger logger = get_logger(__name__) @@ -20,20 +20,20 @@ class Runtime(threading.Thread): """ - A group of processes that processes incoming requests for multiple experts on a shared device. + A group of processes that processes incoming requests for multiple module backends on a shared device. Runtime is usually created and managed by Server, humans need not apply. For debugging, you can start runtime manually with .start() or .run() - >>> expert_backends = {'expert_name': ExpertBackend(**kwargs)} - >>> runtime = Runtime(expert_backends) + >>> module_backends = {'expert_name': ModuleBackend(**kwargs)} + >>> runtime = Runtime(module_backends) >>> runtime.start() # start runtime in background thread. To start in current thread, use runtime.run() >>> runtime.ready.wait() # await for runtime to load all experts on device and create request pools - >>> future = runtime.expert_backends['expert_name'].forward_pool.submit_task(*expert_inputs) + >>> future = runtime.module_backends['expert_name'].forward_pool.submit_task(*module_inputs) >>> print("Returned:", future.result()) >>> runtime.shutdown() - :param expert_backends: a dict [expert uid -> ExpertBackend] + :param module_backends: a dict [expert uid -> ModuleBackend] :param prefetch_batches: form up to this many batches in advance :param sender_threads: dispatches outputs from finished batches using this many asynchronous threads :param device: if specified, moves all experts and data to this device via .to(device=device). @@ -46,15 +46,15 @@ class Runtime(threading.Thread): def __init__( self, - expert_backends: Dict[str, ExpertBackend], + module_backends: Dict[str, ModuleBackend], prefetch_batches=64, sender_threads: int = 1, device: torch.device = None, stats_report_interval: Optional[int] = None, ): super().__init__() - self.expert_backends = expert_backends - self.pools = tuple(chain(*(expert.get_pools() for expert in expert_backends.values()))) + self.module_backends = module_backends + self.pools = tuple(chain(*(backend.get_pools() for backend in module_backends.values()))) self.device, self.prefetch_batches, self.sender_threads = device, prefetch_batches, sender_threads self.shutdown_recv, self.shutdown_send = mp.Pipe(duplex=False) self.shutdown_trigger = mp.Event() @@ -69,8 +69,8 @@ def run(self): if not pool.is_alive(): pool.start() if self.device is not None: - for expert_backend in self.expert_backends.values(): - expert_backend.expert.to(self.device) + for backend in self.module_backends.values(): + backend.module.to(self.device) with mp.pool.ThreadPool(self.sender_threads) as output_sender_pool: try: diff --git a/hivemind/moe/server/server.py b/hivemind/moe/server/server.py index f7b691561..f4d7d7a77 100644 --- a/hivemind/moe/server/server.py +++ b/hivemind/moe/server/server.py @@ -15,13 +15,14 @@ from hivemind.moe.server.checkpoints import CheckpointSaver, is_directory, load_experts from hivemind.moe.server.connection_handler import ConnectionHandler from hivemind.moe.server.dht_handler import DHTHandlerThread, get_experts -from hivemind.moe.server.expert_backend import ExpertBackend from hivemind.moe.server.layers import ( add_custom_models_from_file, name_to_block, name_to_input, schedule_name_to_scheduler, ) +from hivemind.moe.server.layers.optim import ClippingWrapper +from hivemind.moe.server.module_backend import ModuleBackend from hivemind.moe.server.runtime import Runtime from hivemind.p2p import PeerInfo from hivemind.proto.runtime_pb2 import CompressionType @@ -33,7 +34,7 @@ class Server(threading.Thread): """ - Server allows you to host "experts" - pytorch subnetworks used by Decentralized Mixture of Experts. + Server allows you to host "experts" - pytorch subnetworks that can be accessed remotely by peers. After creation, a server should be started: see Server.run or Server.run_in_background. A working server does two things: @@ -41,7 +42,7 @@ class Server(threading.Thread): - publishes updates to expert status every :update_period: seconds :type dht: an instance of hivemind.DHT. Server will use DHT for all network interactions. - :param expert_backends: dict{expert uid (str) : ExpertBackend} for all expert hosted by this server. + :param module_backends: dict{expert uid (str) : ModuleBackend} for all expert hosted by this server. :param num_connection_handlers: maximum number of simultaneous requests. Please note that the default value of 1 if too small for normal functioning, we recommend 4 handlers per expert backend. :param update_period: how often will server attempt to publish its state (i.e. experts) to the DHT; @@ -54,7 +55,7 @@ class Server(threading.Thread): def __init__( self, dht: DHT, - expert_backends: Dict[str, ExpertBackend], + module_backends: Dict[str, ModuleBackend], num_connection_handlers: int = 1, update_period: float = 30, expiration: Optional[float] = None, @@ -63,18 +64,18 @@ def __init__( **kwargs, ): super().__init__() - self.dht, self.experts, self.update_period = dht, expert_backends, update_period + self.dht, self.module_backends, self.update_period = dht, module_backends, update_period - self.conn_handlers = [ConnectionHandler(dht, self.experts) for _ in range(num_connection_handlers)] + self.conn_handlers = [ConnectionHandler(dht, self.module_backends) for _ in range(num_connection_handlers)] if checkpoint_dir is not None: - self.checkpoint_saver = CheckpointSaver(expert_backends, checkpoint_dir, update_period) + self.checkpoint_saver = CheckpointSaver(module_backends, checkpoint_dir, update_period) else: self.checkpoint_saver = None - self.runtime = Runtime(self.experts, **kwargs) + self.runtime = Runtime(self.module_backends, **kwargs) - if self.experts: + if self.module_backends: self.dht_handler_thread = DHTHandlerThread( - experts=self.experts, + module_backends=self.module_backends, dht=self.dht, update_period=self.update_period, expiration=expiration, @@ -95,7 +96,7 @@ def create( optim_cls=torch.optim.Adam, scheduler: str = "none", num_warmup_steps=None, - num_total_steps=None, + num_training_steps=None, clip_grad_norm=None, num_handlers=None, min_batch_size=1, @@ -113,7 +114,7 @@ def create( **kwargs, ) -> Server: """ - Instantiate a server with several identical experts. See argparse comments below for details + Instantiate a server with several identical modules. See argparse comments below for details :param num_experts: run this many identical experts :param expert_pattern: a string pattern or a list of expert uids, example: myprefix.[0:32].[0:256]\ @@ -129,7 +130,7 @@ def create( :param optim_cls: uses this optimizer to train all experts :param scheduler: if not `none`, the name of the expert LR scheduler :param num_warmup_steps: the number of warmup steps for LR schedule - :param num_total_steps: the total number of steps for LR schedule + :param num_training_steps: the total number of steps for LR schedule :param clip_grad_norm: maximum gradient norm used for clipping :param initial_peers: multiaddrs of one or more active DHT peers (if you want to join an existing DHT) @@ -138,7 +139,7 @@ def create( :param compression: if specified, use this compression to pack all inputs, outputs and gradients by all experts hosted on this server. For a more fine-grained compression, start server in python and specify compression - for each BatchTensorProto in ExpertBackend for the respective experts. + for each BatchTensorProto in ModuleBackend for the respective experts. :param start: if True, starts server right away and returns when server is ready for requests :param stats_report_interval: interval between two reports of batch processing performance statistics @@ -180,7 +181,6 @@ def create( num_experts = len(expert_uids) num_handlers = num_handlers if num_handlers is not None else num_experts * 8 - optim_cls = optim_cls if optim_cls is not None else partial(torch.optim.SGD, lr=0.0) device = device or ("cuda" if torch.cuda.is_available() else "cpu") sample_input = name_to_input[expert_cls](DUMMY_BATCH_SIZE, hidden_dim) @@ -189,21 +189,26 @@ def create( else: args_schema = (BatchTensorDescriptor.from_tensor(sample_input, compression),) - scheduler = schedule_name_to_scheduler[scheduler] + scheduler_cls = schedule_name_to_scheduler[scheduler] + if scheduler_cls is not None: + scheduler_cls = partial( + scheduler_cls, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps + ) # initialize experts experts = {} for expert_uid in expert_uids: expert = name_to_block[expert_cls](hidden_dim) - experts[expert_uid] = ExpertBackend( + optimizer = optim_cls(expert.parameters()) if optim_cls is not None else None + scheduler = scheduler_cls(optimizer) if scheduler_cls is not None else None + if clip_grad_norm is not None: + optimizer = ClippingWrapper(optimizer, clip_grad_norm) + experts[expert_uid] = ModuleBackend( name=expert_uid, - expert=expert, + module=expert, args_schema=args_schema, - optimizer=optim_cls(expert.parameters()), + optimizer=optimizer, scheduler=scheduler, - num_warmup_steps=num_warmup_steps, - num_total_steps=num_total_steps, - clip_grad_norm=clip_grad_norm, min_batch_size=min_batch_size, max_batch_size=max_batch_size, ) @@ -228,15 +233,15 @@ def run(self): Starts Server in the current thread. Initializes dht if necessary, starts connection handlers, runs Runtime (self.runtime) to process incoming requests. """ - logger.info(f"Server started with {len(self.experts)} experts:") - for expert_name, backend in self.experts.items(): - num_parameters = sum(p.numel() for p in backend.expert.parameters() if p.requires_grad) - logger.info(f"{expert_name}: {backend.expert.__class__.__name__}, {num_parameters} parameters") + logger.info(f"Server started with {len(self.module_backends)} modules:") + for expert_name, backend in self.module_backends.items(): + num_parameters = sum(p.numel() for p in backend.module.parameters() if p.requires_grad) + logger.info(f"{expert_name}: {backend.module.__class__.__name__}, {num_parameters} parameters") if not self.dht.is_alive(): self.dht.run_in_background(await_ready=True) - if self.experts: + if self.module_backends: self.dht_handler_thread.start() if self.checkpoint_saver is not None: @@ -287,7 +292,7 @@ def shutdown(self): process.join() logger.debug("Connection handlers terminated") - if self.experts: + if self.module_backends: self.dht_handler_thread.stop.set() self.dht_handler_thread.join() diff --git a/tests/test_connection_handler.py b/tests/test_connection_handler.py index 3b4ac9ab9..afc6179f0 100644 --- a/tests/test_connection_handler.py +++ b/tests/test_connection_handler.py @@ -10,7 +10,7 @@ from hivemind.compression import deserialize_tensor_stream, deserialize_torch_tensor, serialize_torch_tensor from hivemind.dht import DHT from hivemind.moe.server.connection_handler import ConnectionHandler -from hivemind.moe.server.expert_backend import ExpertBackend +from hivemind.moe.server.module_backend import ModuleBackend from hivemind.moe.server.task_pool import TaskPool from hivemind.p2p.p2p_daemon_bindings.control import DEFAULT_MAX_MSG_SIZE, P2PHandlerError from hivemind.proto import runtime_pb2 @@ -25,7 +25,7 @@ async def test_connection_handler_info(): handler = ConnectionHandler( DHT(start=True), - dict(expert1=DummyExpertBackend("expert1", k=1), expert2=DummyExpertBackend("expert2", k=2)), + dict(expert1=DummyModuleBackend("expert1", k=1), expert2=DummyModuleBackend("expert2", k=2)), ) handler.start() @@ -48,7 +48,7 @@ async def test_connection_handler_info(): async def test_connection_handler_forward(): handler = ConnectionHandler( DHT(start=True), - dict(expert1=DummyExpertBackend("expert1", k=1), expert2=DummyExpertBackend("expert2", k=2)), + dict(expert1=DummyModuleBackend("expert1", k=1), expert2=DummyModuleBackend("expert2", k=2)), ) handler.start() @@ -109,7 +109,7 @@ async def test_connection_handler_forward(): async def test_connection_handler_backward(): handler = ConnectionHandler( DHT(start=True), - dict(expert1=DummyExpertBackend("expert1", k=1), expert2=DummyExpertBackend("expert2", k=2)), + dict(expert1=DummyModuleBackend("expert1", k=1), expert2=DummyModuleBackend("expert2", k=2)), ) handler.start() @@ -179,7 +179,7 @@ async def submit_task(self, *inputs: torch.Tensor): return [inputs[0] * self.k] -class DummyExpertBackend(ExpertBackend): +class DummyModuleBackend(ModuleBackend): def __init__(self, name: str, k: float): self.name = name self.outputs_schema = [BatchTensorDescriptor.from_tensor(torch.randn(1, 2))] diff --git a/tests/test_expert_backend.py b/tests/test_expert_backend.py index 752cc96b1..e9d83231b 100644 --- a/tests/test_expert_backend.py +++ b/tests/test_expert_backend.py @@ -5,7 +5,7 @@ import torch from torch.nn import Linear -from hivemind import BatchTensorDescriptor, ExpertBackend +from hivemind import BatchTensorDescriptor, ModuleBackend from hivemind.moe.server.checkpoints import load_experts, store_experts from hivemind.moe.server.layers.lr_schedule import get_linear_schedule_with_warmup @@ -22,13 +22,15 @@ def example_experts(): opt = torch.optim.SGD(expert.parameters(), PEAK_LR) args_schema = (BatchTensorDescriptor(1),) - expert_backend = ExpertBackend( + expert_backend = ModuleBackend( name=EXPERT_NAME, - expert=expert, + module=expert, optimizer=opt, - scheduler=get_linear_schedule_with_warmup, - num_warmup_steps=BACKWARD_PASSES_BEFORE_SAVE, - num_total_steps=BACKWARD_PASSES_BEFORE_SAVE + BACKWARD_PASSES_AFTER_SAVE, + scheduler=get_linear_schedule_with_warmup( + opt, + num_warmup_steps=BACKWARD_PASSES_BEFORE_SAVE, + num_training_steps=BACKWARD_PASSES_BEFORE_SAVE + BACKWARD_PASSES_AFTER_SAVE, + ), args_schema=args_schema, outputs_schema=BatchTensorDescriptor(1), max_batch_size=1, @@ -39,7 +41,7 @@ def example_experts(): @pytest.mark.forked def test_save_load_checkpoints(example_experts): - expert = example_experts[EXPERT_NAME].expert + expert = example_experts[EXPERT_NAME].module with TemporaryDirectory() as tmpdir: tmp_path = Path(tmpdir) @@ -79,7 +81,7 @@ def test_restore_update_count(example_experts): expert_backend.backward(batch, loss_grad) load_experts(example_experts, tmp_path) - assert expert_backend.update_count == BACKWARD_PASSES_BEFORE_SAVE + assert expert_backend.scheduler._step_count == BACKWARD_PASSES_BEFORE_SAVE + 1 @pytest.mark.forked diff --git a/tests/test_moe.py b/tests/test_moe.py index 402306c62..46e9279a8 100644 --- a/tests/test_moe.py +++ b/tests/test_moe.py @@ -7,7 +7,7 @@ from hivemind.moe.client.moe import DUMMY, RemoteMixtureOfExperts, _RemoteCallMany from hivemind.moe.client.switch_moe import RemoteSwitchMixtureOfExperts from hivemind.moe.expert_uid import ExpertInfo -from hivemind.moe.server import ExpertBackend, Server, background_server, declare_experts +from hivemind.moe.server import ModuleBackend, Server, background_server, declare_experts from hivemind.moe.server.layers import name_to_block from hivemind.p2p.p2p_daemon_bindings.control import P2PDaemonError from hivemind.utils import BatchTensorDescriptor, get_dht_time @@ -257,16 +257,16 @@ def test_client_anomaly_detection(): experts = {} for i in range(4): expert = name_to_block["ffn"](HID_DIM) - experts[f"expert.{i}"] = ExpertBackend( + experts[f"expert.{i}"] = ModuleBackend( name=f"expert.{i}", - expert=expert, + module=expert, optimizer=torch.optim.Adam(expert.parameters()), args_schema=(BatchTensorDescriptor(HID_DIM),), outputs_schema=BatchTensorDescriptor(HID_DIM), max_batch_size=16, ) - experts["expert.3"].expert.ffn.weight.data[0, 0] = float("nan") + experts["expert.3"].module.ffn.weight.data[0, 0] = float("nan") dht = DHT(start=True) server = Server(dht, experts, num_connection_handlers=1)