Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix device in Switch-MoE, overhaul Server architecture #256

Merged
merged 9 commits into from
May 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 27 additions & 27 deletions hivemind/client/averaging/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,35 +171,34 @@ def _run_internal(self):
""" Serve DecentralizedAverager forever. This function will not return until the averager is shut down """
loop = switch_to_uvloop()
# initialize asyncio synchronization primitives in this event loop
pipe_awaiter = ThreadPoolExecutor(max_workers=1)

async def _run():
grpc.aio.init_grpc_aio()

if self.listen:
server = grpc.aio.server(**self.kwargs, options=GRPC_KEEPALIVE_OPTIONS)
averaging_pb2_grpc.add_DecentralizedAveragingServicer_to_server(self, server)
found_port = server.add_insecure_port(self.listen_on)
assert found_port != 0, f"Failed to listen to {self.listen_on}"
self._port.value = found_port
await server.start()
else:
logger.info(f"The averager running in an experimental client mode, please report any bugs.")
with ThreadPoolExecutor(max_workers=1) as pipe_awaiter:
async def _run():
grpc.aio.init_grpc_aio()

if self.listen:
server = grpc.aio.server(**self.kwargs, options=GRPC_KEEPALIVE_OPTIONS)
averaging_pb2_grpc.add_DecentralizedAveragingServicer_to_server(self, server)
found_port = server.add_insecure_port(self.listen_on)
assert found_port != 0, f"Failed to listen to {self.listen_on}"
self._port.value = found_port
await server.start()
else:
logger.info(f"The averager running in an experimental client mode, please report any bugs.")

self._matchmaking = Matchmaking(self.endpoint, self.schema_hash, self.dht, **self.matchmaking_kwargs,
client_mode=not self.listen)
if self.listen:
asyncio.create_task(self._declare_for_download_periodically())
self._matchmaking = Matchmaking(self.endpoint, self.schema_hash, self.dht, **self.matchmaking_kwargs,
client_mode=not self.listen)
if self.listen:
asyncio.create_task(self._declare_for_download_periodically())

self._pending_group_assembled = asyncio.Event()
self._pending_group_assembled.set()
self.ready.set()
self._pending_group_assembled = asyncio.Event()
self._pending_group_assembled.set()
self.ready.set()

while True:
method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._pipe.recv)
asyncio.create_task(getattr(self, method)(*args, **kwargs))
while True:
method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._pipe.recv)
asyncio.create_task(getattr(self, method)(*args, **kwargs))

loop.run_until_complete(_run())
loop.run_until_complete(_run())

