Skip to content

Commit

Permalink
Optimize and harden additional tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mryab committed Jul 14, 2024
1 parent 814b73d commit 0abe59c
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 29 deletions.
2 changes: 1 addition & 1 deletion tests/test_cli_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


def test_dht_connection_successful():
dht_refresh_period = 1
dht_refresh_period = 3

cloned_env = os.environ.copy()
# overriding the loglevel to prevent debug print statements
Expand Down
11 changes: 6 additions & 5 deletions tests/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,9 +316,9 @@ def test_client_anomaly_detection():
server.shutdown()


def _measure_coro_running_time(n_coros, elapsed_fut, counter):
def _measure_coro_running_time(n_coros, elapsed_fut, counter, coroutine_time):
async def coro():
await asyncio.sleep(0.1)
await asyncio.sleep(coroutine_time)
counter.value += 1

try:
Expand All @@ -337,20 +337,21 @@ async def coro():


@pytest.mark.forked
def test_remote_expert_worker_runs_coros_concurrently(n_processes=4, n_coros=10):
def test_remote_expert_worker_runs_coros_concurrently(n_processes=4, n_coros=10, coroutine_time=0.1):
processes = []
counter = mp.Value(ctypes.c_int64)
for i in range(n_processes):
elapsed_fut = MPFuture()
factory = threading.Thread if i % 2 == 0 else mp.Process # Test both threads and processes

proc = factory(target=_measure_coro_running_time, args=(n_coros, elapsed_fut, counter))
proc = factory(target=_measure_coro_running_time, args=(n_coros, elapsed_fut, counter, coroutine_time))
proc.start()
processes.append((proc, elapsed_fut))

for proc, elapsed_fut in processes:
# Ensure that the coroutines were run concurrently, not sequentially
assert elapsed_fut.result() < 0.2
expected_time = coroutine_time * 3 # from non-blocking calls + blocking call + some overhead
assert elapsed_fut.result() < expected_time
proc.join()

assert counter.value == n_processes * n_coros # Ensure all couroutines have finished
49 changes: 30 additions & 19 deletions tests/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
import multiprocessing as mp
import time
from functools import partial
from typing import List

import numpy as np
import pytest
import torch
import torch.nn as nn
import torch.nn.functional as F
from multiaddr import Multiaddr

import hivemind
from hivemind.averaging.control import AveragingStage
Expand Down Expand Up @@ -227,8 +229,10 @@ def test_progress_tracker():
finished_evt = mp.Event()
emas = mp.Array(ctypes.c_double, 5)

def run_worker(index: int, batch_size: int, period: float, **kwargs):
dht = hivemind.DHT(initial_peers=dht_root.get_visible_maddrs(), start=True)
root_maddrs = dht_root.get_visible_maddrs()

def run_worker(index: int, batch_size: int, step_time: float, initial_peers: List[Multiaddr]):
dht = hivemind.DHT(initial_peers=initial_peers, start=True)
tracker = ProgressTracker(
dht,
prefix,
Expand All @@ -238,18 +242,17 @@ def run_worker(index: int, batch_size: int, period: float, **kwargs):
default_refresh_period=0.2,
max_refresh_period=0.5,
private_key=RSAPrivateKey(),
**kwargs,
)
with tracker.pause_updates():
barrier.wait()
if index == 4:
delayed_start_evt.wait()

barrier.wait()
if index == 4:
delayed_start_evt.wait()

local_epoch = 2 if index == 4 else 0
samples_accumulated = 0
local_epoch = 2 if index == 4 else 0
samples_accumulated = 0

while True:
time.sleep(period)
time.sleep(step_time)
if finished_evt.is_set():
break

Expand All @@ -270,10 +273,10 @@ def run_worker(index: int, batch_size: int, period: float, **kwargs):
dht.shutdown()

