Skip to content

Commit

Permalink
Add gradient clipping support to ExpertBackend (#214)
Browse files Browse the repository at this point in the history
  • Loading branch information
mryab authored Apr 7, 2021
1 parent ca3aadb commit 6128cbb
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 20 deletions.
7 changes: 5 additions & 2 deletions hivemind/hivemind_cli/run_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@ def main():
parser.add_argument('--optimizer', type=str, default='adam', required=False, help='adam, sgd or none')
parser.add_argument('--scheduler', type=str, choices=schedule_name_to_scheduler.keys(), default='none',
help='LR scheduler type to use')
parser.add_argument('--num-warmup-steps', type=int, required=False, help='the number of warmup steps for LR schedule')
parser.add_argument('--num-training-steps', type=int, required=False, help='the total number of steps for LR schedule')
parser.add_argument('--num_warmup_steps', type=int, required=False, help='The number of warmup steps for LR schedule')
parser.add_argument('--num_total_steps', type=int, required=False, help='The total number of steps for LR schedule')
parser.add_argument('--clip_grad_norm', type=float, required=False, help='Maximum gradient norm used for clipping')

parser.add_argument('--no_dht', action='store_true', help='if specified, the server will not be attached to a dht')
parser.add_argument('--initial_peers', type=str, nargs='*', required=False, default=[],
Expand All @@ -53,6 +54,8 @@ def main():
parser.add_argument('--compression', type=str, default='NONE', required=False, help='Tensor compression '
'parameter for grpc. Can be NONE, MEANSTD or FLOAT16')
parser.add_argument('--checkpoint_dir', type=Path, required=False, help='Directory to store expert checkpoints')
parser.add_argument('--stats_report_interval', type=int, required=False,
help='Interval between two reports of batch processing performance statistics')

# fmt:on
args = vars(parser.parse_args())
Expand Down
16 changes: 10 additions & 6 deletions hivemind/server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,10 @@ def __init__(
@classmethod
def create(cls, listen_on='0.0.0.0:*', num_experts: int = None, expert_uids: str = None, expert_pattern: str = None,
expert_cls='ffn', hidden_dim=1024, optim_cls=torch.optim.Adam, scheduler: str = 'none',
num_warmup_steps=None, num_training_steps=None, num_handlers=None, max_batch_size=4096, device=None,
no_dht=False, initial_peers=(), dht_port=None, checkpoint_dir: Optional[Path] = None,
compression=CompressionType.NONE, *, start: bool, **kwargs) -> Server:
num_warmup_steps=None, num_total_steps=None, clip_grad_norm=None, num_handlers=None, max_batch_size=4096,
device=None, no_dht=False, initial_peers=(), dht_port=None, checkpoint_dir: Optional[Path] = None,
compression=CompressionType.NONE, stats_report_interval: Optional[int] = None, *, start: bool,
**kwargs) -> Server:
"""
Instantiate a server with several identical experts. See argparse comments below for details
:param listen_on: network interface with address and (optional) port, e.g. "127.0.0.1:1337" or "[::]:80"
Expand All @@ -89,7 +90,8 @@ def create(cls, listen_on='0.0.0.0:*', num_experts: int = None, expert_uids: str
:param optim_cls: uses this optimizer to train all experts
:param scheduler: if not `none`, the name of the expert LR scheduler
:param num_warmup_steps: the number of warmup steps for LR schedule
:param num_training_steps: the total number of steps for LR schedule
:param num_total_steps: the total number of steps for LR schedule
:param clip_grad_norm: maximum gradient norm used for clipping
:param no_dht: if specified, the server will not be attached to a dht
:param initial_peers: a list of peers that will introduce this node to the dht,\
Expand All @@ -105,6 +107,7 @@ def create(cls, listen_on='0.0.0.0:*', num_experts: int = None, expert_uids: str
for each BatchTensorProto in ExpertBackend for the respective experts.
:param start: if True, starts server right away and returns when server is ready for requests
:param stats_report_interval: interval between two reports of batch processing performance statistics
"""
if len(kwargs) != 0:
logger.info("Ignored kwargs:", kwargs)
Expand Down Expand Up @@ -165,14 +168,15 @@ def create(cls, listen_on='0.0.0.0:*', num_experts: int = None, expert_uids: str
optimizer=optim_cls(expert.parameters()),
scheduler=scheduler,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
num_total_steps=num_total_steps,
clip_grad_norm=clip_grad_norm,
max_batch_size=max_batch_size)

if checkpoint_dir is not None:
load_experts(experts, checkpoint_dir)

return cls(dht, experts, listen_on=listen_on, num_connection_handlers=num_handlers, device=device,
checkpoint_dir=checkpoint_dir, start=start)
checkpoint_dir=checkpoint_dir, stats_report_interval=stats_report_interval, start=start)

def run(self):
"""
Expand Down
13 changes: 9 additions & 4 deletions hivemind/server/expert_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ class ExpertBackend:
:param kwargs_schema: description of keyword arguments to expert.forward, dict of BatchTensorProto
:param outputs_schema: description of outputs from expert.forward, nested structure of BatchTensorProto
:param num_warmup_steps: the number of warmup steps for LR schedule
:param num_training_steps: the total number of steps for LR schedule
:param num_total_steps: the total number of steps for LR schedule
:param clip_grad_norm: maximum gradient norm used for clipping
:param kwargs: extra parameters to be forwarded into TaskPool.__init__
"""

Expand All @@ -44,16 +45,17 @@ def __init__(self, name: str, expert: nn.Module, optimizer: torch.optim.Optimize
args_schema: Tuple[BatchTensorDescriptor, ...] = None,
kwargs_schema: Dict[str, BatchTensorDescriptor] = None,
outputs_schema: Union[BatchTensorDescriptor, Tuple[BatchTensorDescriptor, ...]] = None,
num_warmup_steps: int = None, num_training_steps: int = None,
num_warmup_steps: int = None, num_total_steps: int = None, clip_grad_norm: float = None,
**kwargs):
super().__init__()
self.expert, self.optimizer, self.name = expert, optimizer, name

if scheduler is None:
self.scheduler = None
else:
assert optimizer is not None and num_warmup_steps is not None and num_training_steps is not None
self.scheduler = scheduler(self.optimizer, num_warmup_steps, num_training_steps)
assert optimizer is not None and num_warmup_steps is not None and num_total_steps is not None
self.scheduler = scheduler(self.optimizer, num_warmup_steps, num_total_steps)
self.clip_grad_norm = clip_grad_norm

self.args_schema = args_schema = tuple(args_schema or ())
self.kwargs_schema = kwargs_schema = dict(kwargs_schema or {})
Expand Down Expand Up @@ -147,6 +149,9 @@ def apply_gradients(self, batch_size) -> None:
"""
Train the expert for one step. This method is called by ``ExpertBackend.backward`` after computing gradients.
"""
if self.clip_grad_norm is not None:
torch.nn.utils.clip_grad_norm_(self.expert.parameters(), self.clip_grad_norm)

self.optimizer.step()
self.optimizer.zero_grad()

Expand Down
23 changes: 16 additions & 7 deletions hivemind/server/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from selectors import DefaultSelector, EVENT_READ
from statistics import mean
from time import time
from typing import Dict, NamedTuple
from typing import Dict, NamedTuple, Optional

import torch
from prefetch_generator import BackgroundGenerator
Expand Down Expand Up @@ -43,15 +43,17 @@ class Runtime(threading.Thread):
"""

def __init__(self, expert_backends: Dict[str, ExpertBackend], prefetch_batches=64, sender_threads: int = 1,
device: torch.device = None, stats_report_interval=30):
device: torch.device = None, stats_report_interval: Optional[int] = None):
super().__init__()
self.expert_backends = expert_backends
self.pools = tuple(chain(*(expert.get_pools() for expert in expert_backends.values())))
self.device, self.prefetch_batches, self.sender_threads = device, prefetch_batches, sender_threads
self.shutdown_recv, self.shutdown_send = mp.Pipe(duplex=False)
self.ready = mp.Event() # event is set iff server is currently running and ready to accept batches

self.stats_reporter = StatsReporter(stats_report_interval)
self.stats_report_interval = stats_report_interval
if self.stats_report_interval is not None:
self.stats_reporter = StatsReporter(self.stats_report_interval)

def run(self):
for pool in self.pools:
Expand All @@ -64,8 +66,10 @@ def run(self):
with mp.pool.ThreadPool(self.sender_threads) as output_sender_pool:
try:
self.ready.set()
self.stats_reporter.start()
if self.stats_report_interval is not None:
self.stats_reporter.start()
logger.info("Started")

for pool, batch_index, batch in BackgroundGenerator(
self.iterate_minibatches_from_pools(), self.prefetch_batches):
logger.debug(f"Processing batch {batch_index} from pool {pool.uid}")
Expand All @@ -76,13 +80,18 @@ def run(self):

batch_size = outputs[0].size(0)
logger.debug(f"Pool {pool.uid}: batch {batch_index} processed, size {batch_size}")
self.stats_reporter.report_stats(pool.uid, batch_size, batch_processing_time)

if self.stats_report_interval is not None:
self.stats_reporter.report_stats(pool.uid, batch_size, batch_processing_time)

output_sender_pool.apply_async(pool.send_outputs_from_runtime, args=[batch_index, outputs])
finally:
logger.info("Shutting down")
self.stats_reporter.stop.set()
self.stats_reporter.join()

if self.stats_report_interval is not None:
self.stats_reporter.stop.set()
self.stats_reporter.join()

self.shutdown()

SHUTDOWN_TRIGGER = "RUNTIME SHUTDOWN TRIGGERED"
Expand Down
2 changes: 1 addition & 1 deletion tests/test_expert_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def example_experts():
expert_backend = ExpertBackend(name=EXPERT_NAME, expert=expert, optimizer=opt,
scheduler=get_linear_schedule_with_warmup,
num_warmup_steps=BACKWARD_PASSES_BEFORE_SAVE,
num_training_steps=BACKWARD_PASSES_BEFORE_SAVE + BACKWARD_PASSES_AFTER_SAVE,
num_total_steps=BACKWARD_PASSES_BEFORE_SAVE + BACKWARD_PASSES_AFTER_SAVE,
args_schema=args_schema, outputs_schema=BatchTensorDescriptor(1), max_batch_size=1,
)
experts = {EXPERT_NAME: expert_backend}
Expand Down

0 comments on commit 6128cbb

Please sign in to comment.