From 00b76938043c6fc782e5d1806f8f58a61cfbed95 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= <aedu.waelchli@gmail.com>
Date: Tue, 16 Feb 2021 16:14:55 +0100
Subject: [PATCH 1/4] reduction docs

---
 .../plugins/training_type/ddp.py              | 20 ++++++++++++----
 .../plugins/training_type/ddp2.py             | 24 ++++++++++++++-----
 .../plugins/training_type/ddp_spawn.py        | 20 ++++++++++++----
 pytorch_lightning/plugins/training_type/dp.py | 23 +++++++++++++-----
 .../plugins/training_type/horovod.py          | 16 +++++++++++--
 .../plugins/training_type/single_device.py    | 16 +++++++++++--
 6 files changed, 95 insertions(+), 24 deletions(-)

diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py
index 52a24655f0846..eb16ac29f60ec 100644
--- a/pytorch_lightning/plugins/training_type/ddp.py
+++ b/pytorch_lightning/plugins/training_type/ddp.py
@@ -278,10 +278,22 @@ def model_to_device(self):
             torch.cuda.set_device(self.root_device)
         self.model.to(self.root_device)
 
-    def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None):
-        if isinstance(output, torch.Tensor):
-            output = sync_ddp_if_available(output, group, reduce_op)
-        return output
+    def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None):
+        """
+        Reduces a tensor from several distributed processes to one aggregated tensor.
+
+        Args:
+            tensor: the tensor to sync and reduce
+            group: the process group to gather results from. Defaults to all processes (world)
+            reduce_op: the reduction operation. Defaults to 'sum'.
+                Can also be a string of 'avg', 'mean' to calculate the mean during reduction.
+
+        Return:
+            reduced value, except when the input was not a tensor the output remains is unchanged
+        """
+        if isinstance(tensor, torch.Tensor):
+            tensor = sync_ddp_if_available(tensor, group, reduce_op)
+        return tensor
 
     def training_step(self, *args, **kwargs):
         return self.model(*args, **kwargs)
diff --git a/pytorch_lightning/plugins/training_type/ddp2.py b/pytorch_lightning/plugins/training_type/ddp2.py
index a7c8477a40c2d..a94bb5459bb1e 100644
--- a/pytorch_lightning/plugins/training_type/ddp2.py
+++ b/pytorch_lightning/plugins/training_type/ddp2.py
@@ -25,14 +25,26 @@ def setup(self, model):
         self.task_idx = self.cluster_environment.local_rank()
         # the difference to DDP is that we don't call children processes here
 
-    def reduce(self, output, *args, **kwargs):
-        if isinstance(output, Result):
-            output.dp_reduce()
+    def reduce(self, tensor, *args, **kwargs):
+        """
+        Reduces a tensor from all processes to one aggregated tensor.
+        In DDP2, the reduction here is only across local devices within the node.
 
-        elif isinstance(output, torch.Tensor):
-            output = output.mean()
+        Args:
+            tensor: the tensor to sync and reduce
+            *args: ignored for DDP2
+            **kwargs: ignored for DDP2
 
-        return output
+        Return:
+            reduced value, except when the input was not a tensor the output remains is unchanged
+        """
+        if isinstance(tensor, Result):
+            tensor.dp_reduce()
+
+        elif isinstance(tensor, torch.Tensor):
+            tensor = tensor.mean()
+
+        return tensor
 
     @property
     def root_device(self):
diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py
index 6b6d85ee0d29f..9678aed261f36 100644
--- a/pytorch_lightning/plugins/training_type/ddp_spawn.py
+++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py
@@ -257,10 +257,22 @@ def pre_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, opti
         if not self.lightning_module.automatic_optimization and self.model.require_backward_grad_sync:
             prepare_for_backward(self.model, closure_loss)
 
