Skip to content

Commit

Permalink
Fix docs and formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
mryab committed Jul 3, 2021
1 parent 7152350 commit be17da5
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 28 deletions.
12 changes: 6 additions & 6 deletions benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,18 +115,18 @@ def benchmark_throughput(num_experts=16, num_handlers=None, num_clients=128, num
total_examples = batch_size * num_clients * num_batches_per_client

logger.info("Benchmark finished, status:" + ["Success", "Failure"][benchmarking_failed.is_set()])
logger.info(f"Server parameters: num_experts={num_experts}, num_handlers={num_handlers}, max_batch_size={max_batch_size},"
f" expert_cls={expert_cls}, hid_dim={hid_dim}, device={device}")
logger.info(f"Server parameters: num_experts={num_experts}, num_handlers={num_handlers}, "
f"max_batch_size={max_batch_size}, expert_cls={expert_cls}, hid_dim={hid_dim}, device={device}")
logger.info(f"Client parameters: num_clients={num_clients}, num_batches_per_client={num_batches_per_client}, "
f"batch_size={batch_size}, backprop={backprop}")
f"batch_size={batch_size}, backprop={backprop}")

logger.info("Results: ")
logger.info(f"\tServer startup took {time_between('began_launching_server', 'server_ready') :.3f} s. "
f"({time_between('began_launching_server', 'created_experts') :.3f} s. experts + "
f"{time_between('created_experts', 'server_ready') :.3f} s. networking)")
f"({time_between('began_launching_server', 'created_experts') :.3f} s. experts + "
f"{time_between('created_experts', 'server_ready') :.3f} s. networking)")
logger.info(f"\tProcessed {total_examples} examples in {time_between('server_ready', 'clients_finished') :.3f}")
logger.info(f"\tThroughput for {'forward + backward' if backprop else 'forward'} passes: "
f"{total_examples / time_between('server_ready', 'clients_finished') :.3f} samples / s.")
f"{total_examples / time_between('server_ready', 'clients_finished') :.3f} samples / s.")
logger.info(f"\tBenchmarking took {time_between('started', 'server_shutdown_finished') :.3f} s.")
if benchmarking_failed.is_set():
logger.info("Note: benchmark code failed, timing/memory results only indicate time till failure!")
Expand Down
2 changes: 1 addition & 1 deletion docs/user/quickstart.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ You can also install it in the editable mode with `pip install -e .`.

## Host a server

`hivemind.moe.server` hosts one or several experts (PyTorch modules) for remote access. These experts are responsible for
`hivemind.moe.Server` hosts one or several experts (PyTorch modules) for remote access. These experts are responsible for
most of the model parameters and computation. The server can be started using either Python or
[a shell script](https://github.com/learning-at-home/hivemind/blob/master/hivemind/hivemind_cli/run_server.py). We'll use the shell
for now. To host a server with default experts, run this in your shell:
Expand Down
2 changes: 1 addition & 1 deletion hivemind/averaging/matchmaking.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ async def pop_next_leader(self) -> Endpoint:
self.update_triggered.set()

if maybe_next_leader is None or (entry.expiration_time, maybe_next_leader) > (
self.declared_expiration_time, self.endpoint):
self.declared_expiration_time, self.endpoint):
await asyncio.wait({self.update_finished.wait(), self.declared_expiration.wait()},
return_when=asyncio.FIRST_COMPLETED)
self.declared_expiration.clear()
Expand Down
7 changes: 3 additions & 4 deletions hivemind/dht/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,7 @@ async def create(

self.protocol = await DHTProtocol.create(self.node_id, bucket_size, depth_modulo, num_replicas, wait_timeout,
parallel_rpc, cache_size, listen, listen_on, endpoint,
record_validator,
**kwargs)
record_validator, **kwargs)
self.port = self.protocol.port

