diff --git a/CHANGELOG.md b/CHANGELOG.md index 2ad54381a082b..b29321cbd4b0e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Made the `Plugin.reduce` method more consistent across all Plugins to reflect a mean-reduction by default ([#6011](https://github.com/PyTorchLightning/pytorch-lightning/pull/6011) + + + ## [1.2.0] - 2021-02-18 ### Added diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 6a4e948e899cf..80161d6e59b6b 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -278,10 +278,22 @@ def model_to_device(self): torch.cuda.set_device(self.root_device) self.model.to(self.root_device) - def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None): - if isinstance(output, torch.Tensor): - output = sync_ddp_if_available(output, group, reduce_op) - return output + def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"): + """ + Reduces a tensor from several distributed processes to one aggregated tensor. + + Args: + tensor: the tensor to sync and reduce + group: the process group to gather results from. Defaults to all processes (world) + reduce_op: the reduction operation. Defaults to 'mean'/'avg'. + Can also be a string 'sum' to calculate the sum during reduction. + + Return: + reduced value, except when the input was not a tensor the output remains is unchanged + """ + if isinstance(tensor, torch.Tensor): + tensor = sync_ddp_if_available(tensor, group, reduce_op=(reduce_op or "mean")) + return tensor def training_step(self, *args, **kwargs): return self.model(*args, **kwargs) diff --git a/pytorch_lightning/plugins/training_type/ddp2.py b/pytorch_lightning/plugins/training_type/ddp2.py index a7c8477a40c2d..a94bb5459bb1e 100644 --- a/pytorch_lightning/plugins/training_type/ddp2.py +++ b/pytorch_lightning/plugins/training_type/ddp2.py @@ -25,14 +25,26 @@ def setup(self, model): self.task_idx = self.cluster_environment.local_rank() # the difference to DDP is that we don't call children processes here - def reduce(self, output, *args, **kwargs): - if isinstance(output, Result): - output.dp_reduce() + def reduce(self, tensor, *args, **kwargs): + """ + Reduces a tensor from all processes to one aggregated tensor. + In DDP2, the reduction here is only across local devices within the node. - elif isinstance(output, torch.Tensor): - output = output.mean() + Args: + tensor: the tensor to sync and reduce + *args: ignored for DDP2 + **kwargs: ignored for DDP2 - return output + Return: + reduced value, except when the input was not a tensor the output remains is unchanged + """ + if isinstance(tensor, Result): + tensor.dp_reduce() + + elif isinstance(tensor, torch.Tensor): + tensor = tensor.mean() + + return tensor @property def root_device(self): diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 66d3cb7bf4619..ca25a6d8bc382 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -256,10 +256,22 @@ def pre_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, opti if not self.lightning_module.automatic_optimization and self.model.require_backward_grad_sync: prepare_for_backward(self.model, closure_loss) - def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None): - if isinstance(output, torch.Tensor): - output = sync_ddp_if_available(output, group, reduce_op) - return output + def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"): + """ + Reduces a tensor from several distributed processes to one aggregated tensor. + + Args: + tensor: the tensor to sync and reduce + group: the process group to gather results from. Defaults to all processes (world) + reduce_op: the reduction operation. Defaults to 'mean'/'avg'. + Can also be a string 'sum' to calculate the sum during reduction. + + Return: + reduced value, except when the input was not a tensor the output remains is unchanged + """ + if isinstance(tensor, torch.Tensor): + tensor = sync_ddp_if_available(tensor, group, reduce_op=(reduce_op or "mean")) + return tensor def training_step(self, *args, **kwargs): return self.model(*args, **kwargs) diff --git a/pytorch_lightning/plugins/training_type/dp.py b/pytorch_lightning/plugins/training_type/dp.py index e1002faf8a3b4..c2b16303e5d4e 100644 --- a/pytorch_lightning/plugins/training_type/dp.py +++ b/pytorch_lightning/plugins/training_type/dp.py @@ -31,14 +31,25 @@ def setup(self, model): model.to(self.root_device) self._model = DataParallel(LightningParallelModule(model), self.parallel_devices) - def reduce(self, output, *args, **kwargs): - if isinstance(output, Result): - output.dp_reduce() + def reduce(self, tensor, *args, **kwargs): + """ + Reduces a tensor from all parallel processes to one aggregated tensor. - elif isinstance(output, torch.Tensor): - output = output.mean() + Args: + tensor: the tensor to sync and reduce + *args: ignored for DP + **kwargs: ignored for DP - return output + Return: + reduced value, except when the input was not a tensor the output remains is unchanged + """ + if isinstance(tensor, Result): + tensor.dp_reduce() + + elif isinstance(tensor, torch.Tensor): + tensor = tensor.mean() + + return tensor @property def root_device(self): diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py index 351d945675a0c..e940cb1d7229b 100644 --- a/pytorch_lightning/plugins/training_type/horovod.py +++ b/pytorch_lightning/plugins/training_type/horovod.py @@ -127,23 +127,35 @@ def model_to_device(self): torch.cuda.set_device(self.root_device) self.model.to(self.root_device) - def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None): + def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"): + """ + Reduces a tensor from several distributed processes to one aggregated tensor. + + Args: + tensor: the tensor to sync and reduce + group: the process group to gather results from. Defaults to all processes (world) + reduce_op: the reduction operation. Defaults to 'mean'/'avg'. + Can also be a string 'sum' to calculate the sum during reduction. + + Return: + reduced value, except when the input was not a tensor the output remains is unchanged + """ if group is not None: raise ValueError( "Horovod does not support allreduce using a subcommunicator at this time. " "Unset `group`." ) - if reduce_op is None or reduce_op == "sum": - reduce_op = hvd.Sum - elif isinstance(reduce_op, str) and reduce_op in ("avg", "mean"): + if reduce_op in (None, "avg", "mean"): reduce_op = hvd.Average + elif reduce_op == "sum": + reduce_op = hvd.Sum else: raise ValueError(f"unrecognized `reduce_op`: {reduce_op}") # sync all processes before reduction hvd.join() - return hvd.allreduce(output, op=reduce_op) + return hvd.allreduce(tensor, op=reduce_op) def gather_all_tensors(self, result: Union[torch.Tensor], group: Optional[Any] = None): if group is not None: diff --git a/pytorch_lightning/plugins/training_type/single_device.py b/pytorch_lightning/plugins/training_type/single_device.py index 4b1d24301b8a0..5bf0597ed7f18 100644 --- a/pytorch_lightning/plugins/training_type/single_device.py +++ b/pytorch_lightning/plugins/training_type/single_device.py @@ -19,8 +19,20 @@ def on_tpu(self) -> bool: def on_gpu(self) -> bool: return self.device.type == "cuda" and torch.cuda.is_available() - def reduce(self, output: Union[Any, torch.Tensor], *args: Any, **kwargs: Any) -> Union[Any, torch.Tensor]: - return output + def reduce(self, tensor: Union[Any, torch.Tensor], *args: Any, **kwargs: Any) -> Union[Any, torch.Tensor]: + """ + Reduces a tensor from several distributed processes to one aggregated tensor. + As this plugin only operates with a single device, the reduction is simply the identity. + + Args: + tensor: the tensor to sync and reduce + *args: ignored + **kwargs: ignored + + Return: + the unmodified input as reduction is not needed for single process operation + """ + return tensor @property def root_device(self) -> torch.device: diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index d7c3b4d4d77e1..b60b63df23e48 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -55,8 +55,15 @@ def is_global_zero(self) -> bool: """Whether the current process is the rank zero process not only on the local node, but for all nodes.""" @abstractmethod - def reduce(self, output: Union[torch.Tensor, Any], *args: Any, **kwargs: Any) -> Union[torch.Tensor, Any]: - """Reduces the given output (e.g. across GPUs/Processes)""" + def reduce(self, tensor: Union[torch.Tensor, Any], *args: Any, **kwargs: Any) -> Union[torch.Tensor, Any]: + """ + Reduces the given tensor (e.g. across GPUs/processes). + + Args: + tensor: the tensor to sync and reduce + *args: plugin-specific positional arguments + **kwargs: plugin-specific keyword arguments + """ @abstractmethod def barrier(self, name: Optional[str] = None) -> None: