From e5a3e46613052e39ffd9de21883050ab4adc583a Mon Sep 17 00:00:00 2001 From: justheuristic Date: Sun, 12 Jun 2022 23:05:24 +0300 Subject: [PATCH 01/33] Increase default update_period to 30s, set default expiration to 2 * update_period --- hivemind/moe/server/dht_handler.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/hivemind/moe/server/dht_handler.py b/hivemind/moe/server/dht_handler.py index 4d1c10515..0e913fc91 100644 --- a/hivemind/moe/server/dht_handler.py +++ b/hivemind/moe/server/dht_handler.py @@ -20,17 +20,18 @@ class DHTHandlerThread(threading.Thread): - def __init__(self, experts, dht: DHT, update_period: int = 5, **kwargs): + def __init__(self, experts, dht: DHT, update_period: int = 30, expiration: Optional[int] = None, **kwargs): super().__init__(**kwargs) self.experts = experts self.dht = dht self.update_period = update_period + self.expiration = expiration if expiration is not None else 2 * update_period self.stop = threading.Event() def run(self) -> None: declare_experts(self.dht, self.experts.keys()) while not self.stop.wait(self.update_period): - declare_experts(self.dht, self.experts.keys()) + declare_experts(self.dht, self.experts.keys(), expiration=self.expiration) def declare_experts( From 637fb01acb359f9568e1c5675f9370a2b16528cf Mon Sep 17 00:00:00 2001 From: justheuristic Date: Sun, 12 Jun 2022 23:11:44 +0300 Subject: [PATCH 02/33] black-isort --- hivemind/moe/server/dht_handler.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/hivemind/moe/server/dht_handler.py b/hivemind/moe/server/dht_handler.py index 0e913fc91..401a55859 100644 --- a/hivemind/moe/server/dht_handler.py +++ b/hivemind/moe/server/dht_handler.py @@ -16,16 +16,18 @@ split_uid, ) from hivemind.p2p import PeerID -from hivemind.utils import MPFuture, get_dht_time +from hivemind.utils import MAX_DHT_TIME_DISCREPANCY_SECONDS, MPFuture, get_dht_time class DHTHandlerThread(threading.Thread): def __init__(self, experts, dht: DHT, update_period: int = 30, expiration: Optional[int] = None, **kwargs): super().__init__(**kwargs) + if expiration is None: + expiration = max(2 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS) self.experts = experts self.dht = dht self.update_period = update_period - self.expiration = expiration if expiration is not None else 2 * update_period + self.expiration = expiration self.stop = threading.Event() def run(self) -> None: From 372d9159daf83cdf030010463ef308a8049708fb Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Mon, 13 Jun 2022 00:07:32 +0300 Subject: [PATCH 03/33] review --- hivemind/moe/server/dht_handler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hivemind/moe/server/dht_handler.py b/hivemind/moe/server/dht_handler.py index 401a55859..f9cfffe3c 100644 --- a/hivemind/moe/server/dht_handler.py +++ b/hivemind/moe/server/dht_handler.py @@ -37,7 +37,7 @@ def run(self) -> None: def declare_experts( - dht: DHT, uids: Sequence[ExpertUID], expiration: DHTExpiration = 300, wait: bool = True + dht: DHT, uids: Sequence[ExpertUID], expiration: DHTExpiration, wait: bool = True ) -> Union[Dict[ExpertUID, bool], MPFuture[Dict[ExpertUID, bool]]]: """ Make experts visible to all DHT peers; update timestamps if declared previously. From 98d495202d5611ad1f3efef449a82819dc720fdd Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Mon, 13 Jun 2022 00:21:45 +0300 Subject: [PATCH 04/33] review --- hivemind/moe/server/dht_handler.py | 15 ++++++++------- tests/test_dht_experts.py | 23 ++++++++++++----------- tests/test_moe.py | 3 ++- 3 files changed, 22 insertions(+), 19 deletions(-) diff --git a/hivemind/moe/server/dht_handler.py b/hivemind/moe/server/dht_handler.py index f9cfffe3c..79594d8a0 100644 --- a/hivemind/moe/server/dht_handler.py +++ b/hivemind/moe/server/dht_handler.py @@ -31,20 +31,20 @@ def __init__(self, experts, dht: DHT, update_period: int = 30, expiration: Optio self.stop = threading.Event() def run(self) -> None: - declare_experts(self.dht, self.experts.keys()) + declare_experts(self.dht, self.experts.keys(), expiration_time=get_dht_time() + self.expiration) while not self.stop.wait(self.update_period): - declare_experts(self.dht, self.experts.keys(), expiration=self.expiration) + declare_experts(self.dht, self.experts.keys(), expiration_time=get_dht_time() + self.expiration) def declare_experts( - dht: DHT, uids: Sequence[ExpertUID], expiration: DHTExpiration, wait: bool = True + dht: DHT, uids: Sequence[ExpertUID], expiration_time: DHTExpiration, wait: bool = True ) -> Union[Dict[ExpertUID, bool], MPFuture[Dict[ExpertUID, bool]]]: """ Make experts visible to all DHT peers; update timestamps if declared previously. :param uids: a list of expert ids to update :param wait: if True, awaits for declaration to finish, otherwise runs in background - :param expiration: experts will be visible for this many seconds + :param expiration_time: experts will be visible for this many seconds :returns: if wait, returns store status for every key (True = store succeeded, False = store rejected) """ assert not isinstance(uids, str), "Please send a list / tuple of expert uids." @@ -52,14 +52,15 @@ def declare_experts( uids = list(uids) for uid in uids: assert is_valid_uid(uid), f"{uid} is not a valid expert uid. All uids must follow {UID_PATTERN.pattern}" - return dht.run_coroutine(partial(_declare_experts, uids=uids, expiration=expiration), return_future=not wait) + return dht.run_coroutine( + partial(_declare_experts, uids=uids, expiration_time=expiration_time), return_future=not wait + ) async def _declare_experts( - dht: DHT, node: DHTNode, uids: List[ExpertUID], expiration: DHTExpiration + dht: DHT, node: DHTNode, uids: List[ExpertUID], expiration_time: DHTExpiration ) -> Dict[ExpertUID, bool]: num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers) - expiration_time = get_dht_time() + expiration data_to_store: Dict[Tuple[ExpertPrefix, Optional[Coordinate]], DHTValue] = {} peer_id_base58 = dht.peer_id.to_base58() diff --git a/tests/test_dht_experts.py b/tests/test_dht_experts.py index 6961df723..9789c0e54 100644 --- a/tests/test_dht_experts.py +++ b/tests/test_dht_experts.py @@ -6,10 +6,10 @@ import pytest import hivemind -from hivemind.dht import DHTNode +from hivemind import get_dht_time from hivemind.moe.client.beam_search import MoEBeamSearcher from hivemind.moe.expert_uid import ExpertInfo, is_valid_prefix, is_valid_uid, split_uid -from hivemind.moe.server import declare_experts, get_experts +from hivemind.moe.server.dht_handler import declare_experts, get_experts @pytest.mark.forked @@ -24,14 +24,14 @@ def test_store_get_experts(n_peers=10): expert_uids = [f"my_expert.{i}" for i in range(50)] batch_size = 10 for batch_start in range(0, len(expert_uids), batch_size): - declare_experts(first_peer, expert_uids[batch_start : batch_start + batch_size]) + declare_experts(first_peer, expert_uids[batch_start : batch_start + batch_size], get_dht_time() + 30) found = get_experts(other_peer, random.sample(expert_uids, 5) + ["foo", "bar"]) assert all(res is not None for res in found[:-2]), "Could not find some existing experts" assert all(res is None for res in found[-2:]), "Found non-existing experts" other_expert = "my_other_expert.1337" - declare_experts(other_peer, [other_expert]) + declare_experts(other_peer, [other_expert], get_dht_time() + 30) first_notfound, first_found = get_experts(first_peer, ["foobar", other_expert]) assert isinstance(first_found, hivemind.RemoteExpert) assert first_found.peer_id == other_peer.peer_id @@ -43,7 +43,7 @@ def test_store_get_experts(n_peers=10): time.sleep(1.0) remaining_peer1 = random.choice([peer for peer in peers if peer.is_alive()]) remaining_peer2 = random.choice([peer for peer in peers if peer.is_alive()]) - assert all(declare_experts(remaining_peer1, ["new_expert.1"])) + assert all(declare_experts(remaining_peer1, ["new_expert.1"], expiration_time=hivemind.get_dht_time() + 30)) assert get_experts(remaining_peer2, ["new_expert.1"])[0].peer_id == remaining_peer1.peer_id @@ -63,6 +63,7 @@ def test_beam_search( declare_experts( dht, real_experts[batch_start : batch_start + batch_size], + expiration_time=get_dht_time() + 30 ) neighbors = sum( @@ -90,14 +91,14 @@ def test_dht_single_node(): node = hivemind.DHT(start=True) beam_search = MoEBeamSearcher(node, "expert.", grid_size=(10,)) - assert all(declare_experts(node, ["expert.1", "expert.2", "expert.3"]).values()) - assert len(declare_experts(node, ["ffn.1", "ffn.2"])) == 4 - assert len(declare_experts(node, ["e.1.2.3", "e.1.2.5", "e.2.0"])) == 7 + assert all(declare_experts(node, ["expert.1", "expert.2", "expert.3"], get_dht_time() + 30).values()) + assert len(declare_experts(node, ["ffn.1", "ffn.2"], get_dht_time() + 30)) == 4 + assert len(declare_experts(node, ["e.1.2.3", "e.1.2.5", "e.2.0"], get_dht_time() + 30)) == 7 for expert in get_experts(node, ["expert.3", "expert.2"]): assert expert.peer_id == node.peer_id - assert all(declare_experts(node, ["expert.5", "expert.2"]).values()) + assert all(declare_experts(node, ["expert.5", "expert.2"], get_dht_time() + 30).values()) found_experts = beam_search.find_best_experts([(0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0)], beam_size=2) assert len(found_experts) == 2 and [expert.uid for expert in found_experts] == ["expert.5", "expert.3"] @@ -196,7 +197,7 @@ async def test_negative_caching(n_peers=10): peers += [hivemind.DHT(initial_peers=initial_peers, start=True, **dht_kwargs) for _ in range(n_peers - 1)] writer_peer = random.choice(peers) - assert all(declare_experts(writer_peer, ["ffn.1.2.3", "ffn.3.4.5"]).values()) + assert all(declare_experts(writer_peer, ["ffn.1.2.3", "ffn.3.4.5"], get_dht_time() + 30).values()) neighbors = sum([peer.get_visible_maddrs() for peer in random.sample(peers, min(3, len(peers)))], []) neg_caching_peer = hivemind.DHT(initial_peers=neighbors, start=True, **dht_kwargs) @@ -204,7 +205,7 @@ async def test_negative_caching(n_peers=10): # get prefixes by the peer with negative caching. Cache "no data" entries for ffn.0.*, ffn.2.*, ffn.4.*, ffn.5.* assert len(beam_search.get_initial_beam(scores=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], beam_size=3)) == 2 - node = await DHTNode.create(initial_peers=neighbors) + node = await declare_expertsDHTNode.create(initial_peers=neighbors) fetched = await asyncio.gather(*(node.get(f"ffn.{i}.") for i in range(10))) for i in range(6): assert fetched[i] is not None, f"node should have cached ffn.{i}." diff --git a/tests/test_moe.py b/tests/test_moe.py index 274bb975a..373c05f51 100644 --- a/tests/test_moe.py +++ b/tests/test_moe.py @@ -2,6 +2,7 @@ import pytest import torch +import hivemind from hivemind.dht import DHT from hivemind.moe.client.expert import RemoteExpert, create_remote_experts from hivemind.moe.client.moe import DUMMY, RemoteMixtureOfExperts, _RemoteCallMany @@ -163,7 +164,7 @@ def test_remote_module_call(hidden_dim=16): def test_beam_search_correctness(): all_expert_uids = [f"ffn.{5 + i}.{10 + j}.{15 + k}" for i in range(10) for j in range(10) for k in range(10)] dht = DHT(start=True) - assert all(declare_experts(dht, all_expert_uids)) + assert all(declare_experts(dht, all_expert_uids, expiration_time=hivemind.get_dht_time() + 30)) dmoe = RemoteMixtureOfExperts(in_features=32, grid_size=(32, 32, 32), dht=dht, k_best=4, uid_prefix="ffn.") From b43c243bbce4b33903d5b20e5e7a872f14c618cb Mon Sep 17 00:00:00 2001 From: justheuristic Date: Mon, 13 Jun 2022 00:22:24 +0300 Subject: [PATCH 05/33] typo --- tests/test_dht_experts.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_dht_experts.py b/tests/test_dht_experts.py index 9789c0e54..06d00ddf8 100644 --- a/tests/test_dht_experts.py +++ b/tests/test_dht_experts.py @@ -8,6 +8,7 @@ import hivemind from hivemind import get_dht_time from hivemind.moe.client.beam_search import MoEBeamSearcher +from hivemind.dht.node import DHTNode from hivemind.moe.expert_uid import ExpertInfo, is_valid_prefix, is_valid_uid, split_uid from hivemind.moe.server.dht_handler import declare_experts, get_experts @@ -205,7 +206,7 @@ async def test_negative_caching(n_peers=10): # get prefixes by the peer with negative caching. Cache "no data" entries for ffn.0.*, ffn.2.*, ffn.4.*, ffn.5.* assert len(beam_search.get_initial_beam(scores=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], beam_size=3)) == 2 - node = await declare_expertsDHTNode.create(initial_peers=neighbors) + node = await DHTNode.create(initial_peers=neighbors) fetched = await asyncio.gather(*(node.get(f"ffn.{i}.") for i in range(10))) for i in range(6): assert fetched[i] is not None, f"node should have cached ffn.{i}." From 108a24ba9cb29f1eff6919713fefcf1b72ed31df Mon Sep 17 00:00:00 2001 From: justheuristic Date: Mon, 13 Jun 2022 00:26:20 +0300 Subject: [PATCH 06/33] add expiration param --- hivemind/moe/server/checkpoints.py | 2 +- hivemind/moe/server/dht_handler.py | 2 +- hivemind/moe/server/server.py | 5 ++++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/hivemind/moe/server/checkpoints.py b/hivemind/moe/server/checkpoints.py index d013aa4f4..23a4a4a2e 100644 --- a/hivemind/moe/server/checkpoints.py +++ b/hivemind/moe/server/checkpoints.py @@ -34,7 +34,7 @@ def copy_tree(src: str, dst: str): class CheckpointSaver(threading.Thread): - def __init__(self, expert_backends: Dict[str, ExpertBackend], checkpoint_dir: Path, update_period: int): + def __init__(self, expert_backends: Dict[str, ExpertBackend], checkpoint_dir: Path, update_period: float): super().__init__() assert is_directory(checkpoint_dir) self.expert_backends = expert_backends diff --git a/hivemind/moe/server/dht_handler.py b/hivemind/moe/server/dht_handler.py index 79594d8a0..ea1a4f658 100644 --- a/hivemind/moe/server/dht_handler.py +++ b/hivemind/moe/server/dht_handler.py @@ -20,7 +20,7 @@ class DHTHandlerThread(threading.Thread): - def __init__(self, experts, dht: DHT, update_period: int = 30, expiration: Optional[int] = None, **kwargs): + def __init__(self, experts, dht: DHT, update_period: float = 30, expiration: Optional[int] = None, **kwargs): super().__init__(**kwargs) if expiration is None: expiration = max(2 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS) diff --git a/hivemind/moe/server/server.py b/hivemind/moe/server/server.py index c2d01cf9b..2fd7ca591 100644 --- a/hivemind/moe/server/server.py +++ b/hivemind/moe/server/server.py @@ -46,6 +46,7 @@ class Server(threading.Thread): if too small for normal functioning, we recommend 4 handlers per expert backend. :param update_period: how often will server attempt to publish its state (i.e. experts) to the DHT; if dht is None, this parameter is ignored. + :param expiration: when server declares its experts to the DHT, these entries will expire after this many seconds :param start: if True, the server will immediately start as a background thread and returns control after server is ready (see .ready below) """ @@ -55,7 +56,8 @@ def __init__( dht: DHT, expert_backends: Dict[str, ExpertBackend], num_connection_handlers: int = 1, - update_period: int = 30, + update_period: float = 30, + expiration: Optional[float] = None, start=False, checkpoint_dir=None, **kwargs, @@ -75,6 +77,7 @@ def __init__( experts=self.experts, dht=self.dht, update_period=self.update_period, + expiration=expiration, daemon=True, ) From a5aa6f9e5b7f28470000c63005956fdb1e220da0 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Mon, 13 Jun 2022 00:26:53 +0300 Subject: [PATCH 07/33] black-isort --- tests/test_dht_experts.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/test_dht_experts.py b/tests/test_dht_experts.py index 06d00ddf8..aa7eccc36 100644 --- a/tests/test_dht_experts.py +++ b/tests/test_dht_experts.py @@ -7,8 +7,8 @@ import hivemind from hivemind import get_dht_time -from hivemind.moe.client.beam_search import MoEBeamSearcher from hivemind.dht.node import DHTNode +from hivemind.moe.client.beam_search import MoEBeamSearcher from hivemind.moe.expert_uid import ExpertInfo, is_valid_prefix, is_valid_uid, split_uid from hivemind.moe.server.dht_handler import declare_experts, get_experts @@ -61,11 +61,7 @@ def test_beam_search( ) for batch_start in range(0, len(real_experts), batch_size): dht = random.choice(dht_instances) - declare_experts( - dht, - real_experts[batch_start : batch_start + batch_size], - expiration_time=get_dht_time() + 30 - ) + declare_experts(dht, real_experts[batch_start : batch_start + batch_size], expiration_time=get_dht_time() + 30) neighbors = sum( [peer.get_visible_maddrs() for peer in random.sample(dht_instances, min(3, len(dht_instances)))], [] From e94091114fb4369fc850e025be6518d4c2af0032 Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Mon, 13 Jun 2022 00:35:12 +0300 Subject: [PATCH 08/33] more requests --- hivemind/hivemind_cli/run_server.py | 4 ++++ hivemind/moe/server/server.py | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/hivemind/hivemind_cli/run_server.py b/hivemind/hivemind_cli/run_server.py index 535aba55c..7c0d70a28 100644 --- a/hivemind/hivemind_cli/run_server.py +++ b/hivemind/hivemind_cli/run_server.py @@ -50,6 +50,10 @@ def main(): help='LR scheduler type to use') parser.add_argument('--num_warmup_steps', type=int, required=False, help='The number of warmup steps for LR schedule') + parser.add_argument('--update_period', type=float, required=False, default=30, + help='Server will report experts to DHT once in this many seconds') + parser.add_argument('--expiration', type=float, required=False, default=None, + help='DHT entries will expire after this many seconds') parser.add_argument('--num_total_steps', type=int, required=False, help='The total number of steps for LR schedule') parser.add_argument('--clip_grad_norm', type=float, required=False, help='Maximum gradient norm used for clipping') diff --git a/hivemind/moe/server/server.py b/hivemind/moe/server/server.py index 2fd7ca591..f7b691561 100644 --- a/hivemind/moe/server/server.py +++ b/hivemind/moe/server/server.py @@ -106,6 +106,8 @@ def create( compression=CompressionType.NONE, stats_report_interval: Optional[int] = None, custom_module_path=None, + update_period: float = 30, + expiration: Optional[float] = None, *, start: bool, **kwargs, @@ -216,6 +218,8 @@ def create( device=device, checkpoint_dir=checkpoint_dir, stats_report_interval=stats_report_interval, + update_period=update_period, + expiration=expiration, start=start, ) From 2356bc10562bddc085d0b662ed3d8b68e284b698 Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Mon, 13 Jun 2022 00:36:25 +0300 Subject: [PATCH 09/33] review --- tests/test_moe.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/test_moe.py b/tests/test_moe.py index 373c05f51..402306c62 100644 --- a/tests/test_moe.py +++ b/tests/test_moe.py @@ -2,7 +2,6 @@ import pytest import torch -import hivemind from hivemind.dht import DHT from hivemind.moe.client.expert import RemoteExpert, create_remote_experts from hivemind.moe.client.moe import DUMMY, RemoteMixtureOfExperts, _RemoteCallMany @@ -11,7 +10,7 @@ from hivemind.moe.server import ExpertBackend, Server, background_server, declare_experts from hivemind.moe.server.layers import name_to_block from hivemind.p2p.p2p_daemon_bindings.control import P2PDaemonError -from hivemind.utils.tensor_descr import BatchTensorDescriptor +from hivemind.utils import BatchTensorDescriptor, get_dht_time @pytest.mark.forked @@ -164,7 +163,7 @@ def test_remote_module_call(hidden_dim=16): def test_beam_search_correctness(): all_expert_uids = [f"ffn.{5 + i}.{10 + j}.{15 + k}" for i in range(10) for j in range(10) for k in range(10)] dht = DHT(start=True) - assert all(declare_experts(dht, all_expert_uids, expiration_time=hivemind.get_dht_time() + 30)) + assert all(declare_experts(dht, all_expert_uids, expiration_time=get_dht_time() + 30)) dmoe = RemoteMixtureOfExperts(in_features=32, grid_size=(32, 32, 32), dht=dht, k_best=4, uid_prefix="ffn.") From 771ebc16b81d828c8086435b52ec454a890e7fe5 Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Mon, 13 Jun 2022 00:37:29 +0300 Subject: [PATCH 10/33] Update tests/test_dht_experts.py --- tests/test_dht_experts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_dht_experts.py b/tests/test_dht_experts.py index aa7eccc36..0061aaf3e 100644 --- a/tests/test_dht_experts.py +++ b/tests/test_dht_experts.py @@ -44,7 +44,7 @@ def test_store_get_experts(n_peers=10): time.sleep(1.0) remaining_peer1 = random.choice([peer for peer in peers if peer.is_alive()]) remaining_peer2 = random.choice([peer for peer in peers if peer.is_alive()]) - assert all(declare_experts(remaining_peer1, ["new_expert.1"], expiration_time=hivemind.get_dht_time() + 30)) + assert all(declare_experts(remaining_peer1, ["new_expert.1"], expiration_time=get_dht_time() + 30)) assert get_experts(remaining_peer2, ["new_expert.1"])[0].peer_id == remaining_peer1.peer_id From 8fb098654bb723bb09983fe4bf53447a305d5b22 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Mon, 13 Jun 2022 02:23:47 +0300 Subject: [PATCH 11/33] py39 --- tests/test_dht_experts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_dht_experts.py b/tests/test_dht_experts.py index aa7eccc36..b6c77b325 100644 --- a/tests/test_dht_experts.py +++ b/tests/test_dht_experts.py @@ -61,7 +61,7 @@ def test_beam_search( ) for batch_start in range(0, len(real_experts), batch_size): dht = random.choice(dht_instances) - declare_experts(dht, real_experts[batch_start : batch_start + batch_size], expiration_time=get_dht_time() + 30) + declare_experts(dht, real_experts[batch_start : batch_start + batch_size], get_dht_time() + 30) neighbors = sum( [peer.get_visible_maddrs() for peer in random.sample(dht_instances, min(3, len(dht_instances)))], [] From ec22edab6ae644711075eed1d0d2d10efed5c016 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Tue, 14 Jun 2022 07:21:18 +0300 Subject: [PATCH 12/33] - rename num_total_steps -> num_training steps (to match the source) - extract batching and clipping from ExpertBackend, reassign this role to optimizer/scheduler - rename full_state -> state_dict, rationale: there is no "non-full" state in this context - rename ExpertBackend.expert -> ExpertBackend.module to avoid confusion --- benchmarks/benchmark_throughput.py | 2 +- hivemind/hivemind_cli/run_server.py | 2 +- hivemind/moe/server/checkpoints.py | 4 +- hivemind/moe/server/expert_backend.py | 108 ++++++++------------------ hivemind/moe/server/layers/optim.py | 63 +++++++++++++++ hivemind/moe/server/runtime.py | 2 +- hivemind/moe/server/server.py | 27 ++++--- tests/test_expert_backend.py | 14 ++-- tests/test_moe.py | 4 +- 9 files changed, 128 insertions(+), 98 deletions(-) create mode 100644 hivemind/moe/server/layers/optim.py diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 6e1389a4b..e62695c8b 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -123,7 +123,7 @@ def benchmark_throughput( expert = torch.jit.script(name_to_block[expert_cls](hid_dim)) experts[f"expert.{i}"] = ExpertBackend( name=f"expert.{i}", - expert=expert, + module=expert, optimizer=torch.optim.Adam(expert.parameters()), args_schema=(BatchTensorDescriptor(hid_dim),), outputs_schema=BatchTensorDescriptor(hid_dim), diff --git a/hivemind/hivemind_cli/run_server.py b/hivemind/hivemind_cli/run_server.py index 7c0d70a28..e087067a6 100644 --- a/hivemind/hivemind_cli/run_server.py +++ b/hivemind/hivemind_cli/run_server.py @@ -54,7 +54,7 @@ def main(): help='Server will report experts to DHT once in this many seconds') parser.add_argument('--expiration', type=float, required=False, default=None, help='DHT entries will expire after this many seconds') - parser.add_argument('--num_total_steps', type=int, required=False, help='The total number of steps for LR schedule') + parser.add_argument('--num_training_steps', type=int, required=False, help='The total number of steps for LR schedule') parser.add_argument('--clip_grad_norm', type=float, required=False, help='Maximum gradient norm used for clipping') parser.add_argument('--initial_peers', type=str, nargs='*', required=False, default=[], diff --git a/hivemind/moe/server/checkpoints.py b/hivemind/moe/server/checkpoints.py index 23a4a4a2e..cd67b5d6f 100644 --- a/hivemind/moe/server/checkpoints.py +++ b/hivemind/moe/server/checkpoints.py @@ -59,7 +59,7 @@ def store_experts(experts: Dict[str, ExpertBackend], checkpoint_dir: Path): expert_dir = Path(tmpdirname) / expert_name expert_dir.mkdir() checkpoint_name = expert_dir / f"checkpoint_{timestamp}.pt" - torch.save(expert_backend.get_full_state(), checkpoint_name) + torch.save(expert_backend.state_dict(), checkpoint_name) os.symlink(checkpoint_name, expert_dir / "checkpoint_last.pt") copy_tree(tmpdirname, str(checkpoint_dir)) @@ -70,6 +70,6 @@ def load_experts(experts: Dict[str, ExpertBackend], checkpoint_dir: Path): checkpoints_folder = checkpoint_dir / expert_name latest_checkpoint = checkpoints_folder / "checkpoint_last.pt" if latest_checkpoint.exists(): - expert.load_full_state(torch.load(latest_checkpoint)) + expert.load_state_dict(torch.load(latest_checkpoint)) else: logger.warning(f"Failed to load checkpoint for expert {expert_name}") diff --git a/hivemind/moe/server/expert_backend.py b/hivemind/moe/server/expert_backend.py index b35238158..c3edcde3b 100644 --- a/hivemind/moe/server/expert_backend.py +++ b/hivemind/moe/server/expert_backend.py @@ -1,9 +1,10 @@ -from typing import Any, Callable, Dict, Sequence, Tuple, Union +from typing import Any, Callable, Dict, Sequence, Tuple, Union, Optional import torch from torch import nn from hivemind.moe.server.task_pool import TaskPool +from hivemind.optim.state_averager import LRSchedulerBase from hivemind.utils.logging import get_logger from hivemind.utils.nested import nested_compare, nested_flatten, nested_map, nested_pack from hivemind.utils.tensor_descr import DUMMY_BATCH_SIZE, BatchTensorDescriptor @@ -20,7 +21,7 @@ class ExpertBackend: - backward - receive gradients w.r.t. outputs, compute gradients w.r.t. inputs and **update expert**. Also batched. - get_info - return expert metadata. Not batched. - :param expert: nn.Module to be wrapped into a backend. Arbitrary pytorch module with a few limitations: + :param module: nn.Module to be wrapped into a backend. Arbitrary pytorch module with a few limitations: - Experts must always receive the same set of args and kwargs and produce output tensors of same type - All args, kwargs and outputs must be **tensors** where 0-th dimension represents to batch size @@ -34,41 +35,28 @@ class ExpertBackend: :param args_schema: description of positional arguments to expert.forward, list of BatchTensorProto :param kwargs_schema: description of keyword arguments to expert.forward, dict of BatchTensorProto :param outputs_schema: description of outputs from expert.forward, nested structure of BatchTensorProto - :param num_warmup_steps: the number of warmup steps for LR schedule - :param num_total_steps: the total number of steps for LR schedule - :param clip_grad_norm: maximum gradient norm used for clipping :param kwargs: extra parameters to be forwarded into TaskPool.__init__ """ def __init__( self, name: str, - expert: nn.Module, - optimizer: torch.optim.Optimizer, + module: nn.Module, *, - scheduler: Callable = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerBase] = None, args_schema: Tuple[BatchTensorDescriptor, ...] = None, kwargs_schema: Dict[str, BatchTensorDescriptor] = None, outputs_schema: Union[BatchTensorDescriptor, Tuple[BatchTensorDescriptor, ...]] = None, - num_warmup_steps: int = None, - num_total_steps: int = None, - clip_grad_norm: float = None, **kwargs, ): super().__init__() - self.expert, self.optimizer, self.name = expert, optimizer, name - - if scheduler is None: - self.scheduler = None - else: - assert optimizer is not None and num_warmup_steps is not None and num_total_steps is not None - self.scheduler = scheduler(self.optimizer, num_warmup_steps, num_total_steps) - self.clip_grad_norm = clip_grad_norm + self.name, self.module, self.optimizer, self.scheduler = name, module, optimizer, scheduler self.args_schema = args_schema = tuple(args_schema or ()) self.kwargs_schema = kwargs_schema = dict(kwargs_schema or {}) assert args_schema or kwargs_schema, ( - "expert must receive at least one positional or keyword input." + f"{self.__class__.__name__} must receive at least one positional or keyword input." " Did you forget to provide args_schema/kwargs_schema?" ) @@ -76,7 +64,7 @@ def __init__( # run expert once to get outputs schema dummy_args = tuple(sample.make_zeros(DUMMY_BATCH_SIZE) for sample in args_schema) dummy_kwargs = {key: sample.make_zeros(DUMMY_BATCH_SIZE) for key, sample in kwargs_schema.items()} - dummy_outputs = self.expert(*dummy_args, **dummy_kwargs) + dummy_outputs = self.module(*dummy_args, **dummy_kwargs) outputs_schema = nested_map(BatchTensorDescriptor.from_tensor, dummy_outputs) self.forward_schema = (self.args_schema, self.kwargs_schema) # inputs for forward @@ -87,9 +75,6 @@ def __init__( self.forward_pool = TaskPool(self.forward, name=f"{self.name}_forward", **kwargs) self.backward_pool = TaskPool(self.backward, name=f"{self.name}_backward", **kwargs) - self.update_count = 0 - self.examples_processed = 0 - def forward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]: """ Apply forward pass to an aggregated batch of requests. Used by Runtime, do not call this manually; @@ -97,12 +82,7 @@ def forward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]: Subclassing: This method receives a sequence of torch tensors following ``nested_flatten(self.forward_schema)``; - It should return gradients w.r.t. inputs that follow ``nested_flatten(self.outputs_schema)``; - - .. todo we handle layer states (e.g. batchnorm stats) incorrectly, updating them twice. - .. For now, either register all buffers as outputs or avoid stateful experts - """ args, kwargs = nested_pack(inputs, structure=self.forward_schema) @@ -110,7 +90,7 @@ def forward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]: raise RuntimeError("Batch should contain more than 0 samples") with torch.no_grad(): - outputs = self.expert(*args, **kwargs) + outputs = self.module(*args, **kwargs) # Note: TaskPool requires function to accept and return a flat tuple of values, we pack/unpack it on client side return tuple(nested_flatten(outputs)) @@ -128,8 +108,6 @@ def backward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]: Runtime doesn't guarantee that backward will be performed in the same order and for the same data as forward, so we recommend stateless backward pass that re-runs expert forward pass inside backward. - .. todo correct state handling (see forward) - Please make sure to call ``ExpertBackend.apply_gradients`` here, otherwise the expert will not train """ (args, kwargs), grad_outputs = nested_pack(inputs, structure=self.backward_schema) @@ -148,7 +126,7 @@ def backward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]: batch_size = args[0].size(0) - outputs = self.expert(*args, **kwargs) + outputs = self.module(*args, **kwargs) assert nested_compare(outputs, grad_outputs), "outputs and grad_outputs must have the same structure" outputs_flat = tuple(nested_flatten(outputs)) @@ -163,65 +141,47 @@ def backward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]: torch.autograd.backward( outputs_flat, grad_tensors=grad_outputs_flat, create_graph=False, retain_graph=False ) - self.apply_gradients(batch_size) + self.on_backward(batch_size) return tuple( x.grad if isinstance(x.grad, torch.Tensor) else torch.zeros_like(x) for x in nested_flatten((args, kwargs)) ) - def apply_gradients(self, batch_size) -> None: + def on_backward(self, batch_size: int) -> None: """ Train the expert for one step. This method is called by ``ExpertBackend.backward`` after computing gradients. """ - if self.clip_grad_norm is not None: - torch.nn.utils.clip_grad_norm_(self.expert.parameters(), self.clip_grad_norm) - - self.optimizer.step() - self.optimizer.zero_grad() + if self.optimizer is not None: + self.optimizer.step() + self.optimizer.zero_grad() if self.scheduler is not None: self.scheduler.step() - self.update_count += 1 - self.examples_processed += batch_size - - def get_stats(self) -> Dict: - """ - Return current expert training statistics (number of updates, number of processed examples after - last optimizer step) - """ - return {"updates": self.update_count, "examples_processed": self.examples_processed} - - def get_full_state(self) -> Dict: + def state_dict(self) -> Dict: """ Return the current state of the expert (including batch processing statistics) """ - full_state = { - "stats": self.get_stats(), - "model": self.expert.state_dict(), - "optimizer": self.optimizer.state_dict(), - "scheduler": {} if self.scheduler is None else self.scheduler.state_dict(), - } + full_state = dict(module=self.module.state_dict()) + if self.optimizer is not None: + full_state["optimizer"] = self.optimizer.state_dict() + if self.scheduler is not None: + full_state["scheduler"] = self.scheduler.state_dict() return full_state - def load_full_state(self, state_dict: Dict): - if "stats" in state_dict: - self.update_count = state_dict["stats"]["updates"] - self.examples_processed = state_dict["stats"]["examples_processed"] - else: - logger.warning(f"Batch processing stats missing for expert {self.name}") + def load_state_dict(self, state_dict: Dict): + self.module.load_state_dict(state_dict["module"]) + if self.optimizer is not None: + if "optimizer" in state_dict: + self.optimizer.load_state_dict(state_dict["optimizer"]) + else: + logger.warning(f"Optimizer state missing for {self.name}") - self.expert.load_state_dict(state_dict["model"]) - - if "optimizer" in state_dict: - self.optimizer.load_state_dict(state_dict["optimizer"]) - else: - logger.warning(f"Optimizer state missing for expert {self.name}") - - if self.scheduler is not None and "scheduler" in state_dict: - self.scheduler.load_state_dict(state_dict["scheduler"]) - else: - logger.warning(f"Learning rate scheduler state missing for expert {self.name}") + if self.scheduler is not None: + if "scheduler" in state_dict: + self.scheduler.load_state_dict(state_dict["scheduler"]) + else: + logger.warning(f"Learning rate scheduler state missing for {self.name}") def get_info(self) -> Dict[str, Any]: """Get expert parameters and stats. Used by RemoteExpert to check shapes and for DMoE orchestration.""" diff --git a/hivemind/moe/server/layers/optim.py b/hivemind/moe/server/layers/optim.py new file mode 100644 index 000000000..f94f911b1 --- /dev/null +++ b/hivemind/moe/server/layers/optim.py @@ -0,0 +1,63 @@ +import torch + + +class OptimizerWrapper(torch.optim.Optimizer): + """A wrapper for pytorch.optim.Optimizer that forwards all methods to the wrapped optimizer""" + + def __init__(self, optim: torch.optim.Optimizer): + object.__init__(self) + self.optim = optim + + @property + def defaults(self): + return self.optim.defaults + + @property + def state(self): + return self.optim.state + + def __getstate__(self): + return self.optim.__getstate__() + + def __setstate__(self, state): + self.optim.__setstate__(state) + + def __repr__(self): + return f"{self.__class__.__name__}({repr(self.optim)})" + + def state_dict(self): + return self.optim.state_dict() + + def load_state_dict(self, state_dict: dict) -> None: + return self.optim.load_state_dict(state_dict) + + def step(self, *args, **kwargs): + return self.optim.step(*args, **kwargs) + + def zero_grad(self, *args, **kwargs): + return self.optim.zero_grad(*args, **kwargs) + + @property + def param_groups(self): + return self.optim.param_groups + + def add_param_group(self, param_group: dict) -> None: + return self.optim.add_param_group(param_group) + + +class ClippingWrapper(OptimizerWrapper): + """A wrapper to pytorch.optimizer that clips gradients by global norm before each step""" + + def __init__(self, optim: torch.optim.Optimizer, clip_grad_norm: float): + super().__init__(optim) + self.clip_grad_norm = clip_grad_norm + + def step(self, *args, **kwargs): + parameters = tuple(param for group in self.param_groups for param in group["params"]) + torch.nn.utils.clip_grad_norm_(parameters, self.clip_grad_norm) + return super().step(*args, **kwargs) + + @classmethod + def create(cls, optim_cls: type, *args, clip_grad_norm: float, **kwargs): + """Create a wrapped optimizer and wrap it with clipping""" + return cls(optim=optim_cls(*args, **kwargs), clip_grad_norm=clip_grad_norm) diff --git a/hivemind/moe/server/runtime.py b/hivemind/moe/server/runtime.py index b79410fc5..255d5ff69 100644 --- a/hivemind/moe/server/runtime.py +++ b/hivemind/moe/server/runtime.py @@ -70,7 +70,7 @@ def run(self): pool.start() if self.device is not None: for expert_backend in self.expert_backends.values(): - expert_backend.expert.to(self.device) + expert_backend.module.to(self.device) with mp.pool.ThreadPool(self.sender_threads) as output_sender_pool: try: diff --git a/hivemind/moe/server/server.py b/hivemind/moe/server/server.py index f7b691561..d8557caa7 100644 --- a/hivemind/moe/server/server.py +++ b/hivemind/moe/server/server.py @@ -22,6 +22,7 @@ name_to_input, schedule_name_to_scheduler, ) +from hivemind.moe.server.layers.optim import ClippingWrapper from hivemind.moe.server.runtime import Runtime from hivemind.p2p import PeerInfo from hivemind.proto.runtime_pb2 import CompressionType @@ -95,7 +96,7 @@ def create( optim_cls=torch.optim.Adam, scheduler: str = "none", num_warmup_steps=None, - num_total_steps=None, + num_training_steps=None, clip_grad_norm=None, num_handlers=None, min_batch_size=1, @@ -129,7 +130,7 @@ def create( :param optim_cls: uses this optimizer to train all experts :param scheduler: if not `none`, the name of the expert LR scheduler :param num_warmup_steps: the number of warmup steps for LR schedule - :param num_total_steps: the total number of steps for LR schedule + :param num_training_steps: the total number of steps for LR schedule :param clip_grad_norm: maximum gradient norm used for clipping :param initial_peers: multiaddrs of one or more active DHT peers (if you want to join an existing DHT) @@ -180,7 +181,6 @@ def create( num_experts = len(expert_uids) num_handlers = num_handlers if num_handlers is not None else num_experts * 8 - optim_cls = optim_cls if optim_cls is not None else partial(torch.optim.SGD, lr=0.0) device = device or ("cuda" if torch.cuda.is_available() else "cpu") sample_input = name_to_input[expert_cls](DUMMY_BATCH_SIZE, hidden_dim) @@ -189,21 +189,26 @@ def create( else: args_schema = (BatchTensorDescriptor.from_tensor(sample_input, compression),) - scheduler = schedule_name_to_scheduler[scheduler] + scheduler_cls = schedule_name_to_scheduler[scheduler] + if scheduler_cls is not None: + scheduler_cls = partial( + scheduler_cls, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps + ) # initialize experts experts = {} for expert_uid in expert_uids: expert = name_to_block[expert_cls](hidden_dim) + optimizer = optim_cls(expert.parameters()) if optim_cls is not None else None + scheduler = scheduler_cls(optimizer) if scheduler_cls is not None else None + if clip_grad_norm is not None: + scheduler = ClippingWrapper(scheduler, clip_grad_norm) experts[expert_uid] = ExpertBackend( name=expert_uid, - expert=expert, + module=expert, args_schema=args_schema, - optimizer=optim_cls(expert.parameters()), + optimizer=optimizer, scheduler=scheduler, - num_warmup_steps=num_warmup_steps, - num_total_steps=num_total_steps, - clip_grad_norm=clip_grad_norm, min_batch_size=min_batch_size, max_batch_size=max_batch_size, ) @@ -230,8 +235,8 @@ def run(self): """ logger.info(f"Server started with {len(self.experts)} experts:") for expert_name, backend in self.experts.items(): - num_parameters = sum(p.numel() for p in backend.expert.parameters() if p.requires_grad) - logger.info(f"{expert_name}: {backend.expert.__class__.__name__}, {num_parameters} parameters") + num_parameters = sum(p.numel() for p in backend.module.parameters() if p.requires_grad) + logger.info(f"{expert_name}: {backend.module.__class__.__name__}, {num_parameters} parameters") if not self.dht.is_alive(): self.dht.run_in_background(await_ready=True) diff --git a/tests/test_expert_backend.py b/tests/test_expert_backend.py index 752cc96b1..90200057c 100644 --- a/tests/test_expert_backend.py +++ b/tests/test_expert_backend.py @@ -24,11 +24,13 @@ def example_experts(): args_schema = (BatchTensorDescriptor(1),) expert_backend = ExpertBackend( name=EXPERT_NAME, - expert=expert, + module=expert, optimizer=opt, - scheduler=get_linear_schedule_with_warmup, - num_warmup_steps=BACKWARD_PASSES_BEFORE_SAVE, - num_total_steps=BACKWARD_PASSES_BEFORE_SAVE + BACKWARD_PASSES_AFTER_SAVE, + scheduler=get_linear_schedule_with_warmup( + opt, + num_warmup_steps=BACKWARD_PASSES_BEFORE_SAVE, + num_training_steps=BACKWARD_PASSES_BEFORE_SAVE + BACKWARD_PASSES_AFTER_SAVE, + ), args_schema=args_schema, outputs_schema=BatchTensorDescriptor(1), max_batch_size=1, @@ -39,7 +41,7 @@ def example_experts(): @pytest.mark.forked def test_save_load_checkpoints(example_experts): - expert = example_experts[EXPERT_NAME].expert + expert = example_experts[EXPERT_NAME].module with TemporaryDirectory() as tmpdir: tmp_path = Path(tmpdir) @@ -79,7 +81,7 @@ def test_restore_update_count(example_experts): expert_backend.backward(batch, loss_grad) load_experts(example_experts, tmp_path) - assert expert_backend.update_count == BACKWARD_PASSES_BEFORE_SAVE + assert expert_backend.scheduler._step_count == BACKWARD_PASSES_BEFORE_SAVE + 1 @pytest.mark.forked diff --git a/tests/test_moe.py b/tests/test_moe.py index 402306c62..b2bceecdb 100644 --- a/tests/test_moe.py +++ b/tests/test_moe.py @@ -259,14 +259,14 @@ def test_client_anomaly_detection(): expert = name_to_block["ffn"](HID_DIM) experts[f"expert.{i}"] = ExpertBackend( name=f"expert.{i}", - expert=expert, + module=expert, optimizer=torch.optim.Adam(expert.parameters()), args_schema=(BatchTensorDescriptor(HID_DIM),), outputs_schema=BatchTensorDescriptor(HID_DIM), max_batch_size=16, ) - experts["expert.3"].expert.ffn.weight.data[0, 0] = float("nan") + experts["expert.3"].module.ffn.weight.data[0, 0] = float("nan") dht = DHT(start=True) server = Server(dht, experts, num_connection_handlers=1) From fa2da451bf08d4705b5acfa7aea8f8e171348c99 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Tue, 14 Jun 2022 07:25:34 +0300 Subject: [PATCH 13/33] rename --- hivemind/moe/server/expert_backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hivemind/moe/server/expert_backend.py b/hivemind/moe/server/expert_backend.py index c3edcde3b..7b3859b5c 100644 --- a/hivemind/moe/server/expert_backend.py +++ b/hivemind/moe/server/expert_backend.py @@ -56,7 +56,7 @@ def __init__( self.args_schema = args_schema = tuple(args_schema or ()) self.kwargs_schema = kwargs_schema = dict(kwargs_schema or {}) assert args_schema or kwargs_schema, ( - f"{self.__class__.__name__} must receive at least one positional or keyword input." + f"Module must take at least one positional or keyword input." " Did you forget to provide args_schema/kwargs_schema?" ) From 6c49fe96b22cd49813237334430136dbbd1e87f9 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Tue, 14 Jun 2022 07:33:40 +0300 Subject: [PATCH 14/33] black-isort --- hivemind/moe/server/expert_backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hivemind/moe/server/expert_backend.py b/hivemind/moe/server/expert_backend.py index 7b3859b5c..78d041082 100644 --- a/hivemind/moe/server/expert_backend.py +++ b/hivemind/moe/server/expert_backend.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, Sequence, Tuple, Union, Optional +from typing import Any, Dict, Optional, Sequence, Tuple, Union import torch from torch import nn From 2c77de0e6be3ceb3f77347ec9ee26ba4115ab60b Mon Sep 17 00:00:00 2001 From: justheuristic Date: Tue, 14 Jun 2022 08:18:22 +0300 Subject: [PATCH 15/33] rename --- docs/modules/server.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/modules/server.rst b/docs/modules/server.rst index 4e8c61456..65cf230de 100644 --- a/docs/modules/server.rst +++ b/docs/modules/server.rst @@ -27,7 +27,7 @@ The hivemind.moe.server module is organized as follows: .. _ExpertBackend: .. autoclass:: ExpertBackend - :members: forward, backward, apply_gradients, get_info, get_pools + :members: forward, backward, on_backward, get_info, get_pools :member-order: bysource .. currentmodule:: hivemind.moe.server.runtime From b1873e1ba063d3918807cabdcf0e70d5504acf72 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Tue, 14 Jun 2022 16:42:57 +0300 Subject: [PATCH 16/33] un-hardcode experts from private interface on server side --- benchmarks/benchmark_throughput.py | 2 +- hivemind/moe/server/runtime.py | 16 ++++++++-------- hivemind/moe/server/server.py | 28 ++++++++++++++-------------- 3 files changed, 23 insertions(+), 23 deletions(-) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index e62695c8b..58678c65f 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -133,7 +133,7 @@ def benchmark_throughput( server = Server( dht=server_dht, - expert_backends=experts, + backends=experts, num_connection_handlers=num_handlers, device=device, ) diff --git a/hivemind/moe/server/runtime.py b/hivemind/moe/server/runtime.py index 255d5ff69..6c40cddb4 100644 --- a/hivemind/moe/server/runtime.py +++ b/hivemind/moe/server/runtime.py @@ -20,7 +20,7 @@ class Runtime(threading.Thread): """ - A group of processes that processes incoming requests for multiple experts on a shared device. + A group of processes that processes incoming requests for multiple module backends on a shared device. Runtime is usually created and managed by Server, humans need not apply. For debugging, you can start runtime manually with .start() or .run() @@ -29,11 +29,11 @@ class Runtime(threading.Thread): >>> runtime = Runtime(expert_backends) >>> runtime.start() # start runtime in background thread. To start in current thread, use runtime.run() >>> runtime.ready.wait() # await for runtime to load all experts on device and create request pools - >>> future = runtime.expert_backends['expert_name'].forward_pool.submit_task(*expert_inputs) + >>> future = runtime.backends['expert_name'].forward_pool.submit_task(*expert_inputs) >>> print("Returned:", future.result()) >>> runtime.shutdown() - :param expert_backends: a dict [expert uid -> ExpertBackend] + :param backends: a dict [expert uid -> ExpertBackend] :param prefetch_batches: form up to this many batches in advance :param sender_threads: dispatches outputs from finished batches using this many asynchronous threads :param device: if specified, moves all experts and data to this device via .to(device=device). @@ -46,15 +46,15 @@ class Runtime(threading.Thread): def __init__( self, - expert_backends: Dict[str, ExpertBackend], + backends: Dict[str, ExpertBackend], prefetch_batches=64, sender_threads: int = 1, device: torch.device = None, stats_report_interval: Optional[int] = None, ): super().__init__() - self.expert_backends = expert_backends - self.pools = tuple(chain(*(expert.get_pools() for expert in expert_backends.values()))) + self.backends = backends + self.pools = tuple(chain(*(expert.get_pools() for expert in backends.values()))) self.device, self.prefetch_batches, self.sender_threads = device, prefetch_batches, sender_threads self.shutdown_recv, self.shutdown_send = mp.Pipe(duplex=False) self.shutdown_trigger = mp.Event() @@ -69,8 +69,8 @@ def run(self): if not pool.is_alive(): pool.start() if self.device is not None: - for expert_backend in self.expert_backends.values(): - expert_backend.module.to(self.device) + for backend in self.backends.values(): + backend.module.to(self.device) with mp.pool.ThreadPool(self.sender_threads) as output_sender_pool: try: diff --git a/hivemind/moe/server/server.py b/hivemind/moe/server/server.py index d8557caa7..8a55b1277 100644 --- a/hivemind/moe/server/server.py +++ b/hivemind/moe/server/server.py @@ -34,7 +34,7 @@ class Server(threading.Thread): """ - Server allows you to host "experts" - pytorch subnetworks used by Decentralized Mixture of Experts. + Server allows you to host "experts" - pytorch subnetworks that can be accessed by your peers in the swarm. After creation, a server should be started: see Server.run or Server.run_in_background. A working server does two things: @@ -42,7 +42,7 @@ class Server(threading.Thread): - publishes updates to expert status every :update_period: seconds :type dht: an instance of hivemind.DHT. Server will use DHT for all network interactions. - :param expert_backends: dict{expert uid (str) : ExpertBackend} for all expert hosted by this server. + :param backends: dict{expert uid (str) : ExpertBackend} for all expert hosted by this server. :param num_connection_handlers: maximum number of simultaneous requests. Please note that the default value of 1 if too small for normal functioning, we recommend 4 handlers per expert backend. :param update_period: how often will server attempt to publish its state (i.e. experts) to the DHT; @@ -55,7 +55,7 @@ class Server(threading.Thread): def __init__( self, dht: DHT, - expert_backends: Dict[str, ExpertBackend], + backends: Dict[str, ExpertBackend], num_connection_handlers: int = 1, update_period: float = 30, expiration: Optional[float] = None, @@ -64,18 +64,18 @@ def __init__( **kwargs, ): super().__init__() - self.dht, self.experts, self.update_period = dht, expert_backends, update_period + self.dht, self.backends, self.update_period = dht, backends, update_period - self.conn_handlers = [ConnectionHandler(dht, self.experts) for _ in range(num_connection_handlers)] + self.conn_handlers = [ConnectionHandler(dht, self.backends) for _ in range(num_connection_handlers)] if checkpoint_dir is not None: - self.checkpoint_saver = CheckpointSaver(expert_backends, checkpoint_dir, update_period) + self.checkpoint_saver = CheckpointSaver(backends, checkpoint_dir, update_period) else: self.checkpoint_saver = None - self.runtime = Runtime(self.experts, **kwargs) + self.runtime = Runtime(self.backends, **kwargs) - if self.experts: + if self.backends: self.dht_handler_thread = DHTHandlerThread( - experts=self.experts, + backends=self.backends, dht=self.dht, update_period=self.update_period, expiration=expiration, @@ -114,7 +114,7 @@ def create( **kwargs, ) -> Server: """ - Instantiate a server with several identical experts. See argparse comments below for details + Instantiate a server with several identical modules. See argparse comments below for details :param num_experts: run this many identical experts :param expert_pattern: a string pattern or a list of expert uids, example: myprefix.[0:32].[0:256]\ @@ -233,15 +233,15 @@ def run(self): Starts Server in the current thread. Initializes dht if necessary, starts connection handlers, runs Runtime (self.runtime) to process incoming requests. """ - logger.info(f"Server started with {len(self.experts)} experts:") - for expert_name, backend in self.experts.items(): + logger.info(f"Server started with {len(self.backends)} modules:") + for expert_name, backend in self.backends.items(): num_parameters = sum(p.numel() for p in backend.module.parameters() if p.requires_grad) logger.info(f"{expert_name}: {backend.module.__class__.__name__}, {num_parameters} parameters") if not self.dht.is_alive(): self.dht.run_in_background(await_ready=True) - if self.experts: + if self.backends: self.dht_handler_thread.start() if self.checkpoint_saver is not None: @@ -292,7 +292,7 @@ def shutdown(self): process.join() logger.debug("Connection handlers terminated") - if self.experts: + if self.backends: self.dht_handler_thread.stop.set() self.dht_handler_thread.join() From 9f3187f5ac084604d20b12e761abe1632d5e1eee Mon Sep 17 00:00:00 2001 From: justheuristic Date: Tue, 14 Jun 2022 16:44:03 +0300 Subject: [PATCH 17/33] un-hardcode experts from private interface on server side --- hivemind/moe/server/dht_handler.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/hivemind/moe/server/dht_handler.py b/hivemind/moe/server/dht_handler.py index ea1a4f658..9c8469ead 100644 --- a/hivemind/moe/server/dht_handler.py +++ b/hivemind/moe/server/dht_handler.py @@ -20,20 +20,20 @@ class DHTHandlerThread(threading.Thread): - def __init__(self, experts, dht: DHT, update_period: float = 30, expiration: Optional[int] = None, **kwargs): + def __init__(self, backends, dht: DHT, update_period: float = 30, expiration: Optional[int] = None, **kwargs): super().__init__(**kwargs) if expiration is None: expiration = max(2 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS) - self.experts = experts + self.backends = backends self.dht = dht self.update_period = update_period self.expiration = expiration self.stop = threading.Event() def run(self) -> None: - declare_experts(self.dht, self.experts.keys(), expiration_time=get_dht_time() + self.expiration) + declare_experts(self.dht, self.backends.keys(), expiration_time=get_dht_time() + self.expiration) while not self.stop.wait(self.update_period): - declare_experts(self.dht, self.experts.keys(), expiration_time=get_dht_time() + self.expiration) + declare_experts(self.dht, self.backends.keys(), expiration_time=get_dht_time() + self.expiration) def declare_experts( From d87a7b1e3efb433c5f0e6840c11f6d9d1e3dc07a Mon Sep 17 00:00:00 2001 From: justheuristic Date: Tue, 14 Jun 2022 16:46:08 +0300 Subject: [PATCH 18/33] un-hardcode experts from private interface on server side --- hivemind/moe/server/checkpoints.py | 8 ++++---- hivemind/moe/server/runtime.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/hivemind/moe/server/checkpoints.py b/hivemind/moe/server/checkpoints.py index cd67b5d6f..a9d0d32be 100644 --- a/hivemind/moe/server/checkpoints.py +++ b/hivemind/moe/server/checkpoints.py @@ -34,20 +34,20 @@ def copy_tree(src: str, dst: str): class CheckpointSaver(threading.Thread): - def __init__(self, expert_backends: Dict[str, ExpertBackend], checkpoint_dir: Path, update_period: float): + def __init__(self, backends: Dict[str, ExpertBackend], checkpoint_dir: Path, update_period: float): super().__init__() assert is_directory(checkpoint_dir) - self.expert_backends = expert_backends + self.backends = backends self.update_period = update_period self.checkpoint_dir = checkpoint_dir self.stop = threading.Event() # create expert directories to ensure that the directory is writable and checkpoints can be loaded - store_experts(self.expert_backends, self.checkpoint_dir) + store_experts(self.backends, self.checkpoint_dir) def run(self) -> None: while not self.stop.wait(self.update_period): - store_experts(self.expert_backends, self.checkpoint_dir) + store_experts(self.backends, self.checkpoint_dir) def store_experts(experts: Dict[str, ExpertBackend], checkpoint_dir: Path): diff --git a/hivemind/moe/server/runtime.py b/hivemind/moe/server/runtime.py index 6c40cddb4..45fdda618 100644 --- a/hivemind/moe/server/runtime.py +++ b/hivemind/moe/server/runtime.py @@ -54,7 +54,7 @@ def __init__( ): super().__init__() self.backends = backends - self.pools = tuple(chain(*(expert.get_pools() for expert in backends.values()))) + self.pools = tuple(chain(*(backend.get_pools() for backend in backends.values()))) self.device, self.prefetch_batches, self.sender_threads = device, prefetch_batches, sender_threads self.shutdown_recv, self.shutdown_send = mp.Pipe(duplex=False) self.shutdown_trigger = mp.Event() From 65d622b746765f78402306420c41bc9f3cf6c526 Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Tue, 14 Jun 2022 16:46:37 +0300 Subject: [PATCH 19/33] wrap optimizer, not scheduler --- hivemind/moe/server/server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hivemind/moe/server/server.py b/hivemind/moe/server/server.py index 8a55b1277..9b60ab1d8 100644 --- a/hivemind/moe/server/server.py +++ b/hivemind/moe/server/server.py @@ -202,7 +202,7 @@ def create( optimizer = optim_cls(expert.parameters()) if optim_cls is not None else None scheduler = scheduler_cls(optimizer) if scheduler_cls is not None else None if clip_grad_norm is not None: - scheduler = ClippingWrapper(scheduler, clip_grad_norm) + optimizer = ClippingWrapper(optimizer, clip_grad_norm) experts[expert_uid] = ExpertBackend( name=expert_uid, module=expert, From a00fb9e2a97dedf486a08ef03b764b5d847d270e Mon Sep 17 00:00:00 2001 From: justheuristic Date: Tue, 14 Jun 2022 17:57:01 +0300 Subject: [PATCH 20/33] ModuleBackend --- benchmarks/benchmark_throughput.py | 4 ++-- docs/modules/server.rst | 8 ++++---- hivemind/__init__.py | 2 +- hivemind/moe/__init__.py | 2 +- hivemind/moe/server/__init__.py | 2 +- hivemind/moe/server/checkpoints.py | 8 ++++---- hivemind/moe/server/connection_handler.py | 6 +++--- hivemind/moe/server/layers/dropout.py | 2 +- .../{expert_backend.py => module_backend.py} | 14 +++++++------- hivemind/moe/server/runtime.py | 8 ++++---- hivemind/moe/server/server.py | 10 +++++----- tests/test_connection_handler.py | 10 +++++----- tests/test_expert_backend.py | 4 ++-- tests/test_moe.py | 4 ++-- 14 files changed, 42 insertions(+), 42 deletions(-) rename hivemind/moe/server/{expert_backend.py => module_backend.py} (96%) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 58678c65f..358083a9e 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -10,7 +10,7 @@ from hivemind.moe.client.expert import RemoteExpert from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker from hivemind.moe.expert_uid import ExpertInfo -from hivemind.moe.server import ExpertBackend, Server +from hivemind.moe.server import ModuleBackend, Server from hivemind.moe.server.layers import name_to_block from hivemind.p2p import P2P from hivemind.utils.limits import increase_file_limit @@ -121,7 +121,7 @@ def benchmark_throughput( experts = {} for i in range(num_experts): expert = torch.jit.script(name_to_block[expert_cls](hid_dim)) - experts[f"expert.{i}"] = ExpertBackend( + experts[f"expert.{i}"] = ModuleBackend( name=f"expert.{i}", module=expert, optimizer=torch.optim.Adam(expert.parameters()), diff --git a/docs/modules/server.rst b/docs/modules/server.rst index 65cf230de..a958ec057 100644 --- a/docs/modules/server.rst +++ b/docs/modules/server.rst @@ -9,9 +9,9 @@ or as a part of **hivemind.moe.client.RemoteMixtureOfExperts** that finds the mo The hivemind.moe.server module is organized as follows: - Server_ is the main class that publishes experts, accepts incoming requests, and passes them to Runtime_ for compute. -- ExpertBackend_ is a wrapper for `torch.nn.Module `_ \ +- ModuleBackend_ is a wrapper for `torch.nn.Module `_ \ that can be accessed by remote clients. It has two TaskPool_ s for forward and backward requests. -- Runtime_ balances the device (GPU) usage between several ExpertBackend_ instances that each service one expert. +- Runtime_ balances the device (GPU) usage between several ModuleBackend_ instances that each service one expert. - TaskPool_ stores incoming requests for a batch-parallel computation (e.g. forward pass), groups them into batches \ and offers those batches to Runtime_ for processing. @@ -25,8 +25,8 @@ The hivemind.moe.server module is organized as follows: :members: :member-order: bysource -.. _ExpertBackend: -.. autoclass:: ExpertBackend +.. _ModuleBackend: +.. autoclass:: ModuleBackend :members: forward, backward, on_backward, get_info, get_pools :member-order: bysource diff --git a/hivemind/__init__.py b/hivemind/__init__.py index 32443f7f7..f74a640a7 100644 --- a/hivemind/__init__.py +++ b/hivemind/__init__.py @@ -2,7 +2,7 @@ from hivemind.compression import * from hivemind.dht import DHT from hivemind.moe import ( - ExpertBackend, + ModuleBackend, RemoteExpert, RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts, diff --git a/hivemind/moe/__init__.py b/hivemind/moe/__init__.py index 00905507d..1436ab35d 100644 --- a/hivemind/moe/__init__.py +++ b/hivemind/moe/__init__.py @@ -1,6 +1,6 @@ from hivemind.moe.client import RemoteExpert, RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts from hivemind.moe.server import ( - ExpertBackend, + ModuleBackend, Server, background_server, declare_experts, diff --git a/hivemind/moe/server/__init__.py b/hivemind/moe/server/__init__.py index 1ac24db2d..7d5504608 100644 --- a/hivemind/moe/server/__init__.py +++ b/hivemind/moe/server/__init__.py @@ -1,4 +1,4 @@ from hivemind.moe.server.dht_handler import declare_experts, get_experts -from hivemind.moe.server.expert_backend import ExpertBackend +from hivemind.moe.server.module_backend import ModuleBackend from hivemind.moe.server.layers import register_expert_class from hivemind.moe.server.server import Server, background_server diff --git a/hivemind/moe/server/checkpoints.py b/hivemind/moe/server/checkpoints.py index a9d0d32be..5cb7153b8 100644 --- a/hivemind/moe/server/checkpoints.py +++ b/hivemind/moe/server/checkpoints.py @@ -8,7 +8,7 @@ import torch -from hivemind.moe.server.expert_backend import ExpertBackend +from hivemind.moe.server.module_backend import ModuleBackend from hivemind.utils.logging import get_logger logger = get_logger(__name__) @@ -34,7 +34,7 @@ def copy_tree(src: str, dst: str): class CheckpointSaver(threading.Thread): - def __init__(self, backends: Dict[str, ExpertBackend], checkpoint_dir: Path, update_period: float): + def __init__(self, backends: Dict[str, ModuleBackend], checkpoint_dir: Path, update_period: float): super().__init__() assert is_directory(checkpoint_dir) self.backends = backends @@ -50,7 +50,7 @@ def run(self) -> None: store_experts(self.backends, self.checkpoint_dir) -def store_experts(experts: Dict[str, ExpertBackend], checkpoint_dir: Path): +def store_experts(experts: Dict[str, ModuleBackend], checkpoint_dir: Path): logger.debug(f"Storing experts at {checkpoint_dir.absolute()}") assert is_directory(checkpoint_dir) timestamp = datetime.now().isoformat(sep="_") @@ -64,7 +64,7 @@ def store_experts(experts: Dict[str, ExpertBackend], checkpoint_dir: Path): copy_tree(tmpdirname, str(checkpoint_dir)) -def load_experts(experts: Dict[str, ExpertBackend], checkpoint_dir: Path): +def load_experts(experts: Dict[str, ModuleBackend], checkpoint_dir: Path): assert is_directory(checkpoint_dir) for expert_name, expert in experts.items(): checkpoints_folder = checkpoint_dir / expert_name diff --git a/hivemind/moe/server/connection_handler.py b/hivemind/moe/server/connection_handler.py index ff610a7bc..435a0b7d8 100644 --- a/hivemind/moe/server/connection_handler.py +++ b/hivemind/moe/server/connection_handler.py @@ -6,7 +6,7 @@ from hivemind.compression import deserialize_tensor_stream, deserialize_torch_tensor, serialize_torch_tensor from hivemind.dht import DHT -from hivemind.moe.server.expert_backend import ExpertBackend +from hivemind.moe.server.module_backend import ModuleBackend from hivemind.moe.server.task_pool import TaskPool from hivemind.p2p import P2PContext, ServicerBase from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE, P2P @@ -25,10 +25,10 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase): :note: ConnectionHandler is designed so as to allow using multiple handler processes for the same port :param dht: a running hivemind.dht.DHT, used to let other peers connect to this one - :param experts: a dict [UID -> ExpertBackend] with all active experts + :param experts: a dict [UID -> ModuleBackend] with all active experts """ - def __init__(self, dht: DHT, experts: Dict[str, ExpertBackend]): + def __init__(self, dht: DHT, experts: Dict[str, ModuleBackend]): super().__init__() self.dht, self.experts = dht, experts self._p2p: Optional[P2P] = None diff --git a/hivemind/moe/server/layers/dropout.py b/hivemind/moe/server/layers/dropout.py index 8efad903e..526787c7c 100644 --- a/hivemind/moe/server/layers/dropout.py +++ b/hivemind/moe/server/layers/dropout.py @@ -19,7 +19,7 @@ def backward(ctx, grad_output): class DeterministicDropout(nn.Module): """ Custom dropout layer which accepts dropout mask as an input (drop_prob is only used for scaling input activations). - Can be used with RemoteExpert/ExpertBackend to ensure that dropout mask is the same at forward and backward steps + Can be used with RemoteExpert/ModuleBackend to ensure that dropout mask is the same at forward and backward steps """ def __init__(self, drop_prob): diff --git a/hivemind/moe/server/expert_backend.py b/hivemind/moe/server/module_backend.py similarity index 96% rename from hivemind/moe/server/expert_backend.py rename to hivemind/moe/server/module_backend.py index 78d041082..901741bc0 100644 --- a/hivemind/moe/server/expert_backend.py +++ b/hivemind/moe/server/module_backend.py @@ -12,10 +12,10 @@ logger = get_logger(__name__) -class ExpertBackend: +class ModuleBackend: """ - ExpertBackend is a wrapper around torch module that allows it to run tasks asynchronously with Runtime - By default, ExpertBackend handles three types of requests: + ModuleBackend is a wrapper around torch module that allows it to run tasks asynchronously with Runtime + By default, ModuleBackend handles three types of requests: - forward - receive inputs and compute outputs. Concurrent requests will be batched for better GPU utilization. - backward - receive gradients w.r.t. outputs, compute gradients w.r.t. inputs and **update expert**. Also batched. @@ -78,7 +78,7 @@ def __init__( def forward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]: """ Apply forward pass to an aggregated batch of requests. Used by Runtime, do not call this manually; - To submit a request for asynchronous processing, please use ``ExpertBackend.forward_pool.submit_task``. + To submit a request for asynchronous processing, please use ``ModuleBackend.forward_pool.submit_task``. Subclassing: This method receives a sequence of torch tensors following ``nested_flatten(self.forward_schema)``; @@ -98,7 +98,7 @@ def forward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]: def backward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]: """ Apply backward pass to an aggregated batch of requests. Used by Runtime, do not call this manually - To submit a request for asynchronous processing, please use ``ExpertBackend.backward_pool.submit_task``. + To submit a request for asynchronous processing, please use ``ModuleBackend.backward_pool.submit_task``. Subclassing: This method receives a sequence of torch tensors following ``nested_flatten(self.backward_schema)``; @@ -108,7 +108,7 @@ def backward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]: Runtime doesn't guarantee that backward will be performed in the same order and for the same data as forward, so we recommend stateless backward pass that re-runs expert forward pass inside backward. - Please make sure to call ``ExpertBackend.apply_gradients`` here, otherwise the expert will not train + Please make sure to call ``ModuleBackend.apply_gradients`` here, otherwise the expert will not train """ (args, kwargs), grad_outputs = nested_pack(inputs, structure=self.backward_schema) @@ -149,7 +149,7 @@ def backward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]: def on_backward(self, batch_size: int) -> None: """ - Train the expert for one step. This method is called by ``ExpertBackend.backward`` after computing gradients. + Train the expert for one step. This method is called by ``ModuleBackend.backward`` after computing gradients. """ if self.optimizer is not None: self.optimizer.step() diff --git a/hivemind/moe/server/runtime.py b/hivemind/moe/server/runtime.py index 45fdda618..0327a92dc 100644 --- a/hivemind/moe/server/runtime.py +++ b/hivemind/moe/server/runtime.py @@ -12,7 +12,7 @@ import torch from prefetch_generator import BackgroundGenerator -from hivemind.moe.server.expert_backend import ExpertBackend +from hivemind.moe.server.module_backend import ModuleBackend from hivemind.utils import get_logger logger = get_logger(__name__) @@ -25,7 +25,7 @@ class Runtime(threading.Thread): For debugging, you can start runtime manually with .start() or .run() - >>> expert_backends = {'expert_name': ExpertBackend(**kwargs)} + >>> expert_backends = {'expert_name': ModuleBackend(**kwargs)} >>> runtime = Runtime(expert_backends) >>> runtime.start() # start runtime in background thread. To start in current thread, use runtime.run() >>> runtime.ready.wait() # await for runtime to load all experts on device and create request pools @@ -33,7 +33,7 @@ class Runtime(threading.Thread): >>> print("Returned:", future.result()) >>> runtime.shutdown() - :param backends: a dict [expert uid -> ExpertBackend] + :param backends: a dict [expert uid -> ModuleBackend] :param prefetch_batches: form up to this many batches in advance :param sender_threads: dispatches outputs from finished batches using this many asynchronous threads :param device: if specified, moves all experts and data to this device via .to(device=device). @@ -46,7 +46,7 @@ class Runtime(threading.Thread): def __init__( self, - backends: Dict[str, ExpertBackend], + backends: Dict[str, ModuleBackend], prefetch_batches=64, sender_threads: int = 1, device: torch.device = None, diff --git a/hivemind/moe/server/server.py b/hivemind/moe/server/server.py index 9b60ab1d8..0873701b4 100644 --- a/hivemind/moe/server/server.py +++ b/hivemind/moe/server/server.py @@ -15,7 +15,7 @@ from hivemind.moe.server.checkpoints import CheckpointSaver, is_directory, load_experts from hivemind.moe.server.connection_handler import ConnectionHandler from hivemind.moe.server.dht_handler import DHTHandlerThread, get_experts -from hivemind.moe.server.expert_backend import ExpertBackend +from hivemind.moe.server.module_backend import ModuleBackend from hivemind.moe.server.layers import ( add_custom_models_from_file, name_to_block, @@ -42,7 +42,7 @@ class Server(threading.Thread): - publishes updates to expert status every :update_period: seconds :type dht: an instance of hivemind.DHT. Server will use DHT for all network interactions. - :param backends: dict{expert uid (str) : ExpertBackend} for all expert hosted by this server. + :param backends: dict{expert uid (str) : ModuleBackend} for all expert hosted by this server. :param num_connection_handlers: maximum number of simultaneous requests. Please note that the default value of 1 if too small for normal functioning, we recommend 4 handlers per expert backend. :param update_period: how often will server attempt to publish its state (i.e. experts) to the DHT; @@ -55,7 +55,7 @@ class Server(threading.Thread): def __init__( self, dht: DHT, - backends: Dict[str, ExpertBackend], + backends: Dict[str, ModuleBackend], num_connection_handlers: int = 1, update_period: float = 30, expiration: Optional[float] = None, @@ -139,7 +139,7 @@ def create( :param compression: if specified, use this compression to pack all inputs, outputs and gradients by all experts hosted on this server. For a more fine-grained compression, start server in python and specify compression - for each BatchTensorProto in ExpertBackend for the respective experts. + for each BatchTensorProto in ModuleBackend for the respective experts. :param start: if True, starts server right away and returns when server is ready for requests :param stats_report_interval: interval between two reports of batch processing performance statistics @@ -203,7 +203,7 @@ def create( scheduler = scheduler_cls(optimizer) if scheduler_cls is not None else None if clip_grad_norm is not None: optimizer = ClippingWrapper(optimizer, clip_grad_norm) - experts[expert_uid] = ExpertBackend( + experts[expert_uid] = ModuleBackend( name=expert_uid, module=expert, args_schema=args_schema, diff --git a/tests/test_connection_handler.py b/tests/test_connection_handler.py index 3b4ac9ab9..afc6179f0 100644 --- a/tests/test_connection_handler.py +++ b/tests/test_connection_handler.py @@ -10,7 +10,7 @@ from hivemind.compression import deserialize_tensor_stream, deserialize_torch_tensor, serialize_torch_tensor from hivemind.dht import DHT from hivemind.moe.server.connection_handler import ConnectionHandler -from hivemind.moe.server.expert_backend import ExpertBackend +from hivemind.moe.server.module_backend import ModuleBackend from hivemind.moe.server.task_pool import TaskPool from hivemind.p2p.p2p_daemon_bindings.control import DEFAULT_MAX_MSG_SIZE, P2PHandlerError from hivemind.proto import runtime_pb2 @@ -25,7 +25,7 @@ async def test_connection_handler_info(): handler = ConnectionHandler( DHT(start=True), - dict(expert1=DummyExpertBackend("expert1", k=1), expert2=DummyExpertBackend("expert2", k=2)), + dict(expert1=DummyModuleBackend("expert1", k=1), expert2=DummyModuleBackend("expert2", k=2)), ) handler.start() @@ -48,7 +48,7 @@ async def test_connection_handler_info(): async def test_connection_handler_forward(): handler = ConnectionHandler( DHT(start=True), - dict(expert1=DummyExpertBackend("expert1", k=1), expert2=DummyExpertBackend("expert2", k=2)), + dict(expert1=DummyModuleBackend("expert1", k=1), expert2=DummyModuleBackend("expert2", k=2)), ) handler.start() @@ -109,7 +109,7 @@ async def test_connection_handler_forward(): async def test_connection_handler_backward(): handler = ConnectionHandler( DHT(start=True), - dict(expert1=DummyExpertBackend("expert1", k=1), expert2=DummyExpertBackend("expert2", k=2)), + dict(expert1=DummyModuleBackend("expert1", k=1), expert2=DummyModuleBackend("expert2", k=2)), ) handler.start() @@ -179,7 +179,7 @@ async def submit_task(self, *inputs: torch.Tensor): return [inputs[0] * self.k] -class DummyExpertBackend(ExpertBackend): +class DummyModuleBackend(ModuleBackend): def __init__(self, name: str, k: float): self.name = name self.outputs_schema = [BatchTensorDescriptor.from_tensor(torch.randn(1, 2))] diff --git a/tests/test_expert_backend.py b/tests/test_expert_backend.py index 90200057c..e9d83231b 100644 --- a/tests/test_expert_backend.py +++ b/tests/test_expert_backend.py @@ -5,7 +5,7 @@ import torch from torch.nn import Linear -from hivemind import BatchTensorDescriptor, ExpertBackend +from hivemind import BatchTensorDescriptor, ModuleBackend from hivemind.moe.server.checkpoints import load_experts, store_experts from hivemind.moe.server.layers.lr_schedule import get_linear_schedule_with_warmup @@ -22,7 +22,7 @@ def example_experts(): opt = torch.optim.SGD(expert.parameters(), PEAK_LR) args_schema = (BatchTensorDescriptor(1),) - expert_backend = ExpertBackend( + expert_backend = ModuleBackend( name=EXPERT_NAME, module=expert, optimizer=opt, diff --git a/tests/test_moe.py b/tests/test_moe.py index b2bceecdb..46e9279a8 100644 --- a/tests/test_moe.py +++ b/tests/test_moe.py @@ -7,7 +7,7 @@ from hivemind.moe.client.moe import DUMMY, RemoteMixtureOfExperts, _RemoteCallMany from hivemind.moe.client.switch_moe import RemoteSwitchMixtureOfExperts from hivemind.moe.expert_uid import ExpertInfo -from hivemind.moe.server import ExpertBackend, Server, background_server, declare_experts +from hivemind.moe.server import ModuleBackend, Server, background_server, declare_experts from hivemind.moe.server.layers import name_to_block from hivemind.p2p.p2p_daemon_bindings.control import P2PDaemonError from hivemind.utils import BatchTensorDescriptor, get_dht_time @@ -257,7 +257,7 @@ def test_client_anomaly_detection(): experts = {} for i in range(4): expert = name_to_block["ffn"](HID_DIM) - experts[f"expert.{i}"] = ExpertBackend( + experts[f"expert.{i}"] = ModuleBackend( name=f"expert.{i}", module=expert, optimizer=torch.optim.Adam(expert.parameters()), From 5569c4202eb22fb0ee2e2bea9dab0daa4f31d109 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Tue, 14 Jun 2022 17:59:53 +0300 Subject: [PATCH 21/33] ModuleBackend --- benchmarks/benchmark_throughput.py | 6 +++--- hivemind/moe/server/checkpoints.py | 8 ++++---- hivemind/moe/server/dht_handler.py | 8 ++++---- hivemind/moe/server/runtime.py | 16 ++++++++-------- hivemind/moe/server/server.py | 24 ++++++++++++------------ 5 files changed, 31 insertions(+), 31 deletions(-) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 358083a9e..c9af194c4 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -118,10 +118,10 @@ def benchmark_throughput( timestamps["launched_clients"] = timestamps["began_launching_server"] = time.perf_counter() device = device or ("cuda" if torch.cuda.is_available() else "cpu") - experts = {} + module_backends = {} for i in range(num_experts): expert = torch.jit.script(name_to_block[expert_cls](hid_dim)) - experts[f"expert.{i}"] = ModuleBackend( + module_backends[f"expert.{i}"] = ModuleBackend( name=f"expert.{i}", module=expert, optimizer=torch.optim.Adam(expert.parameters()), @@ -133,7 +133,7 @@ def benchmark_throughput( server = Server( dht=server_dht, - backends=experts, + module_backends=module_backends, num_connection_handlers=num_handlers, device=device, ) diff --git a/hivemind/moe/server/checkpoints.py b/hivemind/moe/server/checkpoints.py index 5cb7153b8..6003a1c39 100644 --- a/hivemind/moe/server/checkpoints.py +++ b/hivemind/moe/server/checkpoints.py @@ -34,20 +34,20 @@ def copy_tree(src: str, dst: str): class CheckpointSaver(threading.Thread): - def __init__(self, backends: Dict[str, ModuleBackend], checkpoint_dir: Path, update_period: float): + def __init__(self, module_backends: Dict[str, ModuleBackend], checkpoint_dir: Path, update_period: float): super().__init__() assert is_directory(checkpoint_dir) - self.backends = backends + self.module_backends = module_backends self.update_period = update_period self.checkpoint_dir = checkpoint_dir self.stop = threading.Event() # create expert directories to ensure that the directory is writable and checkpoints can be loaded - store_experts(self.backends, self.checkpoint_dir) + store_experts(self.module_backends, self.checkpoint_dir) def run(self) -> None: while not self.stop.wait(self.update_period): - store_experts(self.backends, self.checkpoint_dir) + store_experts(self.module_backends, self.checkpoint_dir) def store_experts(experts: Dict[str, ModuleBackend], checkpoint_dir: Path): diff --git a/hivemind/moe/server/dht_handler.py b/hivemind/moe/server/dht_handler.py index 9c8469ead..eb9c64389 100644 --- a/hivemind/moe/server/dht_handler.py +++ b/hivemind/moe/server/dht_handler.py @@ -20,20 +20,20 @@ class DHTHandlerThread(threading.Thread): - def __init__(self, backends, dht: DHT, update_period: float = 30, expiration: Optional[int] = None, **kwargs): + def __init__(self, module_backends, dht: DHT, update_period: float = 30, expiration: Optional[int] = None, **kwargs): super().__init__(**kwargs) if expiration is None: expiration = max(2 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS) - self.backends = backends + self.module_backends = module_backends self.dht = dht self.update_period = update_period self.expiration = expiration self.stop = threading.Event() def run(self) -> None: - declare_experts(self.dht, self.backends.keys(), expiration_time=get_dht_time() + self.expiration) + declare_experts(self.dht, self.module_backends.keys(), expiration_time=get_dht_time() + self.expiration) while not self.stop.wait(self.update_period): - declare_experts(self.dht, self.backends.keys(), expiration_time=get_dht_time() + self.expiration) + declare_experts(self.dht, self.module_backends.keys(), expiration_time=get_dht_time() + self.expiration) def declare_experts( diff --git a/hivemind/moe/server/runtime.py b/hivemind/moe/server/runtime.py index 0327a92dc..1e750812f 100644 --- a/hivemind/moe/server/runtime.py +++ b/hivemind/moe/server/runtime.py @@ -25,15 +25,15 @@ class Runtime(threading.Thread): For debugging, you can start runtime manually with .start() or .run() - >>> expert_backends = {'expert_name': ModuleBackend(**kwargs)} - >>> runtime = Runtime(expert_backends) + >>> module_backends = {'expert_name': ModuleBackend(**kwargs)} + >>> runtime = Runtime(module_backends) >>> runtime.start() # start runtime in background thread. To start in current thread, use runtime.run() >>> runtime.ready.wait() # await for runtime to load all experts on device and create request pools - >>> future = runtime.backends['expert_name'].forward_pool.submit_task(*expert_inputs) + >>> future = runtime.module_backends['expert_name'].forward_pool.submit_task(*module_inputs) >>> print("Returned:", future.result()) >>> runtime.shutdown() - :param backends: a dict [expert uid -> ModuleBackend] + :param module_backends: a dict [expert uid -> ModuleBackend] :param prefetch_batches: form up to this many batches in advance :param sender_threads: dispatches outputs from finished batches using this many asynchronous threads :param device: if specified, moves all experts and data to this device via .to(device=device). @@ -46,15 +46,15 @@ class Runtime(threading.Thread): def __init__( self, - backends: Dict[str, ModuleBackend], + module_backends: Dict[str, ModuleBackend], prefetch_batches=64, sender_threads: int = 1, device: torch.device = None, stats_report_interval: Optional[int] = None, ): super().__init__() - self.backends = backends - self.pools = tuple(chain(*(backend.get_pools() for backend in backends.values()))) + self.module_backends = module_backends + self.pools = tuple(chain(*(backend.get_pools() for backend in module_backends.values()))) self.device, self.prefetch_batches, self.sender_threads = device, prefetch_batches, sender_threads self.shutdown_recv, self.shutdown_send = mp.Pipe(duplex=False) self.shutdown_trigger = mp.Event() @@ -69,7 +69,7 @@ def run(self): if not pool.is_alive(): pool.start() if self.device is not None: - for backend in self.backends.values(): + for backend in self.module_backends.values(): backend.module.to(self.device) with mp.pool.ThreadPool(self.sender_threads) as output_sender_pool: diff --git a/hivemind/moe/server/server.py b/hivemind/moe/server/server.py index 0873701b4..82de71ee2 100644 --- a/hivemind/moe/server/server.py +++ b/hivemind/moe/server/server.py @@ -42,7 +42,7 @@ class Server(threading.Thread): - publishes updates to expert status every :update_period: seconds :type dht: an instance of hivemind.DHT. Server will use DHT for all network interactions. - :param backends: dict{expert uid (str) : ModuleBackend} for all expert hosted by this server. + :param module_backends: dict{expert uid (str) : ModuleBackend} for all expert hosted by this server. :param num_connection_handlers: maximum number of simultaneous requests. Please note that the default value of 1 if too small for normal functioning, we recommend 4 handlers per expert backend. :param update_period: how often will server attempt to publish its state (i.e. experts) to the DHT; @@ -55,7 +55,7 @@ class Server(threading.Thread): def __init__( self, dht: DHT, - backends: Dict[str, ModuleBackend], + module_backends: Dict[str, ModuleBackend], num_connection_handlers: int = 1, update_period: float = 30, expiration: Optional[float] = None, @@ -64,18 +64,18 @@ def __init__( **kwargs, ): super().__init__() - self.dht, self.backends, self.update_period = dht, backends, update_period + self.dht, self.module_backends, self.update_period = dht, module_backends, update_period - self.conn_handlers = [ConnectionHandler(dht, self.backends) for _ in range(num_connection_handlers)] + self.conn_handlers = [ConnectionHandler(dht, self.module_backends) for _ in range(num_connection_handlers)] if checkpoint_dir is not None: - self.checkpoint_saver = CheckpointSaver(backends, checkpoint_dir, update_period) + self.checkpoint_saver = CheckpointSaver(module_backends, checkpoint_dir, update_period) else: self.checkpoint_saver = None - self.runtime = Runtime(self.backends, **kwargs) + self.runtime = Runtime(self.module_backends, **kwargs) - if self.backends: + if self.module_backends: self.dht_handler_thread = DHTHandlerThread( - backends=self.backends, + module_backends=self.module_backends, dht=self.dht, update_period=self.update_period, expiration=expiration, @@ -233,15 +233,15 @@ def run(self): Starts Server in the current thread. Initializes dht if necessary, starts connection handlers, runs Runtime (self.runtime) to process incoming requests. """ - logger.info(f"Server started with {len(self.backends)} modules:") - for expert_name, backend in self.backends.items(): + logger.info(f"Server started with {len(self.module_backends)} modules:") + for expert_name, backend in self.module_backends.items(): num_parameters = sum(p.numel() for p in backend.module.parameters() if p.requires_grad) logger.info(f"{expert_name}: {backend.module.__class__.__name__}, {num_parameters} parameters") if not self.dht.is_alive(): self.dht.run_in_background(await_ready=True) - if self.backends: + if self.module_backends: self.dht_handler_thread.start() if self.checkpoint_saver is not None: @@ -292,7 +292,7 @@ def shutdown(self): process.join() logger.debug("Connection handlers terminated") - if self.backends: + if self.module_backends: self.dht_handler_thread.stop.set() self.dht_handler_thread.join() From add83b5416d7b7d40f20749db0dd5e31bed8347f Mon Sep 17 00:00:00 2001 From: justheuristic Date: Tue, 14 Jun 2022 19:51:36 +0300 Subject: [PATCH 22/33] Update hivemind/moe/server/layers/optim.py Co-authored-by: Max Ryabinin --- hivemind/moe/server/layers/optim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hivemind/moe/server/layers/optim.py b/hivemind/moe/server/layers/optim.py index f94f911b1..f6b62362a 100644 --- a/hivemind/moe/server/layers/optim.py +++ b/hivemind/moe/server/layers/optim.py @@ -46,7 +46,7 @@ def add_param_group(self, param_group: dict) -> None: class ClippingWrapper(OptimizerWrapper): - """A wrapper to pytorch.optimizer that clips gradients by global norm before each step""" + """A wrapper of torch.Optimizer that clips gradients by global norm before each step""" def __init__(self, optim: torch.optim.Optimizer, clip_grad_norm: float): super().__init__(optim) From 9664d05780dcc411cababd96bdc6285b976c9bbd Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Tue, 14 Jun 2022 19:52:33 +0300 Subject: [PATCH 23/33] fix import --- hivemind/moe/server/module_backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hivemind/moe/server/module_backend.py b/hivemind/moe/server/module_backend.py index 901741bc0..c431b5ebd 100644 --- a/hivemind/moe/server/module_backend.py +++ b/hivemind/moe/server/module_backend.py @@ -4,11 +4,11 @@ from torch import nn from hivemind.moe.server.task_pool import TaskPool -from hivemind.optim.state_averager import LRSchedulerBase from hivemind.utils.logging import get_logger from hivemind.utils.nested import nested_compare, nested_flatten, nested_map, nested_pack from hivemind.utils.tensor_descr import DUMMY_BATCH_SIZE, BatchTensorDescriptor +LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None) logger = get_logger(__name__) From 7aed0a82a87a510a291a056fa8217037cd9cd0ee Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Tue, 14 Jun 2022 19:56:37 +0300 Subject: [PATCH 24/33] review --- hivemind/moe/server/module_backend.py | 1 + 1 file changed, 1 insertion(+) diff --git a/hivemind/moe/server/module_backend.py b/hivemind/moe/server/module_backend.py index c431b5ebd..a890a3951 100644 --- a/hivemind/moe/server/module_backend.py +++ b/hivemind/moe/server/module_backend.py @@ -59,6 +59,7 @@ def __init__( f"Module must take at least one positional or keyword input." " Did you forget to provide args_schema/kwargs_schema?" ) + assert optimizer is not None or scheduler is None, "scheduler should only be used if optimizer is not None" if outputs_schema is None: # run expert once to get outputs schema From fa48f2c9aba6693e8ffde1993deda68593c3516b Mon Sep 17 00:00:00 2001 From: justheuristic Date: Tue, 14 Jun 2022 19:57:44 +0300 Subject: [PATCH 25/33] review --- hivemind/moe/server/module_backend.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/hivemind/moe/server/module_backend.py b/hivemind/moe/server/module_backend.py index a890a3951..38ad2a1d3 100644 --- a/hivemind/moe/server/module_backend.py +++ b/hivemind/moe/server/module_backend.py @@ -160,9 +160,7 @@ def on_backward(self, batch_size: int) -> None: self.scheduler.step() def state_dict(self) -> Dict: - """ - Return the current state of the expert (including batch processing statistics) - """ + """ Return the current state of the module, optimizer, and scheduler """ full_state = dict(module=self.module.state_dict()) if self.optimizer is not None: full_state["optimizer"] = self.optimizer.state_dict() From 5601a95d36e1093bd82f899c1bccf5a3c7245307 Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Tue, 14 Jun 2022 20:02:22 +0300 Subject: [PATCH 26/33] review --- hivemind/moe/server/layers/optim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hivemind/moe/server/layers/optim.py b/hivemind/moe/server/layers/optim.py index f6b62362a..988928bbf 100644 --- a/hivemind/moe/server/layers/optim.py +++ b/hivemind/moe/server/layers/optim.py @@ -4,8 +4,8 @@ class OptimizerWrapper(torch.optim.Optimizer): """A wrapper for pytorch.optim.Optimizer that forwards all methods to the wrapped optimizer""" + # noinspection PyMissingConstructor def __init__(self, optim: torch.optim.Optimizer): - object.__init__(self) self.optim = optim @property From decf1b4c11ec6fda88035d0966a28c6e125d9108 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Tue, 14 Jun 2022 20:03:37 +0300 Subject: [PATCH 27/33] Update hivemind/moe/server/server.py Co-authored-by: Max Ryabinin --- hivemind/moe/server/server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hivemind/moe/server/server.py b/hivemind/moe/server/server.py index 82de71ee2..0e2af4dfc 100644 --- a/hivemind/moe/server/server.py +++ b/hivemind/moe/server/server.py @@ -34,7 +34,7 @@ class Server(threading.Thread): """ - Server allows you to host "experts" - pytorch subnetworks that can be accessed by your peers in the swarm. + Server allows you to host "experts" - pytorch subnetworks that can be accessed remotely by peers. After creation, a server should be started: see Server.run or Server.run_in_background. A working server does two things: From 409e035553760eeb6b4fcaf85cc52c412e6efc66 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Tue, 14 Jun 2022 20:06:01 +0300 Subject: [PATCH 28/33] review --- hivemind/moe/server/layers/optim.py | 1 + 1 file changed, 1 insertion(+) diff --git a/hivemind/moe/server/layers/optim.py b/hivemind/moe/server/layers/optim.py index 988928bbf..d0248c78c 100644 --- a/hivemind/moe/server/layers/optim.py +++ b/hivemind/moe/server/layers/optim.py @@ -6,6 +6,7 @@ class OptimizerWrapper(torch.optim.Optimizer): # noinspection PyMissingConstructor def __init__(self, optim: torch.optim.Optimizer): + super().__init__(optim.param_groups, optim.defaults) self.optim = optim @property From dd6fc946520bd4d02602cdf7d91c542348e2e762 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Tue, 14 Jun 2022 20:12:16 +0300 Subject: [PATCH 29/33] black-isort --- hivemind/moe/server/__init__.py | 2 +- hivemind/moe/server/dht_handler.py | 4 +++- hivemind/moe/server/module_backend.py | 2 +- hivemind/moe/server/server.py | 2 +- 4 files changed, 6 insertions(+), 4 deletions(-) diff --git a/hivemind/moe/server/__init__.py b/hivemind/moe/server/__init__.py index 7d5504608..b370ffbff 100644 --- a/hivemind/moe/server/__init__.py +++ b/hivemind/moe/server/__init__.py @@ -1,4 +1,4 @@ from hivemind.moe.server.dht_handler import declare_experts, get_experts -from hivemind.moe.server.module_backend import ModuleBackend from hivemind.moe.server.layers import register_expert_class +from hivemind.moe.server.module_backend import ModuleBackend from hivemind.moe.server.server import Server, background_server diff --git a/hivemind/moe/server/dht_handler.py b/hivemind/moe/server/dht_handler.py index eb9c64389..e5cbb1935 100644 --- a/hivemind/moe/server/dht_handler.py +++ b/hivemind/moe/server/dht_handler.py @@ -20,7 +20,9 @@ class DHTHandlerThread(threading.Thread): - def __init__(self, module_backends, dht: DHT, update_period: float = 30, expiration: Optional[int] = None, **kwargs): + def __init__( + self, module_backends, dht: DHT, update_period: float = 30, expiration: Optional[int] = None, **kwargs + ): super().__init__(**kwargs) if expiration is None: expiration = max(2 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS) diff --git a/hivemind/moe/server/module_backend.py b/hivemind/moe/server/module_backend.py index 38ad2a1d3..8f6e1ba63 100644 --- a/hivemind/moe/server/module_backend.py +++ b/hivemind/moe/server/module_backend.py @@ -160,7 +160,7 @@ def on_backward(self, batch_size: int) -> None: self.scheduler.step() def state_dict(self) -> Dict: - """ Return the current state of the module, optimizer, and scheduler """ + """Return the current state of the module, optimizer, and scheduler""" full_state = dict(module=self.module.state_dict()) if self.optimizer is not None: full_state["optimizer"] = self.optimizer.state_dict() diff --git a/hivemind/moe/server/server.py b/hivemind/moe/server/server.py index 0e2af4dfc..f4d7d7a77 100644 --- a/hivemind/moe/server/server.py +++ b/hivemind/moe/server/server.py @@ -15,7 +15,6 @@ from hivemind.moe.server.checkpoints import CheckpointSaver, is_directory, load_experts from hivemind.moe.server.connection_handler import ConnectionHandler from hivemind.moe.server.dht_handler import DHTHandlerThread, get_experts -from hivemind.moe.server.module_backend import ModuleBackend from hivemind.moe.server.layers import ( add_custom_models_from_file, name_to_block, @@ -23,6 +22,7 @@ schedule_name_to_scheduler, ) from hivemind.moe.server.layers.optim import ClippingWrapper +from hivemind.moe.server.module_backend import ModuleBackend from hivemind.moe.server.runtime import Runtime from hivemind.p2p import PeerInfo from hivemind.proto.runtime_pb2 import CompressionType From a9b7643272fee8ab01c3e8f237932033c64624cb Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Wed, 15 Jun 2022 12:26:06 +0300 Subject: [PATCH 30/33] review --- hivemind/moe/server/module_backend.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/hivemind/moe/server/module_backend.py b/hivemind/moe/server/module_backend.py index 8f6e1ba63..d6fe4d722 100644 --- a/hivemind/moe/server/module_backend.py +++ b/hivemind/moe/server/module_backend.py @@ -81,6 +81,9 @@ def forward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]: Apply forward pass to an aggregated batch of requests. Used by Runtime, do not call this manually; To submit a request for asynchronous processing, please use ``ModuleBackend.forward_pool.submit_task``. + .. warning: if the underlying module performs non-gradient updates (e.g. batchnorm), it will be updated twice: + once during forward pass, and again during backward. This behavior is similar to gradient checkpointing. + Subclassing: This method receives a sequence of torch tensors following ``nested_flatten(self.forward_schema)``; It should return gradients w.r.t. inputs that follow ``nested_flatten(self.outputs_schema)``; From efeb31b89b97800e8a749009f2989ebe1efa6303 Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Wed, 15 Jun 2022 12:58:35 +0300 Subject: [PATCH 31/33] review --- hivemind/moe/server/layers/optim.py | 1 - 1 file changed, 1 deletion(-) diff --git a/hivemind/moe/server/layers/optim.py b/hivemind/moe/server/layers/optim.py index d0248c78c..c4f1b8d6b 100644 --- a/hivemind/moe/server/layers/optim.py +++ b/hivemind/moe/server/layers/optim.py @@ -4,7 +4,6 @@ class OptimizerWrapper(torch.optim.Optimizer): """A wrapper for pytorch.optim.Optimizer that forwards all methods to the wrapped optimizer""" - # noinspection PyMissingConstructor def __init__(self, optim: torch.optim.Optimizer): super().__init__(optim.param_groups, optim.defaults) self.optim = optim From 04af589b98214f8a45a83fb8a3bc0272952b74c3 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Wed, 15 Jun 2022 13:00:21 +0300 Subject: [PATCH 32/33] review --- hivemind/moe/server/layers/optim.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/hivemind/moe/server/layers/optim.py b/hivemind/moe/server/layers/optim.py index c4f1b8d6b..f280ba427 100644 --- a/hivemind/moe/server/layers/optim.py +++ b/hivemind/moe/server/layers/optim.py @@ -56,8 +56,3 @@ def step(self, *args, **kwargs): parameters = tuple(param for group in self.param_groups for param in group["params"]) torch.nn.utils.clip_grad_norm_(parameters, self.clip_grad_norm) return super().step(*args, **kwargs) - - @classmethod - def create(cls, optim_cls: type, *args, clip_grad_norm: float, **kwargs): - """Create a wrapped optimizer and wrap it with clipping""" - return cls(optim=optim_cls(*args, **kwargs), clip_grad_norm=clip_grad_norm) From e05d3dcc30be96d01692cd93f9ffbe86a01732bc Mon Sep 17 00:00:00 2001 From: justheuristic Date: Wed, 15 Jun 2022 13:02:41 +0300 Subject: [PATCH 33/33] review --- hivemind/moe/server/module_backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hivemind/moe/server/module_backend.py b/hivemind/moe/server/module_backend.py index d6fe4d722..f6260371a 100644 --- a/hivemind/moe/server/module_backend.py +++ b/hivemind/moe/server/module_backend.py @@ -112,7 +112,7 @@ def backward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]: Runtime doesn't guarantee that backward will be performed in the same order and for the same data as forward, so we recommend stateless backward pass that re-runs expert forward pass inside backward. - Please make sure to call ``ModuleBackend.apply_gradients`` here, otherwise the expert will not train + Please make sure to call ``ModuleBackend.on_backward`` after each call to backward """ (args, kwargs), grad_outputs = nested_pack(inputs, structure=self.backward_schema)