diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index 20891b5f30..3b9e6bfd27 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -32,7 +32,12 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss): best_function = max if args.maximize_best_checkpoint_metric else min save_checkpoint.best = best_function(val_loss, prev_best) - if args.no_save or not trainer.is_data_parallel_master: + if args.no_save: + return + + trainer.consolidate_optimizer() + + if not trainer.is_data_parallel_master: return def is_better(a, b): diff --git a/fairseq/optim/__init__.py b/fairseq/optim/__init__.py index 273aa5e8f6..dff140d580 100644 --- a/fairseq/optim/__init__.py +++ b/fairseq/optim/__init__.py @@ -10,12 +10,14 @@ from fairseq.optim.fairseq_optimizer import FairseqOptimizer from fairseq.optim.fp16_optimizer import FP16Optimizer, MemoryEfficientFP16Optimizer from fairseq.optim.bmuf import FairseqBMUF # noqa +from fairseq.optim.shard import shard_ __all__ = [ 'FairseqOptimizer', 'FP16Optimizer', 'MemoryEfficientFP16Optimizer', + 'shard_', ] diff --git a/fairseq/optim/fairseq_optimizer.py b/fairseq/optim/fairseq_optimizer.py index b1b9c76edb..e00a04dd1b 100644 --- a/fairseq/optim/fairseq_optimizer.py +++ b/fairseq/optim/fairseq_optimizer.py @@ -28,6 +28,15 @@ def optimizer(self): raise ValueError('_optimizer must be an instance of torch.optim.Optimizer') return self._optimizer + @optimizer.setter + def optimizer(self, optimizer): + """Reset optimizer instance.""" + if not hasattr(self, '_optimizer'): + raise NotImplementedError + if not isinstance(self._optimizer, torch.optim.Optimizer): + raise ValueError('_optimizer must be an instance of torch.optim.Optimizer') + self._optimizer = optimizer + @property def optimizer_config(self): """ diff --git a/fairseq/optim/fp16_optimizer.py b/fairseq/optim/fp16_optimizer.py index 960c3a67eb..777d43a713 100644 --- a/fairseq/optim/fp16_optimizer.py +++ b/fairseq/optim/fp16_optimizer.py @@ -232,6 +232,10 @@ def build_optimizer(cls, args, params): def optimizer(self): return self.fp32_optimizer.optimizer + @optimizer.setter + def optimizer(self, optimizer): + self.fp32_optimizer.optimizer = optimizer + @property def optimizer_config(self): return self.fp32_optimizer.optimizer_config @@ -279,19 +283,20 @@ def load_state_dict(self, state_dict, optimizer_overrides=None): # params are FP16 while the optimizer state is FP32 and we don't want # to cast. A workaround is to manually copy back the original state # after the optimizer has been loaded. - groups = self.optimizer.param_groups - saved_groups = state_dict['param_groups'] - id_map = { - old_id: p - for old_id, p in zip( - chain(*(g['params'] for g in saved_groups)), - chain(*(g['params'] for g in groups)) - ) - } - for k, v in state_dict['state'].items(): - if k in id_map: - param = id_map[k] - self.optimizer.state[param] = v + if not getattr(self.optimizer, 'disable_mem_eff_fp16_loading_hack', False): + groups = self.optimizer.param_groups + saved_groups = state_dict['param_groups'] + id_map = { + old_id: p + for old_id, p in zip( + chain(*(g['params'] for g in saved_groups)), + chain(*(g['params'] for g in groups)) + ) + } + for k, v in state_dict['state'].items(): + if k in id_map: + param = id_map[k] + self.optimizer.state[param] = v def backward(self, loss): """Computes the sum of gradients of the given tensor w.r.t. graph leaves. @@ -412,6 +417,10 @@ def build_optimizer(cls, args, params): def optimizer(self): return self.wrapped_optimizer.optimizer + @optimizer.setter + def optimizer(self, optimizer): + self.wrapped_optimizer.optimizer = optimizer + @property def optimizer_config(self): return self.wrapped_optimizer.optimizer_config diff --git a/fairseq/optim/shard.py b/fairseq/optim/shard.py new file mode 100644 index 0000000000..4f35dbda47 --- /dev/null +++ b/fairseq/optim/shard.py @@ -0,0 +1,33 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +try: + from fairscale.optim import OSS + _has_fairscale = True +except ImportError: + _has_fairscale = False + + +def shard_(args, optimizer): + if not _has_fairscale: + raise ImportError( + '\n\nPlease install the fairscale package:' + '\n\n pip install fairscale' + ) + + class FairseqOSS(OSS): + @property + def disable_mem_eff_fp16_loading_hack(self): + return True + + def __getattr__(self, name): + if name.startswith("supports") and hasattr(self.optim, name): + return getattr(self.optim, name) + raise AttributeError("'FairseqOSS' object has no attribute {0!r}".format(name)) + + torch_optimizer = optimizer.optimizer + optim_cls = type(torch_optimizer) + optimizer.optimizer = FairseqOSS(torch_optimizer.param_groups, optim_cls, **optimizer.optimizer_config) diff --git a/fairseq/options.py b/fairseq/options.py index 74c1499e06..01150bda4a 100644 --- a/fairseq/options.py +++ b/fairseq/options.py @@ -448,6 +448,10 @@ def add_distributed_training_args(parser, default_world_size=None): help='number of GPUs in each node. An allreduce operation across GPUs in ' 'a node is very fast. Hence, we do allreduce across GPUs in a node, ' 'and gossip across different nodes') + # Add argument for ZeRO sharding of OptimizerState(os), gradients(g) and parameters(p) + group.add_argument('--zero-sharding', default='none', type=str, + choices=['none', 'os'], + help='ZeRO sharding') # fmt: on return group diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 6cd73a631a..5022ceea2d 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -214,11 +214,28 @@ def _build_optimizer(self): if self.args.use_bmuf: self._optimizer = optim.FairseqBMUF(self.args, self._optimizer) + if self.args.zero_sharding == 'os': + if (self.args.fp16 + and not self.args.memory_efficient_fp16 + and not self.args.memory_efficient_bf16 + ) and not self.args.fp16_no_flatten_grads: + raise ValueError( + "ZeRO is incomptabile with fp16 and flattened grads. " + "Please use --fp16-no-flatten-grads" + ) + else: + optim.shard_(self.args, self._optimizer) + # We should initialize the learning rate scheduler immediately after # building the optimizer, so that the initial learning rate is set. self._lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self.optimizer) self._lr_scheduler.step_update(0) + def consolidate_optimizer(self): + """For OSS, we need to consolidate the state dict.""" + if hasattr(self.optimizer.optimizer, "consolidate_state_dict"): + self.optimizer.optimizer.consolidate_state_dict() + def save_checkpoint(self, filename, extra_state): """Save all training state in a checkpoint file.""" if self.is_data_parallel_master: # only save one checkpoint