Skip to content

Commit

Permalink
Simplify ExpertBackend interface (#483)
Browse files Browse the repository at this point in the history
- extract gradient clipping from ExpertBackend: this behavior can be achieved with a user-defined Optimizer
- remove stats from ExpertBackend: this behavior can be achieved with a user-defined Scheduler
- rename full_state -> state_dict, rationale: there is no "non-full" state in this context
- rename ExpertBackend.expert -> ExpertBackend.module to avoid confusion

Co-authored-by: Max Ryabinin <[email protected]>
  • Loading branch information
justheuristic and mryab authored Jun 15, 2022
1 parent 6c56a87 commit 5ea21a7
Show file tree
Hide file tree
Showing 17 changed files with 201 additions and 171 deletions.
10 changes: 5 additions & 5 deletions benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from hivemind.moe.client.expert import RemoteExpert
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.moe.expert_uid import ExpertInfo
from hivemind.moe.server import ExpertBackend, Server
from hivemind.moe.server import ModuleBackend, Server
from hivemind.moe.server.layers import name_to_block
from hivemind.p2p import P2P
from hivemind.utils.limits import increase_file_limit
Expand Down Expand Up @@ -118,12 +118,12 @@ def benchmark_throughput(
timestamps["launched_clients"] = timestamps["began_launching_server"] = time.perf_counter()

device = device or ("cuda" if torch.cuda.is_available() else "cpu")
experts = {}
module_backends = {}
for i in range(num_experts):
expert = torch.jit.script(name_to_block[expert_cls](hid_dim))
experts[f"expert.{i}"] = ExpertBackend(
module_backends[f"expert.{i}"] = ModuleBackend(
name=f"expert.{i}",
expert=expert,
module=expert,
optimizer=torch.optim.Adam(expert.parameters()),
args_schema=(BatchTensorDescriptor(hid_dim),),
outputs_schema=BatchTensorDescriptor(hid_dim),
Expand All @@ -133,7 +133,7 @@ def benchmark_throughput(

server = Server(
dht=server_dht,
expert_backends=experts,
module_backends=module_backends,
num_connection_handlers=num_handlers,
device=device,
)
Expand Down
10 changes: 5 additions & 5 deletions docs/modules/server.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ or as a part of **hivemind.moe.client.RemoteMixtureOfExperts** that finds the mo
The hivemind.moe.server module is organized as follows:

- Server_ is the main class that publishes experts, accepts incoming requests, and passes them to Runtime_ for compute.
- ExpertBackend_ is a wrapper for `torch.nn.Module <https://pytorch.org/docs/stable/generated/torch.nn.Module.html>`_ \
- ModuleBackend_ is a wrapper for `torch.nn.Module <https://pytorch.org/docs/stable/generated/torch.nn.Module.html>`_ \
that can be accessed by remote clients. It has two TaskPool_ s for forward and backward requests.
- Runtime_ balances the device (GPU) usage between several ExpertBackend_ instances that each service one expert.
- Runtime_ balances the device (GPU) usage between several ModuleBackend_ instances that each service one expert.
- TaskPool_ stores incoming requests for a batch-parallel computation (e.g. forward pass), groups them into batches \
and offers those batches to Runtime_ for processing.

Expand All @@ -25,9 +25,9 @@ The hivemind.moe.server module is organized as follows:
:members:
:member-order: bysource

.. _ExpertBackend:
.. autoclass:: ExpertBackend
:members: forward, backward, apply_gradients, get_info, get_pools
.. _ModuleBackend:
.. autoclass:: ModuleBackend
:members: forward, backward, on_backward, get_info, get_pools
:member-order: bysource

.. currentmodule:: hivemind.moe.server.runtime
Expand Down
2 changes: 1 addition & 1 deletion hivemind/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from hivemind.compression import *
from hivemind.dht import DHT
from hivemind.moe import (
ExpertBackend,
ModuleBackend,
RemoteExpert,
RemoteMixtureOfExperts,
RemoteSwitchMixtureOfExperts,
Expand Down
3 changes: 2 additions & 1 deletion hivemind/hivemind_cli/run_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ def main():
help='Server will report experts to DHT once in this many seconds')
parser.add_argument('--expiration', type=float, required=False, default=None,
help='DHT entries will expire after this many seconds')
parser.add_argument('--num_total_steps', type=int, required=False, help='The total number of steps for LR schedule')
parser.add_argument('--num_training_steps', type=int, required=False, help='The total number of steps for LR schedule')

parser.add_argument('--clip_grad_norm', type=float, required=False, help='Maximum gradient norm used for clipping')

parser.add_argument('--initial_peers', type=str, nargs='*', required=False, default=[],
Expand Down
2 changes: 1 addition & 1 deletion hivemind/moe/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from hivemind.moe.client import RemoteExpert, RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts
from hivemind.moe.server import (
ExpertBackend,
ModuleBackend,
Server,
background_server,
declare_experts,
Expand Down
2 changes: 1 addition & 1 deletion hivemind/moe/server/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from hivemind.moe.server.dht_handler import declare_experts, get_experts
from hivemind.moe.server.expert_backend import ExpertBackend
from hivemind.moe.server.layers import register_expert_class
from hivemind.moe.server.module_backend import ModuleBackend
from hivemind.moe.server.server import Server, background_server
18 changes: 9 additions & 9 deletions hivemind/moe/server/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import torch

from hivemind.moe.server.expert_backend import ExpertBackend
from hivemind.moe.server.module_backend import ModuleBackend
from hivemind.utils.logging import get_logger

logger = get_logger(__name__)
Expand All @@ -34,23 +34,23 @@ def copy_tree(src: str, dst: str):


class CheckpointSaver(threading.Thread):
def __init__(self, expert_backends: Dict[str, ExpertBackend], checkpoint_dir: Path, update_period: float):
def __init__(self, module_backends: Dict[str, ModuleBackend], checkpoint_dir: Path, update_period: float):
super().__init__()
assert is_directory(checkpoint_dir)
self.expert_backends = expert_backends
self.module_backends = module_backends
self.update_period = update_period
self.checkpoint_dir = checkpoint_dir
self.stop = threading.Event()

# create expert directories to ensure that the directory is writable and checkpoints can be loaded
store_experts(self.expert_backends, self.checkpoint_dir)
store_experts(self.module_backends, self.checkpoint_dir)

def run(self) -> None:
while not self.stop.wait(self.update_period):
store_experts(self.expert_backends, self.checkpoint_dir)
store_experts(self.module_backends, self.checkpoint_dir)


def store_experts(experts: Dict[str, ExpertBackend], checkpoint_dir: Path):
def store_experts(experts: Dict[str, ModuleBackend], checkpoint_dir: Path):
logger.debug(f"Storing experts at {checkpoint_dir.absolute()}")
assert is_directory(checkpoint_dir)
timestamp = datetime.now().isoformat(sep="_")
Expand All @@ -59,17 +59,17 @@ def store_experts(experts: Dict[str, ExpertBackend], checkpoint_dir: Path):
expert_dir = Path(tmpdirname) / expert_name
expert_dir.mkdir()
checkpoint_name = expert_dir / f"checkpoint_{timestamp}.pt"
torch.save(expert_backend.get_full_state(), checkpoint_name)
torch.save(expert_backend.state_dict(), checkpoint_name)
os.symlink(checkpoint_name, expert_dir / "checkpoint_last.pt")
copy_tree(tmpdirname, str(checkpoint_dir))


def load_experts(experts: Dict[str, ExpertBackend], checkpoint_dir: Path):
def load_experts(experts: Dict[str, ModuleBackend], checkpoint_dir: Path):
assert is_directory(checkpoint_dir)
for expert_name, expert in experts.items():
checkpoints_folder = checkpoint_dir / expert_name
latest_checkpoint = checkpoints_folder / "checkpoint_last.pt"
if latest_checkpoint.exists():
expert.load_full_state(torch.load(latest_checkpoint))
expert.load_state_dict(torch.load(latest_checkpoint))
else:
logger.warning(f"Failed to load checkpoint for expert {expert_name}")
6 changes: 3 additions & 3 deletions hivemind/moe/server/connection_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from hivemind.compression import deserialize_tensor_stream, deserialize_torch_tensor, serialize_torch_tensor
from hivemind.dht import DHT
from hivemind.moe.server.expert_backend import ExpertBackend
from hivemind.moe.server.module_backend import ModuleBackend
from hivemind.moe.server.task_pool import TaskPool
from hivemind.p2p import P2PContext, ServicerBase
from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE, P2P
Expand All @@ -25,10 +25,10 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
:note: ConnectionHandler is designed so as to allow using multiple handler processes for the same port
:param dht: a running hivemind.dht.DHT, used to let other peers connect to this one
:param experts: a dict [UID -> ExpertBackend] with all active experts
:param experts: a dict [UID -> ModuleBackend] with all active experts
"""

def __init__(self, dht: DHT, experts: Dict[str, ExpertBackend]):
def __init__(self, dht: DHT, experts: Dict[str, ModuleBackend]):
super().__init__()
self.dht, self.experts = dht, experts
self._p2p: Optional[P2P] = None
Expand Down
10 changes: 6 additions & 4 deletions hivemind/moe/server/dht_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,22 @@


class DHTHandlerThread(threading.Thread):
def __init__(self, experts, dht: DHT, update_period: float = 30, expiration: Optional[int] = None, **kwargs):
def __init__(
self, module_backends, dht: DHT, update_period: float = 30, expiration: Optional[int] = None, **kwargs
):
super().__init__(**kwargs)
if expiration is None:
expiration = max(2 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS)
self.experts = experts
self.module_backends = module_backends
self.dht = dht
self.update_period = update_period
self.expiration = expiration
self.stop = threading.Event()

def run(self) -> None:
declare_experts(self.dht, self.experts.keys(), expiration_time=get_dht_time() + self.expiration)
declare_experts(self.dht, self.module_backends.keys(), expiration_time=get_dht_time() + self.expiration)
while not self.stop.wait(self.update_period):
declare_experts(self.dht, self.experts.keys(), expiration_time=get_dht_time() + self.expiration)
declare_experts(self.dht, self.module_backends.keys(), expiration_time=get_dht_time() + self.expiration)


def declare_experts(
Expand Down
2 changes: 1 addition & 1 deletion hivemind/moe/server/layers/dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def backward(ctx, grad_output):
class DeterministicDropout(nn.Module):
"""
Custom dropout layer which accepts dropout mask as an input (drop_prob is only used for scaling input activations).
Can be used with RemoteExpert/ExpertBackend to ensure that dropout mask is the same at forward and backward steps
Can be used with RemoteExpert/ModuleBackend to ensure that dropout mask is the same at forward and backward steps
"""

def __init__(self, drop_prob):
Expand Down
58 changes: 58 additions & 0 deletions hivemind/moe/server/layers/optim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import torch


class OptimizerWrapper(torch.optim.Optimizer):
"""A wrapper for pytorch.optim.Optimizer that forwards all methods to the wrapped optimizer"""

def __init__(self, optim: torch.optim.Optimizer):
super().__init__(optim.param_groups, optim.defaults)
self.optim = optim

@property
def defaults(self):
return self.optim.defaults

@property
def state(self):
return self.optim.state

def __getstate__(self):
return self.optim.__getstate__()

def __setstate__(self, state):
self.optim.__setstate__(state)

def __repr__(self):
return f"{self.__class__.__name__}({repr(self.optim)})"

def state_dict(self):
return self.optim.state_dict()

def load_state_dict(self, state_dict: dict) -> None:
return self.optim.load_state_dict(state_dict)

def step(self, *args, **kwargs):
return self.optim.step(*args, **kwargs)

def zero_grad(self, *args, **kwargs):
return self.optim.zero_grad(*args, **kwargs)

@property
def param_groups(self):
return self.optim.param_groups

def add_param_group(self, param_group: dict) -> None:
return self.optim.add_param_group(param_group)


class ClippingWrapper(OptimizerWrapper):
"""A wrapper of torch.Optimizer that clips gradients by global norm before each step"""

def __init__(self, optim: torch.optim.Optimizer, clip_grad_norm: float):
super().__init__(optim)
self.clip_grad_norm = clip_grad_norm

def step(self, *args, **kwargs):
parameters = tuple(param for group in self.param_groups for param in group["params"])
torch.nn.utils.clip_grad_norm_(parameters, self.clip_grad_norm)
return super().step(*args, **kwargs)
Loading

0 comments on commit 5ea21a7

Please sign in to comment.