From 7d4e74c7454ef7d4ae2d13b73a87e913c3e70ef2 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Wed, 10 Mar 2021 14:19:07 +0000 Subject: [PATCH] [bug] All_gather support tensor on cpu (#6416) * add test * update changelog * update * rename function --- CHANGELOG.md | 3 +++ pytorch_lightning/utilities/apply_func.py | 6 ++++++ tests/utilities/test_all_gather_grad.py | 4 ++++ 3 files changed, 13 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index f73292f79342a..899e79ffae28f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -122,6 +122,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed logger creating directory structure too early in DDP ([#6380](https://github.com/PyTorchLightning/pytorch-lightning/pull/6380)) +- Fixed LightningModule `all_gather` on cpu tensors ([#6416](https://github.com/PyTorchLightning/pytorch-lightning/pull/6416)) + + ## [1.2.3] - 2021-03-09 ### Fixed diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index 9fd42008b9d8d..e100a803bcd00 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -164,6 +164,12 @@ def batch_to(data): def convert_to_tensors(data, device: torch.device = None): if device is None: raise MisconfigurationException("device (torch.device) should be provided.") + for src_dtype, conversion_func in CONVERSION_DTYPES: data = apply_to_collection(data, src_dtype, partial(conversion_func, device=device)) + + def _move_to_device_and_make_contiguous(t: torch.Tensor, device: torch.device): + return t.to(device).contiguous() + + data = apply_to_collection(data, torch.Tensor, partial(_move_to_device_and_make_contiguous, device=device)) return data diff --git a/tests/utilities/test_all_gather_grad.py b/tests/utilities/test_all_gather_grad.py index ae3addb4d6a66..259f9f4c09871 100644 --- a/tests/utilities/test_all_gather_grad.py +++ b/tests/utilities/test_all_gather_grad.py @@ -58,6 +58,8 @@ def training_epoch_end(self, outputs) -> None: self.training_epoch_end_called = True losses = torch.stack([x["loss"] for x in outputs]) gathered_loss = self.all_gather({ + "losses_tensor_int": torch.rand(2, 2).int().t(), + "losses_tensor_float": torch.rand(2, 2).t(), "losses_np_ndarray": np.array([1, 2, 3]), "losses_bool": [True, False], "losses_float": [0., 1., 2.], @@ -65,6 +67,8 @@ def training_epoch_end(self, outputs) -> None: "losses": losses, "losses_list": [losses, losses] }) + assert gathered_loss["losses_tensor_int"][0].dtype == torch.int64 + assert gathered_loss["losses_tensor_float"][0].dtype == torch.float assert gathered_loss["losses_np_ndarray"][0].dtype == torch.int64 # torch.bool can't be all_gathered assert gathered_loss["losses_bool"][0].dtype == torch.uint8