workers = [
mp.Process(target=run_worker, kwargs=dict(index=1, batch_size=12, period=0.6)),
mp.Process(target=run_worker, kwargs=dict(index=2, batch_size=16, period=0.5)),
mp.Process(target=run_worker, kwargs=dict(index=3, batch_size=24, period=0.4)),
mp.Process(target=run_worker, kwargs=dict(index=4, batch_size=64, period=0.4)),
mp.Process(target=run_worker, kwargs=dict(index=1, batch_size=12, step_time=0.6, initial_peers=root_maddrs)),
mp.Process(target=run_worker, kwargs=dict(index=2, batch_size=16, step_time=0.5, initial_peers=root_maddrs)),
mp.Process(target=run_worker, kwargs=dict(index=3, batch_size=24, step_time=0.2, initial_peers=root_maddrs)),
mp.Process(target=run_worker, kwargs=dict(index=4, batch_size=64, step_time=0.2, initial_peers=root_maddrs)),
]
for worker in workers:
worker.start()
Expand Down Expand Up @@ -336,7 +339,7 @@ def run_worker(index: int, batch_size: int, period: float, **kwargs):
(False, True, True, True, True),
(False, True, True, False, True),
(True, False, False, False, False),
(True, True, False, False, False,),
(True, True, False, False, False),
],
# fmt: on
)
Expand All @@ -359,6 +362,8 @@ def test_optimizer(
def _test_optimizer(
num_peers: int = 1,
num_clients: int = 0,
default_batch_size: int = 4,
default_batch_time: int = 0.1,
target_batch_size: int = 32,
total_epochs: int = 3,
use_local_updates: bool = False,
Expand Down Expand Up @@ -422,20 +427,21 @@ def run_trainer(batch_size: int, batch_time: float, client_mode: bool):

prev_time = time.perf_counter()

time.sleep(1.0)
optimizer.shutdown()
return optimizer

peers = []

for index in range(num_peers):
peer_batch_size = default_batch_size + index
peer_batch_time = default_batch_time + 0.01 * index
peers.append(
mp.Process(
target=run_trainer,
name=f"trainer-{index}",
kwargs=dict(
batch_size=4 + index,
batch_time=0.3 + 0.2 * index,
batch_size=peer_batch_size,
batch_time=peer_batch_time,
client_mode=(index >= num_peers - num_clients),
),
)
Expand All @@ -451,7 +457,12 @@ def run_trainer(batch_size: int, batch_time: float, client_mode: bool):
assert optimizer.local_epoch == optimizer.tracker.global_epoch == total_epochs
expected_samples_accumulated = target_batch_size * total_epochs
assert expected_samples_accumulated <= total_samples_accumulated.value <= expected_samples_accumulated * 1.2
assert 4 / 0.3 * 0.8 <= optimizer.tracker.performance_ema.samples_per_second <= 4 / 0.3 * 1.2
expected_performance = default_batch_size / default_batch_time
assert (
expected_performance * 0.8
<= optimizer.tracker.performance_ema.samples_per_second
<= expected_performance * 1.2
)

assert not optimizer.state_averager.is_alive()
assert not optimizer.tracker.is_alive()
Expand Down
19 changes: 15 additions & 4 deletions tests/test_util_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import random
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from threading import Event

import numpy as np
import pytest
Expand Down Expand Up @@ -266,9 +267,10 @@ def _check_result_and_set(future):
with pytest.raises(RuntimeError):
future1.add_done_callback(lambda future: (1, 2, 3))

events[0].wait()
assert future1.done() and not future1.cancelled()
assert future2.done() and future2.cancelled()
for i in 0, 1, 4:
for i in 1, 4:
events[i].wait(1)
assert events[0].is_set() and events[1].is_set() and events[2].is_set() and events[4].is_set()
assert not events[3].is_set()
Expand Down Expand Up @@ -557,16 +559,25 @@ def test_performance_ema_threadsafe(
bias_power: float = 0.7,
tolerance: float = 0.05,
):
def run_task(ema):
task_size = random.randint(1, 4)
def run_task(ema, start_event, task_size):
start_event.wait()
with ema.update_threadsafe(task_size):
time.sleep(task_size * interval * (0.9 + 0.2 * random.random()))
return task_size

with ThreadPoolExecutor(max_workers) as pool:
ema = PerformanceEMA(alpha=alpha)
start_event = Event()
start_time = time.perf_counter()
futures = [pool.submit(run_task, ema) for _ in range(num_updates)]

futures = []
for _ in range(num_updates):
task_size = random.randint(1, 4)
future = pool.submit(run_task, ema, start_event, task_size)
futures.append(future)

ema.reset_timer()
start_event.set()
total_size = sum(future.result() for future in as_completed(futures))
end_time = time.perf_counter()
target = total_size / (end_time - start_time)
Expand Down

0 comments on commit 0abe59c

Please sign in to comment.