-    def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None):
-        if isinstance(output, torch.Tensor):
-            output = sync_ddp_if_available(output, group, reduce_op)
-        return output
+    def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None):
+        """
+        Reduces a tensor from several distributed processes to one aggregated tensor.
+
+        Args:
+            tensor: the tensor to sync and reduce
+            group: the process group to gather results from. Defaults to all processes (world)
+            reduce_op: the reduction operation. Defaults to 'sum'.
+                Can also be a string of 'avg', 'mean' to calculate the mean during reduction.
+
+        Return:
+            reduced value, except when the input was not a tensor the output remains is unchanged
+        """
+        if isinstance(tensor, torch.Tensor):
+            tensor = sync_ddp_if_available(tensor, group, reduce_op)
+        return tensor
 
     def training_step(self, *args, **kwargs):
         return self.model(*args, **kwargs)
diff --git a/pytorch_lightning/plugins/training_type/dp.py b/pytorch_lightning/plugins/training_type/dp.py
index d1a3e26e22693..77f8bf2252d43 100644
--- a/pytorch_lightning/plugins/training_type/dp.py
+++ b/pytorch_lightning/plugins/training_type/dp.py
@@ -31,14 +31,25 @@ def setup(self, model):
         model.to(self.root_device)
         self._model = DataParallel(LightningParallelModule(model), self.parallel_devices)
 
-    def reduce(self, output, *args, **kwargs):
-        if isinstance(output, Result):
-            output.dp_reduce()
+    def reduce(self, tensor, *args, **kwargs):
+        """
+        Reduces a tensor from all parallel processes to one aggregated tensor.
 
-        elif isinstance(output, torch.Tensor):
-            output = output.mean()
+        Args:
+            tensor: the tensor to sync and reduce
+            *args: ignored for DP
+            **kwargs: ignored for DP
 
-        return output
+        Return:
+            reduced value, except when the input was not a tensor the output remains is unchanged
+        """
+        if isinstance(tensor, Result):
+            tensor.dp_reduce()
+
+        elif isinstance(tensor, torch.Tensor):
+            tensor = tensor.mean()
+
+        return tensor
 
     @property
     def root_device(self):
diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py
index 2393c040bcc8f..0c17a6075c97d 100644
--- a/pytorch_lightning/plugins/training_type/horovod.py
+++ b/pytorch_lightning/plugins/training_type/horovod.py
@@ -124,7 +124,19 @@ def model_to_device(self):
             torch.cuda.set_device(self.root_device)
         self.model.to(self.root_device)
 