def run_in_background(self, await_ready=True, timeout=None):
"""
Expand Down Expand Up @@ -255,7 +254,8 @@ async def _step(self, *, future: MPFuture, gather_binary: bytes, weight: float,
try:
self._pending_group_assembled.clear()
data_for_gather = self.serializer.dumps([weight, self._throughput, self.listen, gather_binary])
group_info = await self._matchmaking.look_for_group(timeout=timeout, data_for_gather=data_for_gather)
group_info = await self._matchmaking.look_for_group(timeout=timeout,
data_for_gather=data_for_gather)
if group_info is None:
raise AllreduceException("Averaging step failed: could not find a group.")
group_id = group_info.group_id
Expand Down Expand Up @@ -294,7 +294,7 @@ async def _make_allreduce_runner(self, group_info: GroupInfo, min_vector_size: i
""" Use a group description found by Matchmaking to form AllreduceRunner """
try:
weights, throughputs, modes, user_gathered = zip(*map(self.serializer.loads, group_info.gathered))
user_gathered = dict(zip(group_info.endpoints, map(self.serializer.loads, user_gathered)))
user_gathered = dict(zip(group_info.endpoints, map(self.serializer.loads, user_gathered)))

# compute optimal part sizes from peer throughputs
incoming_throughputs = [thr if listen else 0.0 for thr, listen in zip(throughputs, modes)]
Expand Down
11 changes: 7 additions & 4 deletions hivemind/client/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,11 @@ def compute_expert_scores(
batch_size = len(batch_experts)
max_num_experts = max(expert_counts)
total_num_experts = sum(expert_counts)
expert_index_in_batch = torch.arange(total_num_experts, device=grid_scores[0].device)
expert_strides = torch.cumsum(torch.as_tensor([0] + expert_counts, device=grid_scores[0].device), dim=-1)[:-1]

device = grid_scores[0].device

expert_index_in_batch = torch.arange(total_num_experts, device=device)
expert_strides = torch.cumsum(torch.as_tensor([0] + expert_counts, device=device), dim=-1)[:-1]
flat_batch_indices = (expert_index_in_batch >= expert_strides[:, None]).to(torch.int32).sum(0) - 1
flat_local_indices = expert_index_in_batch - expert_strides[flat_batch_indices]
flat_experts = [expert for row in batch_experts for expert in row]
Expand All @@ -133,11 +136,11 @@ def compute_expert_scores(
grid_indices[i] = torch.as_tensor(expert_indices, dtype=grid_indices.dtype)

scores_per_dim = [
dim_scores[flat_batch_indices, dim_indices] if len(flat_batch_indices) else torch.zeros(0)
dim_scores[flat_batch_indices, dim_indices] if len(flat_batch_indices) else torch.zeros(0, device=device)
for dim_scores, dim_indices in zip(grid_scores, grid_indices.T)]
flat_scores = torch.sum(torch.stack(scores_per_dim, dim=0), dim=0)

scores = torch.full((batch_size, max_num_experts), fill_value=-float('inf'), device=grid_scores[0].device)
scores = torch.full((batch_size, max_num_experts), fill_value=-float('inf'), device=device)
scores[flat_batch_indices, flat_local_indices] = flat_scores # backprop-able w.r.t. flat_scores
return scores

Expand Down
11 changes: 7 additions & 4 deletions hivemind/client/switch_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,11 @@ def compute_expert_scores(
batch_size = len(batch_experts)
max_num_experts = max(expert_counts)
total_num_experts = sum(expert_counts)
expert_index_in_batch = torch.arange(total_num_experts, device=grid_probs[0].device)
expert_strides = torch.cumsum(torch.as_tensor([0] + expert_counts, device=grid_probs[0].device), dim=-1)[:-1]

device = grid_probs[0].device

expert_index_in_batch = torch.arange(total_num_experts, device=device)
expert_strides = torch.cumsum(torch.as_tensor([0] + expert_counts, device=device), dim=-1)[:-1]
flat_batch_indices = (expert_index_in_batch >= expert_strides[:, None]).to(torch.int32).sum(0) - 1
flat_local_indices = expert_index_in_batch - expert_strides[flat_batch_indices]
flat_experts = [expert for row in batch_experts for expert in row]
Expand All @@ -169,10 +172,10 @@ def compute_expert_scores(
grid_indices[i] = torch.as_tensor(expert_indices, dtype=grid_indices.dtype)

scores_per_dim = [
dim_scores[flat_batch_indices, dim_indices] if len(flat_batch_indices) else torch.zeros(0)
dim_scores[flat_batch_indices, dim_indices] if len(flat_batch_indices) else torch.zeros(0, device=device)
for dim_scores, dim_indices in zip(grid_probs, grid_indices.T)]
flat_scores = torch.prod(torch.stack(scores_per_dim, dim=0), dim=0)

scores = torch.full((batch_size, max_num_experts), fill_value=-float('inf'), device=grid_probs[0].device)
scores = torch.full((batch_size, max_num_experts), fill_value=-float('inf'), device=device)
scores[flat_batch_indices, flat_local_indices] = flat_scores # backprop-able w.r.t. flat_scores
return scores
32 changes: 15 additions & 17 deletions hivemind/dht/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,25 +69,23 @@ def __init__(self, listen_on: Endpoint = "0.0.0.0:*", initial_peers: Sequence[En
def run(self) -> None:
""" Serve DHT forever. This function will not return until DHT node is shut down """
loop = switch_to_uvloop()
pipe_awaiter = ThreadPoolExecutor(max_workers=1)

async def _run():
node = await DHTNode.create(
initial_peers=list(self.initial_peers), listen_on=self.listen_on, parallel_rpc=self.parallel_rpc,
num_workers=self.max_workers or 1, record_validator=self._record_validator,
**self.kwargs)
if node.port is not None:
self._port.value = node.port
self.ready.set()
with ThreadPoolExecutor(max_workers=1) as pipe_awaiter:
async def _run():
node = await DHTNode.create(
initial_peers=list(self.initial_peers), listen_on=self.listen_on, parallel_rpc=self.parallel_rpc,
num_workers=self.max_workers or 1, record_validator=self._record_validator,
**self.kwargs)
if node.port is not None:
self._port.value = node.port
self.ready.set()

while True:
method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._pipe.recv)
asyncio.create_task(getattr(self, method)(node, *args, **kwargs))
while True:
method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._pipe.recv)
asyncio.create_task(getattr(self, method)(node, *args, **kwargs))

try:
loop.run_until_complete(_run())
except KeyboardInterrupt:
logger.debug("Caught KeyboardInterrupt, shutting down")
coro = _run()
loop.run_until_complete(coro)

def run_in_background(self, await_ready=True, timeout=None):
"""
Expand All @@ -96,7 +94,7 @@ def run_in_background(self, await_ready=True, timeout=None):
"""
self.start()
if await_ready and not self.ready.wait(timeout=timeout):
raise TimeoutError(f"Server didn't notify .ready in {timeout} seconds")
raise TimeoutError(f"DHT didn't notify .ready in {timeout} seconds")

def shutdown(self) -> None:
""" Shut down a running dht process """
Expand Down
4 changes: 3 additions & 1 deletion hivemind/hivemind_cli/run_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ def main():

parser.add_argument('--num_handlers', type=int, default=None, required=False,
help='server will use this many processes to handle incoming requests')
parser.add_argument('--max_batch_size', type=int, default=16384, required=False,
parser.add_argument('--min_batch_size', type=int, default=1,
help='Minimum required batch size for all expert operations')
parser.add_argument('--max_batch_size', type=int, default=16384,
help='The total number of examples in the same batch will not exceed this value')
parser.add_argument('--device', type=str, default=None, required=False,
help='all experts will use this device in torch notation; default: cuda if available else cpu')
Expand Down
52 changes: 30 additions & 22 deletions hivemind/server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,20 @@ def __init__(
self.checkpoint_saver = None
self.runtime = Runtime(self.experts, **kwargs)

if self.dht and self.experts:
self.dht_handler_thread = DHTHandlerThread(experts=self.experts, dht=self.dht, endpoint=self.listen_on,
update_period=self.update_period)

if start:
self.run_in_background(await_ready=True)

@classmethod
def create(cls, listen_on='0.0.0.0:*', num_experts: int = None, expert_uids: str = None, expert_pattern: str = None,
expert_cls='ffn', hidden_dim=1024, optim_cls=torch.optim.Adam, scheduler: str = 'none',
num_warmup_steps=None, num_total_steps=None, clip_grad_norm=None, num_handlers=None, max_batch_size=4096,
device=None, no_dht=False, initial_peers=(), dht_port=None, checkpoint_dir: Optional[Path] = None,
compression=CompressionType.NONE, stats_report_interval: Optional[int] = None, custom_module_path=None,
*, start: bool, **kwargs) -> Server:
num_warmup_steps=None, num_total_steps=None, clip_grad_norm=None, num_handlers=None, min_batch_size=1,
max_batch_size=4096, device=None, no_dht=False, initial_peers=(), dht_port=None,
checkpoint_dir: Optional[Path] = None, compression=CompressionType.NONE,
stats_report_interval: Optional[int] = None, custom_module_path=None, *, start: bool) -> Server:
"""
Instantiate a server with several identical experts. See argparse comments below for details
:param listen_on: network interface with address and (optional) port, e.g. "127.0.0.1:1337" or "[::]:80"
Expand All @@ -85,6 +89,7 @@ def create(cls, listen_on='0.0.0.0:*', num_experts: int = None, expert_uids: str
:param expert_cls: expert type from hivemind.server.layers, e.g. 'ffn', 'transformer', 'det_dropout' or 'nop';
:param hidden_dim: main dimension for expert_cls
:param num_handlers: server will use this many parallel processes to handle incoming requests
:param min_batch_size: total num examples in the same batch will be greater than this value
:param max_batch_size: total num examples in the same batch will not exceed this value
:param device: all experts will use this device in torch notation; default: cuda if available else cpu
Expand Down Expand Up @@ -112,9 +117,6 @@ def create(cls, listen_on='0.0.0.0:*', num_experts: int = None, expert_uids: str
"""
if custom_module_path is not None:
add_custom_models_from_file(custom_module_path)

if len(kwargs) != 0:
logger.info("Ignored kwargs:", kwargs)
assert expert_cls in name_to_block

if no_dht:
Expand Down Expand Up @@ -172,6 +174,7 @@ def create(cls, listen_on='0.0.0.0:*', num_experts: int = None, expert_uids: str
num_warmup_steps=num_warmup_steps,
num_total_steps=num_total_steps,
clip_grad_norm=clip_grad_norm,
min_batch_size=min_batch_size,
max_batch_size=max_batch_size)

if checkpoint_dir is not None:
Expand All @@ -196,9 +199,7 @@ def run(self):
self.dht.run_in_background(await_ready=True)

if self.experts:
dht_handler_thread = DHTHandlerThread(
experts=self.experts, dht=self.dht, endpoint=self.listen_on, update_period=self.update_period)
dht_handler_thread.start()
self.dht_handler_thread.start()
if self.checkpoint_saver is not None:
self.checkpoint_saver.start()

Expand All @@ -207,16 +208,10 @@ def run(self):
process.start()
process.ready.wait()

self.runtime.run()

for process in self.conn_handlers:
process.join()
if self.dht and self.experts:
dht_handler_thread.stop.set()
dht_handler_thread.join()
if self.checkpoint_saver is not None:
self.checkpoint_saver.stop.set()
self.checkpoint_saver.join()
try:
self.runtime.run()
finally:
self.shutdown()

def run_in_background(self, await_ready=True, timeout=None):
"""
Expand All @@ -242,19 +237,32 @@ def ready(self) -> mp.synchronize.Event:

def shutdown(self):
"""
Gracefully terminate a hivemind server, process-safe.
Gracefully terminate the server, process-safe.
Please note that terminating server otherwise (e.g. by killing processes) may result in zombie processes.
If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL).
"""
self.ready.clear()

for process in self.conn_handlers:
process.terminate()
process.join()
logger.debug("Connection handlers terminated")

if self.dht and self.experts:
self.dht_handler_thread.stop.set()
self.dht_handler_thread.join()

if self.checkpoint_saver is not None:
self.checkpoint_saver.stop.set()
self.checkpoint_saver.join()

if self.dht is not None:
self.dht.shutdown()
self.dht.join()

self.runtime.shutdown()
logger.debug(f"Shutting down runtime")
self.runtime.stop.set()
logger.info("Server shutdown succesfully")


@contextmanager
Expand Down
5 changes: 4 additions & 1 deletion hivemind/server/connection_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,10 @@ async def _run():
await server.wait_for_termination()
logger.debug(f"ConnectionHandler terminated: (pid={os.getpid()})")

loop.run_until_complete(_run())
try:
loop.run_until_complete(_run())
except KeyboardInterrupt:
logger.debug('Caught KeyboardInterrupt, shutting down')

async def info(self, request: runtime_pb2.ExpertUID, context: grpc.ServicerContext):
return runtime_pb2.ExpertInfo(serialized_info=pickle.dumps(self.experts[request.uid].get_info()))
Expand Down
4 changes: 2 additions & 2 deletions hivemind/server/expert_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ def __init__(self, name: str, expert: nn.Module, optimizer: torch.optim.Optimize

self.backward_schema = (self.forward_schema, self.outputs_schema) # inputs to backward
self.grad_inputs_schema = self.forward_schema # outputs from backward
self.forward_pool = TaskPool(self.forward, uid=f'{self.name}_forward', **kwargs)
self.backward_pool = TaskPool(self.backward, uid=f'{self.name}_backward', **kwargs)
self.forward_pool = TaskPool(self.forward, name=f'{self.name}_forward', **kwargs)
self.backward_pool = TaskPool(self.backward, name=f'{self.name}_backward', **kwargs)

self.update_count = 0
self.examples_processed = 0
Expand Down
4 changes: 2 additions & 2 deletions hivemind/server/expert_uid.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ def _generate_uid():
uid.append(str(random.randint(slice_start, slice_end - 1)))
else:
raise ValueError("Block must be either fixed or a range [from:to]")
except KeyboardInterrupt as e:
raise e
except KeyboardInterrupt:
raise
except Exception as e:
raise ValueError(f"Expert pattern {expert_pattern} has invalid block {block}, {e}")
return UID_DELIMITER.join(uid)
Expand Down
Loading