Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix device in Switch-MoE, overhaul Server architecture #256

Merged
merged 9 commits into from
May 6, 2021
Merged
11 changes: 7 additions & 4 deletions hivemind/client/switch_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,11 @@ def compute_expert_scores(
batch_size = len(batch_experts)
max_num_experts = max(expert_counts)
total_num_experts = sum(expert_counts)
expert_index_in_batch = torch.arange(total_num_experts, device=grid_probs[0].device)
expert_strides = torch.cumsum(torch.as_tensor([0] + expert_counts, device=grid_probs[0].device), dim=-1)[:-1]

device = grid_probs[0].device

expert_index_in_batch = torch.arange(total_num_experts, device=device)
expert_strides = torch.cumsum(torch.as_tensor([0] + expert_counts, device=device), dim=-1)[:-1]
flat_batch_indices = (expert_index_in_batch >= expert_strides[:, None]).to(torch.int32).sum(0) - 1
flat_local_indices = expert_index_in_batch - expert_strides[flat_batch_indices]
flat_experts = [expert for row in batch_experts for expert in row]
Expand All @@ -169,10 +172,10 @@ def compute_expert_scores(
grid_indices[i] = torch.as_tensor(expert_indices, dtype=grid_indices.dtype)

scores_per_dim = [
dim_scores[flat_batch_indices, dim_indices] if len(flat_batch_indices) else torch.zeros(0)
dim_scores[flat_batch_indices, dim_indices] if len(flat_batch_indices) else torch.zeros(0, device=device)
for dim_scores, dim_indices in zip(grid_probs, grid_indices.T)]
flat_scores = torch.prod(torch.stack(scores_per_dim, dim=0), dim=0)

scores = torch.full((batch_size, max_num_experts), fill_value=-float('inf'), device=grid_probs[0].device)
scores = torch.full((batch_size, max_num_experts), fill_value=-float('inf'), device=device)
scores[flat_batch_indices, flat_local_indices] = flat_scores # backprop-able w.r.t. flat_scores
return scores
39 changes: 20 additions & 19 deletions hivemind/dht/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,25 +69,26 @@ def __init__(self, listen_on: Endpoint = "0.0.0.0:*", initial_peers: Sequence[En
def run(self) -> None:
""" Serve DHT forever. This function will not return until DHT node is shut down """
loop = switch_to_uvloop()
pipe_awaiter = ThreadPoolExecutor(max_workers=1)

async def _run():
node = await DHTNode.create(
initial_peers=list(self.initial_peers), listen_on=self.listen_on, parallel_rpc=self.parallel_rpc,
num_workers=self.max_workers or 1, record_validator=self._record_validator,
**self.kwargs)
if node.port is not None:
self._port.value = node.port
self.ready.set()

while True:
method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._pipe.recv)
asyncio.create_task(getattr(self, method)(node, *args, **kwargs))

try:
loop.run_until_complete(_run())
except KeyboardInterrupt:
logger.debug("Caught KeyboardInterrupt, shutting down")
with ThreadPoolExecutor(max_workers=1) as pipe_awaiter:
async def _run():
node = await DHTNode.create(
initial_peers=list(self.initial_peers), listen_on=self.listen_on, parallel_rpc=self.parallel_rpc,
num_workers=self.max_workers or 1, record_validator=self._record_validator,
**self.kwargs)
if node.port is not None:
self._port.value = node.port
self.ready.set()

while True:
method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._pipe.recv)
asyncio.create_task(getattr(self, method)(node, *args, **kwargs))

coro = _run()
try:
loop.run_until_complete(coro)
except KeyboardInterrupt:
logger.debug("Caught KeyboardInterrupt, shutting down")

def run_in_background(self, await_ready=True, timeout=None):
"""
Expand All @@ -96,7 +97,7 @@ def run_in_background(self, await_ready=True, timeout=None):
"""
self.start()
if await_ready and not self.ready.wait(timeout=timeout):
raise TimeoutError(f"Server didn't notify .ready in {timeout} seconds")
raise TimeoutError(f"DHT didn't notify .ready in {timeout} seconds")

def shutdown(self) -> None:
""" Shut down a running dht process """
Expand Down
4 changes: 3 additions & 1 deletion hivemind/hivemind_cli/run_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ def main():

parser.add_argument('--num_handlers', type=int, default=None, required=False,
help='server will use this many processes to handle incoming requests')
parser.add_argument('--max_batch_size', type=int, default=16384, required=False,
parser.add_argument('--min_batch_size', type=int, default=1,
help='Minimum required batch size for all expert operations')
parser.add_argument('--max_batch_size', type=int, default=16384,
help='The total number of examples in the same batch will not exceed this value')
parser.add_argument('--device', type=str, default=None, required=False,
help='all experts will use this device in torch notation; default: cuda if available else cpu')
Expand Down
52 changes: 30 additions & 22 deletions hivemind/server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,20 @@ def __init__(
self.checkpoint_saver = None
self.runtime = Runtime(self.experts, **kwargs)

if self.dht and self.experts:
self.dht_handler_thread = DHTHandlerThread(experts=self.experts, dht=self.dht, endpoint=self.listen_on,
update_period=self.update_period)

if start:
self.run_in_background(await_ready=True)

@classmethod
def create(cls, listen_on='0.0.0.0:*', num_experts: int = None, expert_uids: str = None, expert_pattern: str = None,
expert_cls='ffn', hidden_dim=1024, optim_cls=torch.optim.Adam, scheduler: str = 'none',
num_warmup_steps=None, num_total_steps=None, clip_grad_norm=None, num_handlers=None, max_batch_size=4096,
device=None, no_dht=False, initial_peers=(), dht_port=None, checkpoint_dir: Optional[Path] = None,
compression=CompressionType.NONE, stats_report_interval: Optional[int] = None, custom_module_path=None,
*, start: bool, **kwargs) -> Server:
num_warmup_steps=None, num_total_steps=None, clip_grad_norm=None, num_handlers=None, min_batch_size=1,
max_batch_size=4096, device=None, no_dht=False, initial_peers=(), dht_port=None,
checkpoint_dir: Optional[Path] = None, compression=CompressionType.NONE,
stats_report_interval: Optional[int] = None, custom_module_path=None, *, start: bool) -> Server:
"""
Instantiate a server with several identical experts. See argparse comments below for details
:param listen_on: network interface with address and (optional) port, e.g. "127.0.0.1:1337" or "[::]:80"
Expand All @@ -85,6 +89,7 @@ def create(cls, listen_on='0.0.0.0:*', num_experts: int = None, expert_uids: str
:param expert_cls: expert type from hivemind.server.layers, e.g. 'ffn', 'transformer', 'det_dropout' or 'nop';
:param hidden_dim: main dimension for expert_cls
:param num_handlers: server will use this many parallel processes to handle incoming requests
:param min_batch_size: total num examples in the same batch will be greater than this value
:param max_batch_size: total num examples in the same batch will not exceed this value
:param device: all experts will use this device in torch notation; default: cuda if available else cpu

Expand Down Expand Up @@ -112,9 +117,6 @@ def create(cls, listen_on='0.0.0.0:*', num_experts: int = None, expert_uids: str
"""
if custom_module_path is not None:
add_custom_models_from_file(custom_module_path)

if len(kwargs) != 0:
logger.info("Ignored kwargs:", kwargs)
assert expert_cls in name_to_block

if no_dht:
Expand Down Expand Up @@ -172,6 +174,7 @@ def create(cls, listen_on='0.0.0.0:*', num_experts: int = None, expert_uids: str
num_warmup_steps=num_warmup_steps,
num_total_steps=num_total_steps,
clip_grad_norm=clip_grad_norm,
min_batch_size=min_batch_size,
max_batch_size=max_batch_size)

if checkpoint_dir is not None:
Expand All @@ -196,9 +199,7 @@ def run(self):
self.dht.run_in_background(await_ready=True)

if self.experts:
dht_handler_thread = DHTHandlerThread(
experts=self.experts, dht=self.dht, endpoint=self.listen_on, update_period=self.update_period)
dht_handler_thread.start()
self.dht_handler_thread.start()
if self.checkpoint_saver is not None:
self.checkpoint_saver.start()

Expand All @@ -207,16 +208,10 @@ def run(self):
process.start()
process.ready.wait()

self.runtime.run()

for process in self.conn_handlers:
process.join()
if self.dht and self.experts:
dht_handler_thread.stop.set()
dht_handler_thread.join()
if self.checkpoint_saver is not None:
self.checkpoint_saver.stop.set()
self.checkpoint_saver.join()
try:
self.runtime.run()
finally:
self.shutdown()

def run_in_background(self, await_ready=True, timeout=None):
"""
Expand All @@ -242,19 +237,32 @@ def ready(self) -> mp.synchronize.Event:

def shutdown(self):
"""
Gracefully terminate a hivemind server, process-safe.
Gracefully terminate the server, process-safe.
Please note that terminating server otherwise (e.g. by killing processes) may result in zombie processes.
If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL).
"""
self.ready.clear()

for process in self.conn_handlers:
process.terminate()
process.join()
logger.debug("Connection handlers terminated")

if self.dht and self.experts:
self.dht_handler_thread.stop.set()
self.dht_handler_thread.join()

if self.checkpoint_saver is not None:
self.checkpoint_saver.stop.set()
self.checkpoint_saver.join()

if self.dht is not None:
self.dht.shutdown()
self.dht.join()

self.runtime.shutdown()
logger.debug(f"Shutting down runtime")
self.runtime.stop.set()
logger.info("Server shutdown succesfully")


@contextmanager
Expand Down
5 changes: 4 additions & 1 deletion hivemind/server/connection_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,10 @@ async def _run():
await server.wait_for_termination()
logger.debug(f"ConnectionHandler terminated: (pid={os.getpid()})")

loop.run_until_complete(_run())
try:
loop.run_until_complete(_run())
except KeyboardInterrupt:
logger.debug('Caught KeyboardInterrupt, shutting down')

async def info(self, request: runtime_pb2.ExpertUID, context: grpc.ServicerContext):
return runtime_pb2.ExpertInfo(serialized_info=pickle.dumps(self.experts[request.uid].get_info()))
Expand Down
8 changes: 4 additions & 4 deletions hivemind/server/expert_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,13 @@ def __init__(self, name: str, expert: nn.Module, optimizer: torch.optim.Optimize

self.backward_schema = (self.forward_schema, self.outputs_schema) # inputs to backward
self.grad_inputs_schema = self.forward_schema # outputs from backward
self.forward_pool = TaskPool(self.forward, uid=f'{self.name}_forward', **kwargs)
self.backward_pool = TaskPool(self.backward, uid=f'{self.name}_backward', **kwargs)
self.forward_pool = TaskPool(self.forward, name=f'{self.name}_forward', **kwargs)
self.backward_pool = TaskPool(self.backward, name=f'{self.name}_backward', **kwargs)

self.update_count = 0
self.examples_processed = 0

@torch.no_grad()
def forward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
"""
Apply forward pass to an aggregated batch of requests. Used by Runtime, do not call this manually;
Expand All @@ -99,8 +100,7 @@ def forward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
if args[0].shape[0] == 0:
raise RuntimeError("Batch should contain more than 0 samples")

with torch.no_grad():
outputs = self.expert(*args, **kwargs)
outputs = self.expert(*args, **kwargs)

# Note: TaskPool requires function to accept and return a flat tuple of values, we pack/unpack it on client side
return tuple(nested_flatten(outputs))
Expand Down
4 changes: 2 additions & 2 deletions hivemind/server/expert_uid.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ def _generate_uid():
uid.append(str(random.randint(slice_start, slice_end - 1)))
else:
raise ValueError("Block must be either fixed or a range [from:to]")
except KeyboardInterrupt as e:
raise e
except KeyboardInterrupt:
raise
except Exception as e:
raise ValueError(f"Expert pattern {expert_pattern} has invalid block {block}, {e}")
return UID_DELIMITER.join(uid)
Expand Down
37 changes: 17 additions & 20 deletions hivemind/server/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def __init__(self, expert_backends: Dict[str, ExpertBackend], prefetch_batches=6
self.expert_backends = expert_backends
self.pools = tuple(chain(*(expert.get_pools() for expert in expert_backends.values())))
self.device, self.prefetch_batches, self.sender_threads = device, prefetch_batches, sender_threads
self.shutdown_recv, self.shutdown_send = mp.Pipe(duplex=False)
self.ready = mp.Event() # event is set iff server is currently running and ready to accept batches
self.stop = threading.Event()

self.stats_report_interval = stats_report_interval
if self.stats_report_interval is not None:
Expand All @@ -72,62 +72,59 @@ def run(self):

for pool, batch_index, batch in BackgroundGenerator(
self.iterate_minibatches_from_pools(), self.prefetch_batches):
logger.debug(f"Processing batch {batch_index} from pool {pool.uid}")
logger.debug(f"Processing batch {batch_index} from pool {pool.name}")

start = time()
outputs = pool.process_func(*batch)
batch_processing_time = time() - start

batch_size = outputs[0].size(0)
logger.debug(f"Pool {pool.uid}: batch {batch_index} processed, size {batch_size}")
logger.debug(f"Pool {pool.name}: batch {batch_index} processed, size {batch_size}")

if self.stats_report_interval is not None:
self.stats_reporter.report_stats(pool.uid, batch_size, batch_processing_time)
self.stats_reporter.report_stats(pool.name, batch_size, batch_processing_time)

output_sender_pool.apply_async(pool.send_outputs_from_runtime, args=[batch_index, outputs])
finally:
logger.info("Shutting down")

if self.stats_report_interval is not None:
self.stats_reporter.stop.set()
self.stats_reporter.join()

self.shutdown()

SHUTDOWN_TRIGGER = "RUNTIME SHUTDOWN TRIGGERED"

def shutdown(self):
""" Gracefully terminate a running runtime. """
self.ready.clear()
self.shutdown_send.send(self.SHUTDOWN_TRIGGER) # trigger background thread to shutdown
logger.info("Shutting down")

if self.stats_report_interval is not None:
self.stats_reporter.stop.set()
self.stats_reporter.join()

self.stop.set() # trigger background thread to shutdown

logger.debug("Terminating pools")
for pool in self.pools:
if pool.is_alive():
pool.terminate()
pool.join()
logger.debug("Pools terminated")

def iterate_minibatches_from_pools(self, timeout=None):
"""
Chooses pool according to priority, then copies exposed batch and frees the buffer
"""
with DefaultSelector() as selector:
selector.register(self.shutdown_recv, EVENT_READ, self.SHUTDOWN_TRIGGER)
for pool in self.pools:
selector.register(pool.batch_receiver, EVENT_READ, pool)

while True:
while not self.stop.is_set():
# wait until at least one batch_receiver becomes available
logger.debug("Waiting for inputs from task pools")
ready_fds = selector.select()
ready_objects = {key.data for (key, events) in ready_fds}
if self.SHUTDOWN_TRIGGER in ready_objects:
break # someone asked us to shutdown, break from the loop

logger.debug("Choosing the pool with highest priority")
pool = max(ready_objects, key=lambda pool: pool.priority)

logger.debug(f"Loading batch from {pool.uid}")
logger.debug(f"Loading batch from {pool.name}")
batch_index, batch_tensors = pool.load_batch_to_runtime(timeout, self.device)
logger.debug(f"Loaded batch from {pool.uid}")
logger.debug(f"Loaded batch from {pool.name}")
yield pool, batch_index, batch_tensors


Expand Down
Loading