-    def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None):
+    def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None):
+        """
+        Reduces a tensor from several distributed processes to one aggregated tensor.
+
+        Args:
+            tensor: the tensor to sync and reduce
+            group: the process group to gather results from. Defaults to all processes (world)
+            reduce_op: the reduction operation. Defaults to 'sum'.
+                Can also be a string of 'avg', 'mean' to calculate the mean during reduction.
+
+        Return:
+            reduced value, except when the input was not a tensor the output remains is unchanged
+        """
         if group is not None:
             raise ValueError(
                 "Horovod does not support allreduce using a subcommunicator at this time. "
@@ -140,7 +152,7 @@ def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[
 
         # sync all processes before reduction
         hvd.join()
-        return hvd.allreduce(output, op=reduce_op)
+        return hvd.allreduce(tensor, op=reduce_op)
 
     def gather_all_tensors(self, result: Union[torch.Tensor], group: Optional[Any] = None):
         if group is not None:
diff --git a/pytorch_lightning/plugins/training_type/single_device.py b/pytorch_lightning/plugins/training_type/single_device.py
index 0eda31833d6fa..b4e6fc1a0bb19 100644
--- a/pytorch_lightning/plugins/training_type/single_device.py
+++ b/pytorch_lightning/plugins/training_type/single_device.py
@@ -20,8 +20,20 @@ def on_tpu(self) -> bool:
     def on_gpu(self) -> bool:
         return self.device.type == "cuda" and torch.cuda.is_available()
 
-    def reduce(self, output: Union[Any, torch.Tensor], *args: Any, **kwargs: Any) -> Union[Any, torch.Tensor]:
-        return output
+    def reduce(self, tensor: Union[Any, torch.Tensor], *args: Any, **kwargs: Any) -> Union[Any, torch.Tensor]:
+        """
+        Reduces a tensor from several distributed processes to one aggregated tensor.
+        As this plugin only operates with a single device, the reduction is simply the identity.
+
+        Args:
+            tensor: the tensor to sync and reduce
+            *args: ignored
+            **kwargs: ignored
+
+        Return:
+            the unmodified input as reduction is not needed for single process operation
+        """
+        return tensor
 
     @property
     def root_device(self) -> torch.device:

From 985aa923cc171958af4d394c57c83fd00bfaf1c8 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= <aedu.waelchli@gmail.com>
Date: Tue, 16 Feb 2021 16:24:17 +0100
Subject: [PATCH 2/4] docs for abstract base method

---
 .../plugins/training_type/training_type_plugin.py     | 11 +++++++++--
 1 file changed, 9 insertions(+), 2 deletions(-)

diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py
index db0e390c4b03e..a5cee68658f3b 100644
--- a/pytorch_lightning/plugins/training_type/training_type_plugin.py
+++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py
@@ -59,8 +59,15 @@ def is_global_zero(self) -> bool:
         """Whether the current process is the rank zero process not only on the local node, but for all nodes."""
 
     @abstractmethod
-    def reduce(self, output: Union[torch.Tensor, Any], *args: Any, **kwargs: Any) -> Union[torch.Tensor, Any]:
-        """Reduces the given output (e.g. across GPUs/Processes)"""
+    def reduce(self, tensor: Union[torch.Tensor, Any], *args: Any, **kwargs: Any) -> Union[torch.Tensor, Any]:
+        """
+        Reduces the given tensor (e.g. across GPUs/processes).
+
+        Args:
+            tensor: the tensor to sync and reduce
+            *args: plugin-specific positional arguments
+            **kwargs: plugin-specific keyword arguments
+        """
 
     @abstractmethod
     def barrier(self, name: Optional[str] = None) -> None:

From d841d7791dc08720d0fc111d639ce71de837e4b0 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= <aedu.waelchli@gmail.com>
Date: Wed, 17 Feb 2021 17:08:09 +0100
Subject: [PATCH 3/4] make mean the default

---
 pytorch_lightning/plugins/training_type/ddp.py       |  8 ++++----
 pytorch_lightning/plugins/training_type/ddp_spawn.py |  8 ++++----
 pytorch_lightning/plugins/training_type/horovod.py   | 12 ++++++------
 3 files changed, 14 insertions(+), 14 deletions(-)

diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py
index eb16ac29f60ec..19951b84a9661 100644
--- a/pytorch_lightning/plugins/training_type/ddp.py
+++ b/pytorch_lightning/plugins/training_type/ddp.py
@@ -278,21 +278,21 @@ def model_to_device(self):
             torch.cuda.set_device(self.root_device)
         self.model.to(self.root_device)
 
-    def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None):
+    def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"):
         """
         Reduces a tensor from several distributed processes to one aggregated tensor.
 
         Args:
             tensor: the tensor to sync and reduce
             group: the process group to gather results from. Defaults to all processes (world)
-            reduce_op: the reduction operation. Defaults to 'sum'.
-                Can also be a string of 'avg', 'mean' to calculate the mean during reduction.
+            reduce_op: the reduction operation. Defaults to 'mean'/'avg'.
+                Can also be a string 'sum' to calculate the sum during reduction.
 
         Return:
             reduced value, except when the input was not a tensor the output remains is unchanged
         """
         if isinstance(tensor, torch.Tensor):
-            tensor = sync_ddp_if_available(tensor, group, reduce_op)
+            tensor = sync_ddp_if_available(tensor, group, reduce_op=(reduce_op or "mean"))
         return tensor
 
     def training_step(self, *args, **kwargs):
diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py
index 9678aed261f36..a394ba7ca3fef 100644
--- a/pytorch_lightning/plugins/training_type/ddp_spawn.py
+++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py
@@ -257,21 +257,21 @@ def pre_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, opti
         if not self.lightning_module.automatic_optimization and self.model.require_backward_grad_sync:
             prepare_for_backward(self.model, closure_loss)
 
-    def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None):
+    def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"):
         """
         Reduces a tensor from several distributed processes to one aggregated tensor.
 
         Args:
             tensor: the tensor to sync and reduce
             group: the process group to gather results from. Defaults to all processes (world)
-            reduce_op: the reduction operation. Defaults to 'sum'.
-                Can also be a string of 'avg', 'mean' to calculate the mean during reduction.
+            reduce_op: the reduction operation. Defaults to 'mean'/'avg'.
+                Can also be a string 'sum' to calculate the sum during reduction.
 
         Return:
             reduced value, except when the input was not a tensor the output remains is unchanged
         """
         if isinstance(tensor, torch.Tensor):
