Skip to content

Commit

Permalink
Fix device in Switch-MoE, overhaul Server architecture (#256)
Browse files Browse the repository at this point in the history
* Set correct device for scores

* Put pipe_awaiter in a context manager

* Pass min_batch_size to ExpertBackend in Server.create

* Remove unneeded variable for exception in generate_uids_from_pattern

* Overhaul server architecture
  • Loading branch information
mryab authored May 6, 2021
1 parent 94b9db0 commit 2328ba9
Show file tree
Hide file tree
Showing 11 changed files with 174 additions and 170 deletions.
54 changes: 27 additions & 27 deletions hivemind/client/averaging/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,35 +171,34 @@ def _run_internal(self):
""" Serve DecentralizedAverager forever. This function will not return until the averager is shut down """
loop = switch_to_uvloop()
# initialize asyncio synchronization primitives in this event loop
pipe_awaiter = ThreadPoolExecutor(max_workers=1)

async def _run():
grpc.aio.init_grpc_aio()

if self.listen:
server = grpc.aio.server(**self.kwargs, options=GRPC_KEEPALIVE_OPTIONS)
averaging_pb2_grpc.add_DecentralizedAveragingServicer_to_server(self, server)
found_port = server.add_insecure_port(self.listen_on)
assert found_port != 0, f"Failed to listen to {self.listen_on}"
self._port.value = found_port
await server.start()
else:
logger.info(f"The averager running in an experimental client mode, please report any bugs.")
with ThreadPoolExecutor(max_workers=1) as pipe_awaiter:
async def _run():
grpc.aio.init_grpc_aio()

if self.listen:
server = grpc.aio.server(**self.kwargs, options=GRPC_KEEPALIVE_OPTIONS)
averaging_pb2_grpc.add_DecentralizedAveragingServicer_to_server(self, server)
found_port = server.add_insecure_port(self.listen_on)
assert found_port != 0, f"Failed to listen to {self.listen_on}"
self._port.value = found_port
await server.start()
else:
logger.info(f"The averager running in an experimental client mode, please report any bugs.")

self._matchmaking = Matchmaking(self.endpoint, self.schema_hash, self.dht, **self.matchmaking_kwargs,
client_mode=not self.listen)
if self.listen:
asyncio.create_task(self._declare_for_download_periodically())
self._matchmaking = Matchmaking(self.endpoint, self.schema_hash, self.dht, **self.matchmaking_kwargs,
client_mode=not self.listen)
if self.listen:
asyncio.create_task(self._declare_for_download_periodically())

self._pending_group_assembled = asyncio.Event()
self._pending_group_assembled.set()
self.ready.set()
self._pending_group_assembled = asyncio.Event()
self._pending_group_assembled.set()
self.ready.set()

while True:
method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._pipe.recv)
asyncio.create_task(getattr(self, method)(*args, **kwargs))
while True:
method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._pipe.recv)
asyncio.create_task(getattr(self, method)(*args, **kwargs))

loop.run_until_complete(_run())
loop.run_until_complete(_run())

def run_in_background(self, await_ready=True, timeout=None):
"""
Expand Down Expand Up @@ -255,7 +254,8 @@ async def _step(self, *, future: MPFuture, gather_binary: bytes, weight: float,
try:
self._pending_group_assembled.clear()
data_for_gather = self.serializer.dumps([weight, self._throughput, self.listen, gather_binary])
group_info = await self._matchmaking.look_for_group(timeout=timeout, data_for_gather=data_for_gather)
group_info = await self._matchmaking.look_for_group(timeout=timeout,
data_for_gather=data_for_gather)
if group_info is None:
raise AllreduceException("Averaging step failed: could not find a group.")
group_id = group_info.group_id
Expand Down Expand Up @@ -294,7 +294,7 @@ async def _make_allreduce_runner(self, group_info: GroupInfo, min_vector_size: i
""" Use a group description found by Matchmaking to form AllreduceRunner """
try:
weights, throughputs, modes, user_gathered = zip(*map(self.serializer.loads, group_info.gathered))
user_gathered = dict(zip(group_info.endpoints, map(self.serializer.loads, user_gathered)))
user_gathered = dict(zip(group_info.endpoints, map(self.serializer.loads, user_gathered)))

# compute optimal part sizes from peer throughputs
incoming_throughputs = [thr if listen else 0.0 for thr, listen in zip(throughputs, modes)]
Expand Down
11 changes: 7 additions & 4 deletions hivemind/client/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,11 @@ def compute_expert_scores(
batch_size = len(batch_experts)
max_num_experts = max(expert_counts)
total_num_experts = sum(expert_counts)
expert_index_in_batch = torch.arange(total_num_experts, device=grid_scores[0].device)
expert_strides = torch.cumsum(torch.as_tensor([0] + expert_counts, device=grid_scores[0].device), dim=-1)[:-1]

device = grid_scores[0].device

expert_index_in_batch = torch.arange(total_num_experts, device=device)
expert_strides = torch.cumsum(torch.as_tensor([0] + expert_counts, device=device), dim=-1)[:-1]
flat_batch_indices = (expert_index_in_batch >= expert_strides[:, None]).to(torch.int32).sum(0) - 1
flat_local_indices = expert_index_in_batch - expert_strides[flat_batch_indices]
flat_experts = [expert for row in batch_experts for expert in row]
Expand All @@ -133,11 +136,11 @@ def compute_expert_scores(
grid_indices[i] = torch.as_tensor(expert_indices, dtype=grid_indices.dtype)

scores_per_dim = [
dim_scores[flat_batch_indices, dim_indices] if len(flat_batch_indices) else torch.zeros(0)
dim_scores[flat_batch_indices, dim_indices] if len(flat_batch_indices) else torch.zeros(0, device=device)
for dim_scores, dim_indices in zip(grid_scores, grid_indices.T)]
flat_scores = torch.sum(torch.stack(scores_per_dim, dim=0), dim=0)

scores = torch.full((batch_size, max_num_experts), fill_value=-float('inf'), device=grid_scores[0].device)
scores = torch.full((batch_size, max_num_experts), fill_value=-float('inf'), device=device)
scores[flat_batch_indices, flat_local_indices] = flat_scores # backprop-able w.r.t. flat_scores
return scores

Expand Down
11 changes: 7 additions & 4 deletions hivemind/client/switch_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,11 @@ def compute_expert_scores(
batch_size = len(batch_experts)
max_num_experts = max(expert_counts)
total_num_experts = sum(expert_counts)
expert_index_in_batch = torch.arange(total_num_experts, device=grid_probs[0].device)
expert_strides = torch.cumsum(torch.as_tensor([0] + expert_counts, device=grid_probs[0].device), dim=-1)[:-1]

device = grid_probs[0].device

expert_index_in_batch = torch.arange(total_num_experts, device=device)
expert_strides = torch.cumsum(torch.as_tensor([0] + expert_counts, device=device), dim=-1)[:-1]
flat_batch_indices = (expert_index_in_batch >= expert_strides[:, None]).to(torch.int32).sum(0) - 1
flat_local_indices = expert_index_in_batch - expert_strides[flat_batch_indices]
flat_experts = [expert for row in batch_experts for expert in row]
Expand All @@ -169,10 +172,10 @@ def compute_expert_scores(
grid_indices[i] = torch.as_tensor(expert_indices, dtype=grid_indices.dtype)

scores_per_dim = [
dim_scores[flat_batch_indices, dim_indices] if len(flat_batch_indices) else torch.zeros(0)
dim_scores[flat_batch_indices, dim_indices] if len(flat_batch_indices) else torch.zeros(0, device=device)
for dim_scores, dim_indices in zip(grid_probs, grid_indices.T)]
flat_scores = torch.prod(torch.stack(scores_per_dim, dim=0), dim=0)

scores = torch.full((batch_size, max_num_experts), fill_value=-float('inf'), device=grid_probs[0].device)
scores = torch.full((batch_size, max_num_experts), fill_value=-float('inf'), device=device)
scores[flat_batch_indices, flat_local_indices] = flat_scores # backprop-able w.r.t. flat_scores
return scores
32 changes: 15 additions & 17 deletions hivemind/dht/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,25 +69,23 @@ def __init__(self, listen_on: Endpoint = "0.0.0.0:*", initial_peers: Sequence[En
def run(self) -> None:
""" Serve DHT forever. This function will not return until DHT node is shut down """
loop = switch_to_uvloop()
pipe_awaiter = ThreadPoolExecutor(max_workers=1)

async def _run():
node = await DHTNode.create(
initial_peers=list(self.initial_peers), listen_on=self.listen_on, parallel_rpc=self.parallel_rpc,
num_workers=self.max_workers or 1, record_validator=self._record_validator,
**self.kwargs)
if node.port is not None:
self._port.value = node.port
self.ready.set()
with ThreadPoolExecutor(max_workers=1) as pipe_awaiter:
async def _run():
node = await DHTNode.create(
initial_peers=list(self.initial_peers), listen_on=self.listen_on, parallel_rpc=self.parallel_rpc,
num_workers=self.max_workers or 1, record_validator=self._record_validator,
**self.kwargs)
if node.port is not None:
self._port.value = node.port
self.ready.set()

while True:
method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._pipe.recv)
asyncio.create_task(getattr(self, method)(node, *args, **kwargs))
while True:
method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._pipe.recv)
asyncio.create_task(getattr(self, method)(node, *args, **kwargs))

try:
loop.run_until_complete(_run())
except KeyboardInterrupt:
logger.debug("Caught KeyboardInterrupt, shutting down")
coro = _run()
loop.run_until_complete(coro)

def run_in_background(self, await_ready=True, timeout=None):
"""
Expand All @@ -96,7 +94,7 @@ def run_in_background(self, await_ready=True, timeout=None):
"""
self.start()
if await_ready and not self.ready.wait(timeout=timeout):
raise TimeoutError(f"Server didn't notify .ready in {timeout} seconds")
raise TimeoutError(f"DHT didn't notify .ready in {timeout} seconds")

def shutdown(self) -> None:
""" Shut down a running dht process """
Expand Down
4 changes: 3 additions & 1 deletion hivemind/hivemind_cli/run_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ def main():

parser.add_argument('--num_handlers', type=int, default=None, required=False,
help='server will use this many processes to handle incoming requests')
parser.add_argument('--max_batch_size', type=int, default=16384, required=False,
parser.add_argument('--min_batch_size', type=int, default=1,
help='Minimum required batch size for all expert operations')
parser.add_argument('--max_batch_size', type=int, default=16384,
help='The total number of examples in the same batch will not exceed this value')
parser.add_argument('--device', type=str, default=None, required=False,
help='all experts will use this device in torch notation; default: cuda if available else cpu')
Expand Down
52 changes: 30 additions & 22 deletions hivemind/server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,20 @@ def __init__(
self.checkpoint_saver = None
self.runtime = Runtime(self.experts, **kwargs)

if self.dht and self.experts:
self.dht_handler_thread = DHTHandlerThread(experts=self.experts, dht=self.dht, endpoint=self.listen_on,
update_period=self.update_period)

if start:
self.run_in_background(await_ready=True)

@classmethod
def create(cls, listen_on='0.0.0.0:*', num_experts: int = None, expert_uids: str = None, expert_pattern: str = None,
expert_cls='ffn', hidden_dim=1024, optim_cls=torch.optim.Adam, scheduler: str = 'none',
num_warmup_steps=None, num_total_steps=None, clip_grad_norm=None, num_handlers=None, max_batch_size=4096,
device=None, no_dht=False, initial_peers=(), dht_port=None, checkpoint_dir: Optional[Path] = None,
compression=CompressionType.NONE, stats_report_interval: Optional[int] = None, custom_module_path=None,
*, start: bool, **kwargs) -> Server:
num_warmup_steps=None, num_total_steps=None, clip_grad_norm=None, num_handlers=None, min_batch_size=1,
max_batch_size=4096, device=None, no_dht=False, initial_peers=(), dht_port=None,
checkpoint_dir: Optional[Path] = None, compression=CompressionType.NONE,
stats_report_interval: Optional[int] = None, custom_module_path=None, *, start: bool) -> Server:
"""
Instantiate a server with several identical experts. See argparse comments below for details
:param listen_on: network interface with address and (optional) port, e.g. "127.0.0.1:1337" or "[::]:80"
Expand All @@ -85,6 +89,7 @@ def create(cls, listen_on='0.0.0.0:*', num_experts: int = None, expert_uids: str
:param expert_cls: expert type from hivemind.server.layers, e.g. 'ffn', 'transformer', 'det_dropout' or 'nop';
:param hidden_dim: main dimension for expert_cls
:param num_handlers: server will use this many parallel processes to handle incoming requests
:param min_batch_size: total num examples in the same batch will be greater than this value
:param max_batch_size: total num examples in the same batch will not exceed this value
:param device: all experts will use this device in torch notation; default: cuda if available else cpu
Expand Down Expand Up @@ -112,9 +117,6 @@ def create(cls, listen_on='0.0.0.0:*', num_experts: int = None, expert_uids: str
"""
if custom_module_path is not None:
add_custom_models_from_file(custom_module_path)

if len(kwargs) != 0:
logger.info("Ignored kwargs:", kwargs)
assert expert_cls in name_to_block

if no_dht:
Expand Down Expand Up @@ -172,6 +174,7 @@ def create(cls, listen_on='0.0.0.0:*', num_experts: int = None, expert_uids: str
num_warmup_steps=num_warmup_steps,
num_total_steps=num_total_steps,
clip_grad_norm=clip_grad_norm,
min_batch_size=min_batch_size,
max_batch_size=max_batch_size)

if checkpoint_dir is not None:
Expand All @@ -196,9 +199,7 @@ def run(self):
self.dht.run_in_background(await_ready=True)

if self.experts:
dht_handler_thread = DHTHandlerThread(
experts=self.experts, dht=self.dht, endpoint=self.listen_on, update_period=self.update_period)
dht_handler_thread.start()
self.dht_handler_thread.start()
if self.checkpoint_saver is not None:
self.checkpoint_saver.start()

Expand All @@ -207,16 +208,10 @@ def run(self):
process.start()
process.ready.wait()

self.runtime.run()

for process in self.conn_handlers:
process.join()
if self.dht and self.experts:
dht_handler_thread.stop.set()
dht_handler_thread.join()
if self.checkpoint_saver is not None:
self.checkpoint_saver.stop.set()
self.checkpoint_saver.join()
try:
self.runtime.run()
finally:
self.shutdown()

def run_in_background(self, await_ready=True, timeout=None):
"""
Expand All @@ -242,19 +237,32 @@ def ready(self) -> mp.synchronize.Event:

def shutdown(self):
"""
Gracefully terminate a hivemind server, process-safe.
Gracefully terminate the server, process-safe.
Please note that terminating server otherwise (e.g. by killing processes) may result in zombie processes.
If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL).
"""
self.ready.clear()

for process in self.conn_handlers:
process.terminate()
process.join()
logger.debug("Connection handlers terminated")

if self.dht and self.experts:
self.dht_handler_thread.stop.set()
self.dht_handler_thread.join()

if self.checkpoint_saver is not None:
self.checkpoint_saver.stop.set()
self.checkpoint_saver.join()

if self.dht is not None:
self.dht.shutdown()
self.dht.join()

self.runtime.shutdown()
logger.debug(f"Shutting down runtime")
self.runtime.stop.set()
logger.info("Server shutdown succesfully")


@contextmanager
Expand Down
5 changes: 4 additions & 1 deletion hivemind/server/connection_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,10 @@ async def _run():
await server.wait_for_termination()
logger.debug(f"ConnectionHandler terminated: (pid={os.getpid()})")

loop.run_until_complete(_run())
try:
loop.run_until_complete(_run())
except KeyboardInterrupt:
logger.debug('Caught KeyboardInterrupt, shutting down')

async def info(self, request: runtime_pb2.ExpertUID, context: grpc.ServicerContext):
return runtime_pb2.ExpertInfo(serialized_info=pickle.dumps(self.experts[request.uid].get_info()))
Expand Down
4 changes: 2 additions & 2 deletions hivemind/server/expert_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ def __init__(self, name: str, expert: nn.Module, optimizer: torch.optim.Optimize

self.backward_schema = (self.forward_schema, self.outputs_schema) # inputs to backward
self.grad_inputs_schema = self.forward_schema # outputs from backward
self.forward_pool = TaskPool(self.forward, uid=f'{self.name}_forward', **kwargs)
self.backward_pool = TaskPool(self.backward, uid=f'{self.name}_backward', **kwargs)
self.forward_pool = TaskPool(self.forward, name=f'{self.name}_forward', **kwargs)
self.backward_pool = TaskPool(self.backward, name=f'{self.name}_backward', **kwargs)

self.update_count = 0
self.examples_processed = 0
Expand Down
4 changes: 2 additions & 2 deletions hivemind/server/expert_uid.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ def _generate_uid():
uid.append(str(random.randint(slice_start, slice_end - 1)))
else:
raise ValueError("Block must be either fixed or a range [from:to]")
except KeyboardInterrupt as e:
raise e
except KeyboardInterrupt:
raise
except Exception as e:
raise ValueError(f"Expert pattern {expert_pattern} has invalid block {block}, {e}")
return UID_DELIMITER.join(uid)
Expand Down
Loading

0 comments on commit 2328ba9

Please sign in to comment.