Skip to content

Commit

Permalink
Implement a CLI for hivemind.DHT (#465)
Browse files Browse the repository at this point in the history
* Implement a CLI for hivemind.DHT

* Fix log message in README

* Update examples/albert/README.md

* Add a basic test for hivemind-dht

* Move log_visible_maddrs to hivemind.utils.networking

Co-authored-by: Michael Diskin <[email protected]>
Co-authored-by: Alexander Borzunov <[email protected]>
  • Loading branch information
3 people authored Jun 5, 2022
1 parent 724cdfe commit c49802a
Show file tree
Hide file tree
Showing 12 changed files with 177 additions and 33 deletions.
4 changes: 2 additions & 2 deletions examples/albert/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ Run the first DHT peer to welcome trainers and record training statistics (e.g.,

```
$ ./run_training_monitor.py --wandb_project Demo-run
Oct 14 16:26:36.083 [INFO] Running a DHT peer. To connect other peers to this one over the Internet,
use --initial_peers /ip4/1.2.3.4/tcp/1337/p2p/XXXX /ip4/1.2.3.4/udp/31337/quic/p2p/XXXX
Oct 14 16:26:36.083 [INFO] Running a DHT instance. To connect other peers to this one, use
--initial_peers /ip4/1.2.3.4/tcp/1337/p2p/XXXX /ip4/1.2.3.4/udp/31337/quic/p2p/XXXX
Oct 14 16:26:36.083 [INFO] Full list of visible multiaddresses: ...
wandb: Currently logged in as: XXX (use `wandb login --relogin` to force relogin)
wandb: Tracking run with wandb version 0.10.32
Expand Down
3 changes: 2 additions & 1 deletion examples/albert/run_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from hivemind import DHT, Float16Compression, Optimizer, get_dht_time
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from hivemind.utils.networking import log_visible_maddrs

import utils
from arguments import (
Expand Down Expand Up @@ -227,7 +228,7 @@ def main():
announce_maddrs=collaboration_args.announce_maddrs,
identity_path=collaboration_args.identity_path,
)
utils.log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=collaboration_args.use_ipfs)
log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=collaboration_args.use_ipfs)

total_batch_size_per_step = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps
if torch.cuda.device_count() != 0:
Expand Down
3 changes: 2 additions & 1 deletion examples/albert/run_training_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import hivemind
from hivemind.optim.state_averager import TrainingStateAverager
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from hivemind.utils.networking import log_visible_maddrs

import utils
from arguments import AveragerArguments, BaseTrainingArguments, OptimizerArguments
Expand Down Expand Up @@ -168,7 +169,7 @@ def upload_checkpoint(self, current_loss):
announce_maddrs=monitor_args.announce_maddrs,
identity_path=monitor_args.identity_path,
)
utils.log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=monitor_args.use_ipfs)
log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=monitor_args.use_ipfs)

if monitor_args.wandb_project is not None:
wandb.init(project=monitor_args.wandb_project)
Expand Down
24 changes: 1 addition & 23 deletions examples/albert/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
from typing import Dict, List, Tuple

from multiaddr import Multiaddr
from pydantic import BaseModel, StrictFloat, confloat, conint

from hivemind import choose_ip_address
from hivemind.dht.crypto import RSASignatureValidator
from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
from hivemind.dht.validation import RecordValidatorBase
from hivemind.utils.logging import TextStyle, get_logger
from hivemind.utils.logging import get_logger

logger = get_logger(__name__)

Expand All @@ -28,23 +26,3 @@ def make_validators(run_id: str) -> Tuple[List[RecordValidatorBase], bytes]:
signature_validator = RSASignatureValidator()
validators = [SchemaValidator(MetricSchema, prefix=run_id), signature_validator]
return validators, signature_validator.local_public_key


def log_visible_maddrs(visible_maddrs: List[Multiaddr], only_p2p: bool) -> None:
if only_p2p:
unique_addrs = {addr["p2p"] for addr in visible_maddrs}
initial_peers_str = " ".join(f"/p2p/{addr}" for addr in unique_addrs)
else:
available_ips = [Multiaddr(addr) for addr in visible_maddrs if "ip4" in addr or "ip6" in addr]
if available_ips:
preferred_ip = choose_ip_address(available_ips)
selected_maddrs = [addr for addr in visible_maddrs if preferred_ip in str(addr)]
else:
selected_maddrs = visible_maddrs
initial_peers_str = " ".join(str(addr) for addr in selected_maddrs)

logger.info(
f"Running a DHT peer. To connect other peers to this one over the Internet, use "
f"{TextStyle.BOLD}{TextStyle.BLUE}--initial_peers {initial_peers_str}{TextStyle.RESET}"
)
logger.info(f"Full list of visible multiaddresses: {' '.join(str(addr) for addr in visible_maddrs)}")
2 changes: 1 addition & 1 deletion hivemind/dht/dht.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def run_coroutine(
DHT fields made by this coroutine will not be accessible from the host process.
:note: all time-consuming operations in coro should be asynchronous (e.g. asyncio.sleep instead of time.sleep)
or use asyncio.get_event_loop().run_in_executor(...) to prevent coroutine from blocking background DHT tasks
:note: when run_coroutine is called with wait=False, MPFuture can be cancelled to interrupt the task.
:note: when run_coroutine is called with return_future=False, MPFuture can be cancelled to interrupt the task.
"""
future = MPFuture()
self._outer_pipe.send(("_run_coroutine", [], dict(coro=coro, future=future)))
Expand Down
4 changes: 2 additions & 2 deletions hivemind/dht/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ async def create(
:param cache_locally: if True, caches all values (stored or found) in a node-local cache
:param cache_on_store: if True, update cache entries for a key after storing a new item for that key
:param cache_nearest: whenever DHTNode finds a value, it will also store (cache) this value on this many
nodes nearest nodes visited by search algorithm. Prefers nodes that are nearest to :key: but have no value yet
nearest nodes visited by search algorithm. Prefers nodes that are nearest to :key: but have no value yet
:param cache_size: if specified, local cache will store up to this many records (as in LRU cache)
:param cache_refresh_before_expiry: if nonzero, refreshes locally cached values
if they are accessed this many seconds before expiration time.
Expand Down Expand Up @@ -341,7 +341,7 @@ async def store(
) -> bool:
"""
Find num_replicas best nodes to store (key, value) and store it there at least until expiration time.
:note: store is a simplified interface to store_many, all kwargs are be forwarded there
:note: store is a simplified interface to store_many, all kwargs are forwarded there
:returns: True if store succeeds, False if it fails (due to no response or newer value)
"""
store_ok = await self.store_many([key], [value], [expiration_time], subkeys=[subkey], **kwargs)
Expand Down
76 changes: 76 additions & 0 deletions hivemind/hivemind_cli/run_dht.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import time
from argparse import ArgumentParser

from hivemind.dht import DHT, DHTNode
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from hivemind.utils.networking import log_visible_maddrs

use_hivemind_log_handler("in_root_logger")
logger = get_logger(__name__)


async def report_status(dht: DHT, node: DHTNode):
logger.info(
f"{len(node.protocol.routing_table.uid_to_peer_id) + 1} DHT nodes (including this one) "
f"are in the local routing table "
)
logger.debug(f"Routing table contents: {node.protocol.routing_table}")
logger.info(f"Local storage contains {len(node.protocol.storage)} keys")
logger.debug(f"Local storage contents: {node.protocol.storage}")


def main():
parser = ArgumentParser()
parser.add_argument(
"--initial_peers",
nargs="*",
help="Multiaddrs of the peers that will welcome you into the existing DHT. "
"Example: /ip4/203.0.113.1/tcp/31337/p2p/XXXX /ip4/203.0.113.2/tcp/7777/p2p/YYYY",
)
parser.add_argument(
"--host_maddrs",
nargs="*",
default=["/ip4/0.0.0.0/tcp/0"],
help="Multiaddrs to listen for external connections from other DHT instances. "
"Defaults to all IPv4 interfaces and the TCP protocol: /ip4/0.0.0.0/tcp/0",
)
parser.add_argument(
"--announce_maddrs",
nargs="*",
help="Visible multiaddrs the host announces for external connections from other DHT instances",
)
parser.add_argument(
"--use_ipfs",
action="store_true",
help='Use IPFS to find initial_peers. If enabled, you only need to provide the "/p2p/XXXX" '
"part of the multiaddrs for the initial_peers "
"(no need to specify a particular IPv4/IPv6 host and port)",
)
parser.add_argument(
"--identity_path",
help="Path to a private key file. If defined, makes the peer ID deterministic. "
"If the file does not exist, writes a new private key to this file.",
)
parser.add_argument(
"--refresh_period", type=int, default=30, help="Period (in seconds) for fetching the keys from DHT"
)

args = parser.parse_args()

dht = DHT(
start=True,
initial_peers=args.initial_peers,
host_maddrs=args.host_maddrs,
announce_maddrs=args.announce_maddrs,
use_ipfs=args.use_ipfs,
identity_path=args.identity_path,
)
log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=args.use_ipfs)

while True:
dht.run_coroutine(report_status, return_future=False)
time.sleep(args.refresh_period)


if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion hivemind/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from hivemind.utils.mpfuture import *
from hivemind.utils.nested import *
from hivemind.utils.networking import *
from hivemind.utils.networking import get_free_port, log_visible_maddrs
from hivemind.utils.performance_ema import PerformanceEMA
from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
from hivemind.utils.streaming import combine_from_streaming, split_for_streaming
Expand Down
26 changes: 25 additions & 1 deletion hivemind/utils/networking.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import socket
from contextlib import closing
from ipaddress import ip_address
from typing import Sequence
from typing import List, Sequence

from multiaddr import Multiaddr

from hivemind.utils.logging import TextStyle, get_logger

LOCALHOST = "127.0.0.1"

logger = get_logger(__name__)


def get_free_port(params=(socket.AF_INET, socket.SOCK_STREAM), opt=(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)):
"""
Expand Down Expand Up @@ -52,3 +56,23 @@ def choose_ip_address(
return value_for_protocol

raise ValueError(f"No IP address found among given multiaddrs: {maddrs}")


def log_visible_maddrs(visible_maddrs: List[Multiaddr], only_p2p: bool) -> None:
if only_p2p:
unique_addrs = {addr["p2p"] for addr in visible_maddrs}
initial_peers = " ".join(f"/p2p/{addr}" for addr in unique_addrs)
else:
available_ips = [Multiaddr(addr) for addr in visible_maddrs if "ip4" in addr or "ip6" in addr]
if available_ips:
preferred_ip = choose_ip_address(available_ips)
selected_maddrs = [addr for addr in visible_maddrs if preferred_ip in str(addr)]
else:
selected_maddrs = visible_maddrs
initial_peers = " ".join(str(addr) for addr in selected_maddrs)

logger.info(
f"Running a DHT instance. To connect other peers to this one, use "
f"{TextStyle.BOLD}{TextStyle.BLUE}--initial_peers {initial_peers}{TextStyle.RESET}"
)
logger.info(f"Full list of visible multiaddresses: {' '.join(str(addr) for addr in visible_maddrs)}")
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def run(self):
],
entry_points={
"console_scripts": [
"hivemind-dht = hivemind.hivemind_cli.run_dht:main",
"hivemind-server = hivemind.hivemind_cli.run_server:main",
]
},
Expand Down
63 changes: 63 additions & 0 deletions tests/test_cli_scripts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import re
from subprocess import PIPE, Popen
from time import sleep

DHT_START_PATTERN = re.compile(r"Running a DHT instance. To connect other peers to this one, use (.+)$")


def test_dht_connection_successful():
dht_refresh_period = 1

dht_proc = Popen(
["hivemind-dht", "--host_maddrs", "/ip4/127.0.0.1/tcp/0", "--refresh_period", str(dht_refresh_period)],
stderr=PIPE,
text=True,
encoding="utf-8",
)

first_line = dht_proc.stderr.readline()
second_line = dht_proc.stderr.readline()
dht_pattern_match = DHT_START_PATTERN.search(first_line)
assert dht_pattern_match is not None, first_line
assert "Full list of visible multiaddresses:" in second_line, second_line

initial_peers = dht_pattern_match.group(1).split(" ")

dht_client_proc = Popen(
["hivemind-dht", *initial_peers, "--host_maddrs", "/ip4/127.0.0.1/tcp/0"],
stderr=PIPE,
text=True,
encoding="utf-8",
)

# skip first two lines with connectivity info
for _ in range(2):
dht_client_proc.stderr.readline()
first_report_msg = dht_client_proc.stderr.readline()

assert "2 DHT nodes (including this one) are in the local routing table" in first_report_msg

# ensure we get the output of dht_proc after the start of dht_client_proc
sleep(dht_refresh_period)

# expect that one of the next logging outputs from the first peer shows a new connection
for _ in range(5):
first_report_msg = dht_proc.stderr.readline()
second_report_msg = dht_proc.stderr.readline()

if (
"2 DHT nodes (including this one) are in the local routing table" in first_report_msg
and "Local storage contains 0 keys" in second_report_msg
):
break
else:
assert (
"2 DHT nodes (including this one) are in the local routing table" in first_report_msg
and "Local storage contains 0 keys" in second_report_msg
)

dht_proc.terminate()
dht_client_proc.terminate()

dht_proc.wait()
dht_client_proc.wait()
2 changes: 1 addition & 1 deletion tests/test_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import random
from itertools import chain, zip_longest

from hivemind import LOCALHOST
from hivemind.dht.routing import DHTID, RoutingTable
from hivemind.utils.networking import LOCALHOST


def test_ids_basic():
Expand Down

0 comments on commit c49802a

Please sign in to comment.