if initial_peers:
Expand Down Expand Up @@ -362,8 +361,8 @@ async def on_found(key_id: DHTID, nearest_nodes: List[DHTID], visited_nodes: Set
try:
await asyncio.gather(store_task, *(evt.wait() for evt in store_finished_events.values()))
assert len(unfinished_key_ids) == 0, "Internal error: traverse_dht didn't finish search"
return {(key, subkey) if subkey is not None else key: status or False for (key, subkey), status in
store_ok.items()}
return {(key, subkey) if subkey is not None else key: status or False
for (key, subkey), status in store_ok.items()}
except asyncio.CancelledError as e:
store_task.cancel()
raise e
Expand Down
6 changes: 3 additions & 3 deletions hivemind/dht/storage.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from __future__ import annotations

from typing import Optional, Union

from hivemind.dht.routing import DHTID, BinaryDHTValue, Subkey
from hivemind.utils.serializer import MSGPackSerializer
from hivemind.utils.timed_storage import KeyType, ValueType, TimedStorage,DHTExpiration
from hivemind.utils.timed_storage import KeyType, ValueType, TimedStorage, DHTExpiration


@MSGPackSerializer.ext_serializable(0x50)
Expand Down Expand Up @@ -32,6 +33,7 @@ def unpackb(cls, raw: bytes) -> DictionaryDHTValue:

class DHTLocalStorage(TimedStorage[DHTID, Union[BinaryDHTValue, DictionaryDHTValue]]):
""" A dictionary-like storage that can store binary values and/or nested dictionaries until expiration """

def store(self, key: DHTID, value: BinaryDHTValue, expiration_time: DHTExpiration,
subkey: Optional[Subkey] = None) -> bool:
"""
Expand Down Expand Up @@ -63,5 +65,3 @@ def store_subkey(self, key: DHTID, subkey: Subkey, value: BinaryDHTValue, expira
return previous_value.store(subkey, value, expiration_time)
else:
return False


22 changes: 9 additions & 13 deletions tests/test_averaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def _test_allreduce_once(n_clients, n_aux):

n_peers = 4
modes = [AveragingMode.CLIENT] * n_clients + [AveragingMode.AUX] * n_aux + [AveragingMode.NODE] * (
n_peers - n_clients - n_aux)
n_peers - n_clients - n_aux)
random.shuffle(modes)

tensors1 = [torch.randn(123), torch.zeros(3)]
Expand Down Expand Up @@ -110,12 +110,10 @@ def test_allreduce_weighted(n_client_mode_peers: int = 2):
tensors2 = [torch.rand(123), torch.ones(3)]
tensors3 = [-torch.rand(123), torch.arange(3).to(torch.float32)]
tensors4 = [torch.randn(123) ** 3, torch.arange(3).to(torch.float32) / 2]
averagers = [hivemind.averaging.DecentralizedAverager(tensors, dht=dht, target_group_size=4,
averaging_expiration=15,
prefix='mygroup', listen=listen,
listen_on='127.0.0.1:*',
start=True)
for tensors, listen in zip([tensors1, tensors2, tensors3, tensors4], should_listen)]
averagers = [
hivemind.averaging.DecentralizedAverager(tensors, dht=dht, target_group_size=4, averaging_expiration=15,
prefix='mygroup', listen=listen, listen_on='127.0.0.1:*', start=True)
for tensors, listen in zip([tensors1, tensors2, tensors3, tensors4], should_listen)]
weights = list(map(float, np.random.rand(len(averagers)) * 10 + 0.01))
reference = [(tensors1[i] * weights[0] + tensors2[i] * weights[1] + tensors3[i] * weights[2]
+ tensors4[i] * weights[3]) / sum(weights) for i in range(len(tensors1))]
Expand Down Expand Up @@ -150,8 +148,8 @@ def test_allreduce_compression():
for compression_type_pair in [(FLOAT16, FLOAT16), (FLOAT16, UINT8), (UINT8, FLOAT16), (UINT8, UINT8)]:
averager1 = hivemind.averaging.DecentralizedAverager([x.clone() for x in tensors1], dht=dht,
compression_type=compression_type_pair,
listen=False,
target_group_size=2, prefix='mygroup', start=True)
listen=False, target_group_size=2, prefix='mygroup',
start=True)
averager2 = hivemind.averaging.DecentralizedAverager([x.clone() for x in tensors2], dht=dht,
compression_type=compression_type_pair,
target_group_size=2, prefix='mygroup', start=True)
Expand Down Expand Up @@ -224,10 +222,8 @@ def test_allreduce_grid():
def test_allgather():
dht = hivemind.DHT(start=True, endpoint=f'{hivemind.LOCALHOST}:*')
averagers = [hivemind.averaging.DecentralizedAverager([torch.ones(1)], dht=dht, target_group_size=4,
averaging_expiration=15,
prefix='mygroup', initial_group_bits='000',
listen_on='127.0.0.1:*',
start=True)
averaging_expiration=15, prefix='mygroup',
initial_group_bits='000', listen_on='127.0.0.1:*', start=True)
for _ in range(8)]

futures = []
Expand Down

0 comments on commit be17da5

Please sign in to comment.