-            tensor = sync_ddp_if_available(tensor, group, reduce_op)
+            tensor = sync_ddp_if_available(tensor, group, reduce_op=(reduce_op or "mean"))
         return tensor
 
     def training_step(self, *args, **kwargs):
diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py
index 0c17a6075c97d..ae6cea8aa4e8f 100644
--- a/pytorch_lightning/plugins/training_type/horovod.py
+++ b/pytorch_lightning/plugins/training_type/horovod.py
@@ -124,15 +124,15 @@ def model_to_device(self):
             torch.cuda.set_device(self.root_device)
         self.model.to(self.root_device)
 
-    def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None):
+    def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"):
         """
         Reduces a tensor from several distributed processes to one aggregated tensor.
 
         Args:
             tensor: the tensor to sync and reduce
             group: the process group to gather results from. Defaults to all processes (world)
-            reduce_op: the reduction operation. Defaults to 'sum'.
-                Can also be a string of 'avg', 'mean' to calculate the mean during reduction.
+            reduce_op: the reduction operation. Defaults to 'mean'/'avg'.
+                Can also be a string 'sum' to calculate the sum during reduction.
 
         Return:
             reduced value, except when the input was not a tensor the output remains is unchanged
@@ -143,10 +143,10 @@ def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Optional[Union[
                 "Unset `group`."
             )
 
-        if reduce_op is None or reduce_op == "sum":
-            reduce_op = hvd.Sum
-        elif isinstance(reduce_op, str) and reduce_op in ("avg", "mean"):
+        if reduce_op in (None, "avg", "mean"):
             reduce_op = hvd.Average
+        elif reduce_op == "sum":
+            reduce_op = hvd.Sum
         else:
             raise ValueError(f"unrecognized `reduce_op`: {reduce_op}")
 

From 1feda4fbe226c48ce2140699661bdba439858364 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= <aedu.waelchli@gmail.com>
Date: Fri, 19 Feb 2021 01:39:36 +0100
Subject: [PATCH 4/4] add preliminary chlog

---
 CHANGELOG.md | 5 +++++
 1 file changed, 5 insertions(+)

diff --git a/CHANGELOG.md b/CHANGELOG.md
index 43c0a4947e58c..b0f5d4b7cdd7f 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -4,6 +4,11 @@ All notable changes to this project will be documented in this file.
 
 The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
 
+
+ 
+- Made the `Plugin.reduce` method more consistent across all Plugins to reflect a mean-reduction by default ([#6011](https://github.com/PyTorchLightning/pytorch-lightning/pull/6011)
+
+
 ## [1.2] - YYYY-MM-DD
 
 ### Added